From 00d6a65926df1df50f3325c3d4885bb8031ba83f Mon Sep 17 00:00:00 2001 From: kapsner Date: Wed, 5 Apr 2023 21:28:40 +0200 Subject: [PATCH 001/207] feat: aggregate survshap across multiple observations --- DESCRIPTION | 7 ++- NAMESPACE | 1 + R/surv_shap.R | 89 +++++++++++++++++++++++++---- R/zzz.R | 2 + tests/testthat/test-predict_parts.R | 20 +++++++ 5 files changed, 106 insertions(+), 13 deletions(-) create mode 100644 R/zzz.R diff --git a/DESCRIPTION b/DESCRIPTION index 93e5581c..5fc010cc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.0.0.9000 +Version: 1.0.0.9001 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), @@ -18,7 +18,7 @@ Description: Survival analysis models are commonly used in medicine and other ar License: GPL (>= 3) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.1 +RoxygenNote: 7.2.3 Depends: R (>= 3.5.0) Imports: DALEX (>= 2.2.1), @@ -26,7 +26,8 @@ Imports: kernelshap, pec, survival, - patchwork + patchwork, + data.table Suggests: censored, covr, diff --git a/NAMESPACE b/NAMESPACE index d055c56b..1dc5c536 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -68,6 +68,7 @@ export(survival_to_cumulative_hazard) export(theme_default_survex) export(theme_vertical_default_survex) export(transform_to_stepfunction) +import(data.table) import(ggplot2) import(patchwork) import(survival) diff --git a/R/surv_shap.R b/R/surv_shap.R index bbb36b36..18775705 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -24,14 +24,32 @@ surv_shap <- function(explainer, B = 25, exact = FALSE ) { + # make this code work for multiple observations + stopifnot(ifelse(!is.null(y_true), + ifelse(is.matrix(y_true), + nrow(new_observation) == nrow(y_true), + is.null(dim(y_true)) && length(y_true) == 2L), + TRUE)) + test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) - new_observation <- new_observation[, colnames(new_observation) %in% colnames(explainer$data)] + + # make this code also work for 1-row matrix + col_index <- which(colnames(new_observation) %in% colnames(explainer$data)) + if (is.matrix(new_observation) && nrow(new_observation) == 1) { + new_observation <- as.matrix(t(new_observation[, col_index])) + } else { + new_observation <- new_observation[, col_index] + } + if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)") if (!is.null(y_true)) { if (is.matrix(y_true)) { - y_true_ind <- y_true[1, 2] - y_true_time <- y_true[1, 1] + # above, we have already checked that nrows of observations are + # identical to nrows of y_true; thus we do not need to index + # the first row here + y_true_ind <- y_true[, 2] + y_true_time <- y_true[, 1] } else { y_true_ind <- y_true[2] y_true_time <- y_true[1] @@ -40,7 +58,8 @@ surv_shap <- function(explainer, res <- list() res$eval_times <- explainer$times - res$variable_values <- new_observation + # to display final object correctly, when is.matrix(new_observation) == TRUE + res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, "exact_kernel" = shap_kernel(explainer, new_observation, ...), @@ -148,14 +167,64 @@ aggregate_surv_shap <- function(survshap, method) { use_kernelshap <- function(explainer, new_observation, ...){ predfun <- function(model, newdata){ - explainer$predict_survival_function(model, newdata, times=explainer$times) + explainer$predict_survival_function( + model, + newdata, + times = explainer$times + ) } - tmp_res <- kernelshap::kernelshap(explainer$model, new_observation, bg_X = explainer$data, - pred_fun = predfun, verbose=FALSE) + tmp_res_list <- sapply( + X = as.character(seq_len(nrow(new_observation))), + FUN = function(i) { + tmp_res <- kernelshap::kernelshap( + object = explainer$model, + X = new_observation[as.integer(i), ], + bg_X = explainer$data, + pred_fun = predfun, + verbose = FALSE + ) + tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) + colnames(tmp_shap_values) <- colnames(tmp_res$X) + rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") + data.table::as.data.table(tmp_shap_values, keep.rownames = TRUE) + }, + USE.NAMES = TRUE, + simplify = FALSE + ) + + shap_values <- aggregate_shap_multiple_observations( + shap_res_list = tmp_res_list, + feature_names = colnames(new_observation) + ) + + return(shap_values) +} + + +aggregate_shap_multiple_observations <- function(shap_res_list, feature_names) { - shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) - colnames(shap_values) <- colnames(tmp_res$X) - rownames(shap_values) <- paste("t=", explainer$times, sep = "") + if (length(shap_res_list) > 1) { + + full_survshap_results <- data.table::rbindlist( + l = shap_res_list, + use.names = TRUE, + idcol = TRUE + ) + + # compute arithmetic mean for each time-point and feature across + # multiple observations + tmp_res <- full_survshap_results[ + , lapply(.SD, mean), by = "rn", .SDcols = feature_names + ] + } else { + # no aggregation required + tmp_res <- shap_res_list[[1]] + } + shap_values <- tmp_res[, .SD, .SDcols = setdiff(colnames(tmp_res), "rn")] + # transform to data.frame to make everything compatible with + # previous code + shap_values <- data.frame(shap_values) + rownames(shap_values) <- tmp_res$rn return(shap_values) } diff --git a/R/zzz.R b/R/zzz.R new file mode 100644 index 00000000..4db9ceea --- /dev/null +++ b/R/zzz.R @@ -0,0 +1,2 @@ +#' @import data.table +NULL diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 20539a9f..43321e05 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -48,6 +48,26 @@ test_that("survshap explanations work", { }) +test_that("global survshap explanations with kernelshap work for ranger", { + veteran <- survival::veteran + + rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + + parts_ranger <- predict_parts( + rsf_ranger_exp, + veteran[1:40, !colnames(veteran) %in% c("time", "status")], + y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + aggregation_method = "mean_absolute", + calculation_method = "kernelshap" + ) + + expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) + expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp$times)) + expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp$data))) + +}) + test_that("survlime explanations work", { From 3bdc61976b3b133c3e286c069b39c3b06d3a6625 Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 6 Apr 2023 10:54:00 +0200 Subject: [PATCH 002/207] test: added plot for global survshap to unit-test --- tests/testthat/test-predict_parts.R | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 43321e05..15e795f3 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -61,6 +61,7 @@ test_that("global survshap explanations with kernelshap work for ranger", { aggregation_method = "mean_absolute", calculation_method = "kernelshap" ) + plot(parts_ranger) expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp$times)) From c0c779daa5c3d2875799fed15a8d8ee1c5f011fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 22 May 2023 12:45:04 +0200 Subject: [PATCH 003/207] Add `observation_aggregation_method` argument to allow for observation wise aggregation using custom functions. --- R/surv_shap.R | 23 ++++++++++++++--------- man/surv_shap.Rd | 5 ++++- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 18775705..8c09d94e 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -1,11 +1,12 @@ #' Helper functions for `predict_parts.R` #' #' @param explainer an explainer object - model preprocessed by the `explain()` function -#' @param new_observation a new observation for which predictions need to be explained +#' @param new_observation new observations for which predictions need to be explained #' @param ... additional parameters, passed to internal functions #' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting #' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation #' @param aggregation_method a character, either `"mean_absolute"` or `"integral"`, `"max_absolute"`, `"sum_of_squares"` +#' @param observation_aggregation_method a function, if `new_observation` contains multiple observation this function is applied to the same time point of generated shap profiles for each observation. Defaults to `mean`. #' #' @return A list, containing the calculated SurvSHAP(t) results in the `result` field #' @@ -20,6 +21,7 @@ surv_shap <- function(explainer, calculation_method = "kernelshap", aggregation_method = "integral", + observation_aggregation_method = mean, path = "average", B = 25, exact = FALSE @@ -29,7 +31,8 @@ surv_shap <- function(explainer, ifelse(is.matrix(y_true), nrow(new_observation) == nrow(y_true), is.null(dim(y_true)) && length(y_true) == 2L), - TRUE)) + TRUE), + is.function(observation_aggregation_method)) test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) @@ -62,8 +65,8 @@ surv_shap <- function(explainer, res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, - "exact_kernel" = shap_kernel(explainer, new_observation, ...), - "kernelshap" = use_kernelshap(explainer, new_observation, ...), + "exact_kernel" = shap_kernel(explainer, new_observation, observation_aggregation_method, ...), + "kernelshap" = use_kernelshap(explainer, new_observation, observation_aggregation_method, ...), stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")) if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind) @@ -75,7 +78,7 @@ surv_shap <- function(explainer, } -shap_kernel <- function(explainer, new_observation, ...) { +shap_kernel <- function(explainer, new_observation, observation_aggregation_method, ...) { timestamps <- explainer$times p <- ncol(explainer$data) @@ -164,7 +167,7 @@ aggregate_surv_shap <- function(survshap, method) { } -use_kernelshap <- function(explainer, new_observation, ...){ +use_kernelshap <- function(explainer, new_observation, observation_aggregation_method, ...){ predfun <- function(model, newdata){ explainer$predict_survival_function( @@ -195,14 +198,16 @@ use_kernelshap <- function(explainer, new_observation, ...){ shap_values <- aggregate_shap_multiple_observations( shap_res_list = tmp_res_list, - feature_names = colnames(new_observation) + feature_names = colnames(new_observation), + aggregation_function = observation_aggregation_method ) + return(shap_values) } -aggregate_shap_multiple_observations <- function(shap_res_list, feature_names) { +aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) { if (length(shap_res_list) > 1) { @@ -215,7 +220,7 @@ aggregate_shap_multiple_observations <- function(shap_res_list, feature_names) { # compute arithmetic mean for each time-point and feature across # multiple observations tmp_res <- full_survshap_results[ - , lapply(.SD, mean), by = "rn", .SDcols = feature_names + , lapply(.SD, aggregation_function), by = "rn", .SDcols = feature_names ] } else { # no aggregation required diff --git a/man/surv_shap.Rd b/man/surv_shap.Rd index 0e8457d9..80eec3a2 100644 --- a/man/surv_shap.Rd +++ b/man/surv_shap.Rd @@ -11,6 +11,7 @@ surv_shap( y_true = NULL, calculation_method = "kernelshap", aggregation_method = "integral", + observation_aggregation_method = mean, path = "average", B = 25, exact = FALSE @@ -19,7 +20,7 @@ surv_shap( \arguments{ \item{explainer}{an explainer object - model preprocessed by the \code{explain()} function} -\item{new_observation}{a new observation for which predictions need to be explained} +\item{new_observation}{new observations for which predictions need to be explained} \item{...}{additional parameters, passed to internal functions} @@ -28,6 +29,8 @@ surv_shap( \item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} \item{aggregation_method}{a character, either \code{"mean_absolute"} or \code{"integral"}, \code{"max_absolute"}, \code{"sum_of_squares"}} + +\item{observation_aggregation_method}{a function, if \code{new_observation} contains multiple observation this function is applied to the same time point of generated shap profiles for each observation. Defaults to \code{mean}.} } \value{ A list, containing the calculated SurvSHAP(t) results in the \code{result} field From c6a1f1d4e48c8f31787c8ef8c4f536a199cd4ed4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 22 May 2023 14:54:14 +0200 Subject: [PATCH 004/207] Get rid of `data.table` and make `"exact_kernel"` shap method work with multiple observations --- DESCRIPTION | 1 - NAMESPACE | 1 - R/surv_shap.R | 64 ++++++++++++++++++++++++++++++++++++--------------- R/zzz.R | 1 - 4 files changed, 46 insertions(+), 21 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 5fc010cc..b1e0c751 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,7 +27,6 @@ Imports: pec, survival, patchwork, - data.table Suggests: censored, covr, diff --git a/NAMESPACE b/NAMESPACE index 1dc5c536..d055c56b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -68,7 +68,6 @@ export(survival_to_cumulative_hazard) export(theme_default_survex) export(theme_vertical_default_survex) export(transform_to_stepfunction) -import(data.table) import(ggplot2) import(patchwork) import(survival) diff --git a/R/surv_shap.R b/R/surv_shap.R index 8c09d94e..13403733 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -45,7 +45,6 @@ surv_shap <- function(explainer, } if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)") - if (!is.null(y_true)) { if (is.matrix(y_true)) { # above, we have already checked that nrows of observations are @@ -63,9 +62,8 @@ surv_shap <- function(explainer, res$eval_times <- explainer$times # to display final object correctly, when is.matrix(new_observation) == TRUE res$variable_values <- as.data.frame(new_observation) - res$result <- switch(calculation_method, - "exact_kernel" = shap_kernel(explainer, new_observation, observation_aggregation_method, ...), + "exact_kernel" = use_exact_shap(explainer, new_observation, observation_aggregation_method, ...), "kernelshap" = use_kernelshap(explainer, new_observation, observation_aggregation_method, ...), stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")) @@ -77,8 +75,32 @@ surv_shap <- function(explainer, res } +use_exact_shap <- function(explainer, new_observation, observation_aggregation_method, ...){ + + tmp_res_list <- sapply( + X = as.character(seq_len(nrow(new_observation))), + FUN = function(i) { + as.data.frame(shap_kernel(explainer, new_observation[as.integer(i),], ...)) + }, + USE.NAMES = TRUE, + simplify = FALSE + ) + + print(tmp_res_list) + shap_values <- aggregate_shap_multiple_observations( + shap_res_list = tmp_res_list, + feature_names = colnames(new_observation), + aggregation_function = observation_aggregation_method + ) + + + return(shap_values) + +} + + +shap_kernel <- function(explainer, new_observation, ...) { -shap_kernel <- function(explainer, new_observation, observation_aggregation_method, ...) { timestamps <- explainer$times p <- ncol(explainer$data) @@ -91,10 +113,13 @@ shap_kernel <- function(explainer, new_observation, observation_aggregation_meth permutations <- expand.grid(rep(list(0:1), p)) kernel_weights <- generate_shap_kernel_weights(permutations, p) - shap_values <- calculate_shap_values(explainer, explainer$model, baseline_sf, explainer$data, permutations, kernel_weights, new_observation, timestamps) + shap_values <- calculate_shap_values(explainer, explainer$model, baseline_sf, as.data.frame(explainer$data), permutations, kernel_weights, as.data.frame(new_observation), timestamps) + + shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data)) colnames(shap_values) <- paste("t=", timestamps, sep = "") + return (t(shap_values)) } @@ -131,7 +156,6 @@ calculate_shap_values <- function(explainer, model, avg_survival_function, data, make_prediction_for_simplified_input <- function(explainer, model, data, simplified_inputs, new_observation, timestamps) { - preds <- apply(simplified_inputs, 1, function(row) { row <- as.logical(row) @@ -190,7 +214,7 @@ use_kernelshap <- function(explainer, new_observation, observation_aggregation_m tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) colnames(tmp_shap_values) <- colnames(tmp_res$X) rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") - data.table::as.data.table(tmp_shap_values, keep.rownames = TRUE) + tmp_shap_values }, USE.NAMES = TRUE, simplify = FALSE @@ -210,26 +234,30 @@ use_kernelshap <- function(explainer, new_observation, observation_aggregation_m aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) { if (length(shap_res_list) > 1) { + shap_res_list <- lapply(shap_res_list, function(x){ + x$rn <- rownames(x) + x + }) - full_survshap_results <- data.table::rbindlist( - l = shap_res_list, - use.names = TRUE, - idcol = TRUE - ) + full_survshap_results <- do.call("rbind", shap_res_list) + rownames(full_survshap_results) <- NULL # compute arithmetic mean for each time-point and feature across # multiple observations - tmp_res <- full_survshap_results[ - , lapply(.SD, aggregation_function), by = "rn", .SDcols = feature_names - ] + + tmp_res <- aggregate(full_survshap_results[, !colnames(full_survshap_results) %in% c("rn")], + by = list(full_survshap_results$rn), + FUN = aggregation_function) + rownames(tmp_res) <- tmp_res$Group.1 + ordering <- order(as.numeric(substring(rownames(tmp_res),3))) + + tmp_res <- tmp_res[ordering, !colnames(tmp_res) %in% c("rn","Group.1")] } else { # no aggregation required tmp_res <- shap_res_list[[1]] } - shap_values <- tmp_res[, .SD, .SDcols = setdiff(colnames(tmp_res), "rn")] + shap_values <- tmp_res # transform to data.frame to make everything compatible with # previous code - shap_values <- data.frame(shap_values) - rownames(shap_values) <- tmp_res$rn return(shap_values) } diff --git a/R/zzz.R b/R/zzz.R index 4db9ceea..7951defe 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,2 +1 @@ -#' @import data.table NULL From 240d0a8e12c273ce575554b7dd2f1cb66fc25568 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 22 May 2023 15:06:38 +0200 Subject: [PATCH 005/207] Add additional fields in the result object for plotting `aggregated_surv_shap` --- R/surv_shap.R | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 13403733..8ce87571 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -71,8 +71,16 @@ surv_shap <- function(explainer, res$aggregate <- aggregate_surv_shap(res, aggregation_method) - class(res) <- "surv_shap" - res + if(nrow(new_observation) > 1){ + class(res) <- c("aggregated_surv_shap", "surv_shap") + res$observation_aggregation_function <- observation_aggregation_method + res$aggregation_method <- aggregation_method + res$n_observations <- nrow(new_observation) + } else { + class(res) <- "surv_shap" + } + + return(res) } use_exact_shap <- function(explainer, new_observation, observation_aggregation_method, ...){ From e1cfcccde9cb4ca5ebb28fcc77fc4bda8146feed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 22 May 2023 15:57:11 +0200 Subject: [PATCH 006/207] Add plotting of `aggregated_surv_shap` --- NAMESPACE | 1 + R/plot_surv_shap.R | 66 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index d055c56b..9f221416 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,6 +12,7 @@ S3method(model_performance,default) S3method(model_performance,surv_explainer) S3method(model_profile,default) S3method(model_profile,surv_explainer) +S3method(plot,aggregated_surv_shap) S3method(plot,feature_importance_explainer) S3method(plot,model_parts_survival) S3method(plot,model_performance_survival) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index b418b808..82b9cff6 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -89,3 +89,69 @@ plot.surv_shap <- function(x, return(return_plot) } + + + +#'@export +plot.aggregated_surv_shap <- function(x, + ..., + title = "Aggregated SurvSHAP(t)", + subtitle = "default", + xlab_left = "Importance", + ylab_right = "Aggregated SurvSHAP(t) value", + max_vars = 7, + colors = NULL, + rug = "all", + rug_colors = c("#dd0000", "#222222")){ + + right_plot <- plot.surv_shap(x = x, + ... = ..., + title = NULL, + subtitle = NULL, + max_vars = max_vars, + colors = colors, + rug = rug, + rug_colors = rug_colors) + + labs(y = ylab_right) + + + dfl <- c(list(x), list(...)) + + df_list <- lapply(dfl, function(x) { + label <- attr(x, "label") + values <- x$aggregate + vars <- names(x$aggregate) + df <- data.frame(label, values, vars) + rownames(df) <- NULL + df + }) + + long_df <- do.call("rbind", df_list) + long_df <- long_df[order(long_df$values, decreasing = TRUE),] + + label <- unique(long_df$label) + if (!is.null(subtitle) && subtitle == "default") { + subtitle <- paste0("created for the ", paste(label, collapse = ", "), " model") + } + + left_plot <- with(long_df, { + ggplot(long_df, aes(x = values, y = reorder(vars, values))) + + geom_col(fill = "#46bac2") + + theme_default_survex() + + facet_wrap(~label, ncol = 1, scales = "free_y") + + labs(x = xlab_left) + + theme(axis.title.y = element_blank()) + + }) + + + pl <- left_plot + + right_plot + + patchwork::plot_layout(widths = c(3,5), guides = "collect") + + patchwork::plot_annotation(title = title, subtitle = subtitle) & + theme(legend.position = "top", + plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), + plot.subtitle = element_text(color = "#371ea3", hjust = 0),) + + pl +} From d7887d5c843265327f8d2005654e10c5e959fd7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Wed, 12 Jul 2023 15:39:41 +0200 Subject: [PATCH 007/207] Add `model_survshap()` function and modify survshap to return all calculated shaps. --- NAMESPACE | 2 ++ R/model_survshap.R | 39 +++++++++++++++++++++++++++++++ R/surv_shap.R | 58 +++++++++++++++------------------------------- man/surv_shap.Rd | 6 +---- 4 files changed, 61 insertions(+), 44 deletions(-) create mode 100644 R/model_survshap.R diff --git a/NAMESPACE b/NAMESPACE index 9f221416..79800b51 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,6 +12,7 @@ S3method(model_performance,default) S3method(model_performance,surv_explainer) S3method(model_profile,default) S3method(model_profile,surv_explainer) +S3method(model_survshap,surv_explainer) S3method(plot,aggregated_surv_shap) S3method(plot,feature_importance_explainer) S3method(plot,model_parts_survival) @@ -60,6 +61,7 @@ export(loss_one_minus_integrated_cd_auc) export(model_parts) export(model_performance) export(model_profile) +export(model_survshap) export(predict_parts) export(predict_profile) export(risk_from_chf) diff --git a/R/model_survshap.R b/R/model_survshap.R new file mode 100644 index 00000000..af9814e4 --- /dev/null +++ b/R/model_survshap.R @@ -0,0 +1,39 @@ +#'@export +model_survshap <- + function(explainer, ...) + UseMethod("model_survshap", explainer) + +#'@export +model_survshap.surv_explainer <- function(explainer, + calculation_method = "kernelshap", + aggregation_method = "integral", + ..., + N = NULL) { + + test_explainer( + explainer, + has_data = TRUE, + has_y = TRUE, + has_survival = TRUE, + function_name = "model_survshap" + ) + observations <- explainer$data + + if (!is.null(N)) { + selected_observations <- sample(1:nrow(observations), N) + observations <- selected_observations + } + + shap_values <- surv_shap( + explainer = explainer, + new_observation = observations, + calculation_method = calculation_method, + aggregation_method = aggregation_method + ) + + attr(shap_values, "label") <- explainer$label + shap_values$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] + shap_values$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] + return(shap_values) + +} diff --git a/R/surv_shap.R b/R/surv_shap.R index 8ce87571..aa70f5a1 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -20,19 +20,14 @@ surv_shap <- function(explainer, y_true = NULL, calculation_method = "kernelshap", - aggregation_method = "integral", - observation_aggregation_method = mean, - path = "average", - B = 25, - exact = FALSE -) { + aggregation_method = "integral") +{ # make this code work for multiple observations stopifnot(ifelse(!is.null(y_true), ifelse(is.matrix(y_true), nrow(new_observation) == nrow(y_true), is.null(dim(y_true)) && length(y_true) == 2L), - TRUE), - is.function(observation_aggregation_method)) + TRUE)) test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) @@ -63,21 +58,22 @@ surv_shap <- function(explainer, # to display final object correctly, when is.matrix(new_observation) == TRUE res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, - "exact_kernel" = use_exact_shap(explainer, new_observation, observation_aggregation_method, ...), - "kernelshap" = use_kernelshap(explainer, new_observation, observation_aggregation_method, ...), + "exact_kernel" = use_exact_shap(explainer, new_observation, ...), + "kernelshap" = use_kernelshap(explainer, new_observation, ...), stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")) if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind) - res$aggregate <- aggregate_surv_shap(res, aggregation_method) + res$aggregate <- lapply(res$result, aggregate_surv_shap, method = aggregation_method, times = res$eval_times) if(nrow(new_observation) > 1){ - class(res) <- c("aggregated_surv_shap", "surv_shap") - res$observation_aggregation_function <- observation_aggregation_method - res$aggregation_method <- aggregation_method + class(res) <- "aggregated_surv_shap" + # res$aggregation_method <- aggregation_method res$n_observations <- nrow(new_observation) } else { class(res) <- "surv_shap" + res$result <- res$result[[1]] + res$aggregate <- res$aggregate[[1]] } return(res) @@ -85,7 +81,7 @@ surv_shap <- function(explainer, use_exact_shap <- function(explainer, new_observation, observation_aggregation_method, ...){ - tmp_res_list <- sapply( + shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { as.data.frame(shap_kernel(explainer, new_observation[as.integer(i),], ...)) @@ -94,14 +90,6 @@ use_exact_shap <- function(explainer, new_observation, observation_aggregation_m simplify = FALSE ) - print(tmp_res_list) - shap_values <- aggregate_shap_multiple_observations( - shap_res_list = tmp_res_list, - feature_names = colnames(new_observation), - aggregation_function = observation_aggregation_method - ) - - return(shap_values) } @@ -180,16 +168,14 @@ make_prediction_for_simplified_input <- function(explainer, model, data, simplif } -aggregate_surv_shap <- function(survshap, method) { - +aggregate_surv_shap <- function(survshap, times, method) { switch(method, - "sum_of_squares" = return(apply(survshap$result, 2, function(x) sum(x^2))), - "mean_absolute" = return(apply(survshap$result, 2, function(x) mean(abs(x)))), - "max_absolute" = return(apply(survshap$result, 2, function(x) max(abs(x)))), - "integral" = return(apply(survshap$result, 2, function(x) { + "sum_of_squares" = return(apply(survshap, 2, function(x) sum(x^2))), + "mean_absolute" = return(apply(survshap, 2, function(x) mean(abs(x)))), + "max_absolute" = return(apply(survshap, 2, function(x) max(abs(x)))), + "integral" = return(apply(survshap, 2, function(x) { x <- abs(x) names(x) <- NULL - times <- survshap$eval_times n <- length(x) i <- (x[1:(n - 1)] + x[2:n]) * diff(times) / 2 sum(i) / (max(times) - min(times)) @@ -209,7 +195,7 @@ use_kernelshap <- function(explainer, new_observation, observation_aggregation_m ) } - tmp_res_list <- sapply( + shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { tmp_res <- kernelshap::kernelshap( @@ -228,17 +214,11 @@ use_kernelshap <- function(explainer, new_observation, observation_aggregation_m simplify = FALSE ) - shap_values <- aggregate_shap_multiple_observations( - shap_res_list = tmp_res_list, - feature_names = colnames(new_observation), - aggregation_function = observation_aggregation_method - ) - - return(shap_values) -} +} +#'@internal aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) { if (length(shap_res_list) > 1) { diff --git a/man/surv_shap.Rd b/man/surv_shap.Rd index 80eec3a2..7bfc4060 100644 --- a/man/surv_shap.Rd +++ b/man/surv_shap.Rd @@ -10,11 +10,7 @@ surv_shap( ..., y_true = NULL, calculation_method = "kernelshap", - aggregation_method = "integral", - observation_aggregation_method = mean, - path = "average", - B = 25, - exact = FALSE + aggregation_method = "integral" ) } \arguments{ From 0015b7f2e112f54a0da002e6aa6b9e0c71e6e7e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Wed, 12 Jul 2023 15:40:31 +0200 Subject: [PATCH 008/207] Modlify aggregated survshap plotting function. --- R/plot_surv_shap.R | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 82b9cff6..9d682089 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -95,7 +95,7 @@ plot.surv_shap <- function(x, #'@export plot.aggregated_surv_shap <- function(x, ..., - title = "Aggregated SurvSHAP(t)", + title = "Feature importance according to aggregated |SurvSHAP(t)|", subtitle = "default", xlab_left = "Importance", ylab_right = "Aggregated SurvSHAP(t) value", @@ -104,6 +104,11 @@ plot.aggregated_surv_shap <- function(x, rug = "all", rug_colors = c("#dd0000", "#222222")){ + + old_x <- x + x$result <- aggregate_shap_multiple_observations(x$result, colnames(x$result[[1]]), function(x) mean(abs(x))) + x$aggregate <- apply(do.call(rbind, x$aggregate), 2, function(x) mean(abs(x))) + right_plot <- plot.surv_shap(x = x, ... = ..., title = NULL, From d0f325ca48db226af9cc0b67aaa1c729c006ec92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 17 Jul 2023 15:37:09 +0200 Subject: [PATCH 009/207] Begin swarmplot --- DESCRIPTION | 1 + NAMESPACE | 1 + R/plot_surv_shap.R | 106 ++++++++++++++++++++++++++++++++++++++++----- R/surv_shap.R | 2 +- 4 files changed, 99 insertions(+), 11 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index b1e0c751..d57c9b38 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,6 +27,7 @@ Imports: pec, survival, patchwork, + ggbeeswarm Suggests: censored, covr, diff --git a/NAMESPACE b/NAMESPACE index 79800b51..64b4b60a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -79,6 +79,7 @@ importFrom(DALEX,theme_drwhy) importFrom(DALEX,theme_drwhy_vertical) importFrom(DALEX,theme_ema) importFrom(DALEX,theme_ema_vertical) +importFrom(ggbeeswarm,geom_beeswarm) importFrom(stats,aggregate) importFrom(stats,as.formula) importFrom(stats,median) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 9d682089..a088e47c 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -94,6 +94,7 @@ plot.surv_shap <- function(x, #'@export plot.aggregated_surv_shap <- function(x, + kind = c("importance", "swarm"), ..., title = "Feature importance according to aggregated |SurvSHAP(t)|", subtitle = "default", @@ -105,7 +106,88 @@ plot.aggregated_surv_shap <- function(x, rug_colors = c("#dd0000", "#222222")){ - old_x <- x + if (kind == "importance"){ + + pl <- plot_shap_global_importance(x = x, + ... = ..., + title = title, + subtitle = subtitle, + xlab_left = xlab_left, + ylab_right = ylab_right, + max_vars = max_vars, + colors = colors, + rug = rug, + rug_colors = rug_colors) + + + } else if (kind == "swarm") { + + pl <- plot_shap_global_swarm(x = x, + ... = ..., + title = title, + subtitle = subtitle, + max_vars = max_vars, + colors = colors, + rug = rug, + rug_colors = rug_colors) + + } else { + stop("Unknown `kind` argument. Please use one of 'importance'") + } + + + return(pl) + + +} + +#'@importFrom ggbeeswarm geom_beeswarm +plot_shap_global_swarm <- function(x, + ..., + title = "Aaaaa", + subtitle = "default", + max_vars = 7, + colors = NULL, + rug = "all", + rug_colors = c("#dd0000", "#222222")){ + + # print(x$variable_values) + + df <- as.data.frame(do.call(rbind, x$aggregate)) + df <- stack(df) + original_values <- as.data.frame(x$variable_values) + # print(nrow(df)) + # print(nrow(original_values)) + print(apply(original_values, 2, function(x) any(is.na((as.numeric(x)))))) + colnames(df) <- c("shap_value", "variable") + cbind(df, label = "TODO label") + + + + + # print(df) + + ggplot(data = df, aes(y = `shap_value`, x = variable)) + + coord_flip()+ + ggbeeswarm::geom_beeswarm() + + + + +} + + +plot_shap_global_importance <- function(x, + ..., + title = "Feature importance according to aggregated |SurvSHAP(t)|", + subtitle = "default", + xlab_left = "Importance", + ylab_right = "Aggregated SurvSHAP(t) value", + max_vars = 7, + colors = NULL, + rug = "all", + rug_colors = c("#dd0000", "#222222")){ + x$result <- aggregate_shap_multiple_observations(x$result, colnames(x$result[[1]]), function(x) mean(abs(x))) x$aggregate <- apply(do.call(rbind, x$aggregate), 2, function(x) mean(abs(x))) @@ -117,7 +199,7 @@ plot.aggregated_surv_shap <- function(x, colors = colors, rug = rug, rug_colors = rug_colors) + - labs(y = ylab_right) + labs(y = ylab_right) dfl <- c(list(x), list(...)) @@ -151,12 +233,16 @@ plot.aggregated_surv_shap <- function(x, pl <- left_plot + - right_plot + - patchwork::plot_layout(widths = c(3,5), guides = "collect") + - patchwork::plot_annotation(title = title, subtitle = subtitle) & - theme(legend.position = "top", - plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), - plot.subtitle = element_text(color = "#371ea3", hjust = 0),) - - pl + right_plot + + patchwork::plot_layout(widths = c(3,5), guides = "collect") + + patchwork::plot_annotation(title = title, subtitle = subtitle) & + theme(legend.position = "top", + plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), + plot.subtitle = element_text(color = "#371ea3", hjust = 0),) + + return(pl) } + + + + diff --git a/R/surv_shap.R b/R/surv_shap.R index aa70f5a1..ec539f99 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -218,7 +218,7 @@ use_kernelshap <- function(explainer, new_observation, observation_aggregation_m } -#'@internal +#'@keywords internal aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) { if (length(shap_res_list) > 1) { From 104d087ea63c0a3f6ba3d73542f672334349fa8e Mon Sep 17 00:00:00 2001 From: kapsner Date: Mon, 24 Jul 2023 17:36:24 +0200 Subject: [PATCH 010/207] refactor: working on implementation of global survshap - make model_survshap working with new data that was not use to build explainer - documentation of model_survshap - minor error handling enhancements - added unit-tests for global survshap - removed global survshap tests from test-predict-parts addresses #75 --- DESCRIPTION | 2 +- R/model_survshap.R | 54 +++++++++++++++++++++++++--- R/surv_shap.R | 45 ++++++++++++----------- man/model_survshap.surv_explainer.Rd | 45 +++++++++++++++++++++++ tests/testthat/test-model_survshap.R | 42 ++++++++++++++++++++++ tests/testthat/test-predict_parts.R | 22 ------------ 6 files changed, 163 insertions(+), 47 deletions(-) create mode 100644 man/model_survshap.surv_explainer.Rd create mode 100644 tests/testthat/test-model_survshap.R diff --git a/DESCRIPTION b/DESCRIPTION index b1e0c751..b75479af 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.0.0.9001 +Version: 1.0.0.9002 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), diff --git a/R/model_survshap.R b/R/model_survshap.R index af9814e4..4c80edca 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -1,15 +1,50 @@ -#'@export +#' Global SHAP Values +#' +#' This function computes global SHAP values. +#' +#' @param N A positive integer indicating the number of observations that should be used to compute global SHAP values. +#' @inheritParams surv_shap +#' +#' @details +#' If specifying `y_true`, also `new_observation` must be specified. +#' Using the argument `new_observation`, global SHAP values are computed for the provided data. Otherwise, +#' global SHAP values are computed for the data, the `explainer` was trained with. +#' +#' +#' @return An object of class `aggregated_surv_shap` containing the computed global SHAP values. +#' +#' @rdname model_survshap.surv_explainer +#' @export model_survshap <- function(explainer, ...) UseMethod("model_survshap", explainer) -#'@export +#' @rdname model_survshap.surv_explainer +#' @export model_survshap.surv_explainer <- function(explainer, calculation_method = "kernelshap", aggregation_method = "integral", + new_observation = NULL, + y_true = NULL, ..., N = NULL) { + stopifnot( + "`N` must be a positive integer" = ifelse( + !is.null(N), + is.integer(N) && N > 0L, + TRUE + ), + "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( + !is.null(y_true), + ifelse( + is.matrix(y_true), + nrow(new_observation) == nrow(y_true), + is.null(dim(y_true)) && length(y_true) == 2L + ), + TRUE + )) + test_explainer( explainer, has_data = TRUE, @@ -17,16 +52,27 @@ model_survshap.surv_explainer <- function(explainer, has_survival = TRUE, function_name = "model_survshap" ) - observations <- explainer$data + if (!is.null(new_observation)) { + if (length(setdiff(colnames(new_observation), colnames(explainer$data))) == 0) { + observations <- new_observation + y_true <- y_true + } else { + stop("`new_observation` must have the same column names as the data, the `explainer` was trained with.") + } + } else { + observations <- explainer$data + y_true <- NULL + } if (!is.null(N)) { selected_observations <- sample(1:nrow(observations), N) - observations <- selected_observations + observations <- observations[selected_observations, ] } shap_values <- surv_shap( explainer = explainer, new_observation = observations, + y_true = y_true, calculation_method = calculation_method, aggregation_method = aggregation_method ) diff --git a/R/surv_shap.R b/R/surv_shap.R index aa70f5a1..2a33824a 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -18,16 +18,20 @@ surv_shap <- function(explainer, new_observation, ..., y_true = NULL, - calculation_method = "kernelshap", aggregation_method = "integral") { # make this code work for multiple observations - stopifnot(ifelse(!is.null(y_true), - ifelse(is.matrix(y_true), - nrow(new_observation) == nrow(y_true), - is.null(dim(y_true)) && length(y_true) == 2L), - TRUE)) + stopifnot( + "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( + !is.null(y_true), + ifelse( + is.matrix(y_true), + nrow(new_observation) == nrow(y_true), + is.null(dim(y_true)) && length(y_true) == 2L + ), + TRUE + )) test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) @@ -169,19 +173,20 @@ make_prediction_for_simplified_input <- function(explainer, model, data, simplif } aggregate_surv_shap <- function(survshap, times, method) { - switch(method, - "sum_of_squares" = return(apply(survshap, 2, function(x) sum(x^2))), - "mean_absolute" = return(apply(survshap, 2, function(x) mean(abs(x)))), - "max_absolute" = return(apply(survshap, 2, function(x) max(abs(x)))), - "integral" = return(apply(survshap, 2, function(x) { - x <- abs(x) - names(x) <- NULL - n <- length(x) - i <- (x[1:(n - 1)] + x[2:n]) * diff(times) / 2 - sum(i) / (max(times) - min(times)) - })), - stop("aggregation_method has to be one of `sum_of_squares`, `mean_absolute`, `max_absolute` or `integral`")) - + switch( + method, + "sum_of_squares" = return(apply(survshap, 2, function(x) sum(x^2))), + "mean_absolute" = return(apply(survshap, 2, function(x) mean(abs(x)))), + "max_absolute" = return(apply(survshap, 2, function(x) max(abs(x)))), + "integral" = return(apply(survshap, 2, function(x) { + x <- abs(x) + names(x) <- NULL + n <- length(x) + i <- (x[1:(n - 1)] + x[2:n]) * diff(times) / 2 + sum(i) / (max(times) - min(times)) + })), + stop("aggregation_method has to be one of `sum_of_squares`, `mean_absolute`, `max_absolute` or `integral`") + ) } @@ -222,7 +227,7 @@ use_kernelshap <- function(explainer, new_observation, observation_aggregation_m aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) { if (length(shap_res_list) > 1) { - shap_res_list <- lapply(shap_res_list, function(x){ + shap_res_list <- lapply(shap_res_list, function(x) { x$rn <- rownames(x) x }) diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd new file mode 100644 index 00000000..07f4cb12 --- /dev/null +++ b/man/model_survshap.surv_explainer.Rd @@ -0,0 +1,45 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model_survshap.R +\name{model_survshap} +\alias{model_survshap} +\alias{model_survshap.surv_explainer} +\title{Global SHAP Values} +\usage{ +model_survshap(explainer, ...) + +\method{model_survshap}{surv_explainer}( + explainer, + calculation_method = "kernelshap", + aggregation_method = "integral", + new_observation = NULL, + y_true = NULL, + ..., + N = NULL +) +} +\arguments{ +\item{explainer}{an explainer object - model preprocessed by the \code{explain()} function} + +\item{...}{additional parameters, passed to internal functions} + +\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} + +\item{aggregation_method}{a character, either \code{"mean_absolute"} or \code{"integral"}, \code{"max_absolute"}, \code{"sum_of_squares"}} + +\item{new_observation}{new observations for which predictions need to be explained} + +\item{y_true}{a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting} + +\item{N}{A positive integer indicating the number of observations that should be used to compute global SHAP values.} +} +\value{ +An object of class \code{aggregated_surv_shap} containing the computed global SHAP values. +} +\description{ +This function computes global SHAP values. +} +\details{ +If specifying \code{y_true}, also \code{new_observation} must be specified. +Using the argument \code{new_observation}, global SHAP values are computed for the provided data. Otherwise, +global SHAP values are computed for the data, the \code{explainer} was trained with. +} diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R new file mode 100644 index 00000000..d33ae3da --- /dev/null +++ b/tests/testthat/test-model_survshap.R @@ -0,0 +1,42 @@ + +veteran <- survival::veteran +rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) +rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + +test_that("global survshap explanations with kernelshap work for ranger, using new data", { + + ranger_global_survshap <- model_survshap( + explainer = rsf_ranger_exp, + new_observation = veteran[1:40, !colnames(veteran) %in% c("time", "status")], + y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + aggregation_method = "mean_absolute", + calculation_method = "kernelshap" + ) + plot(ranger_global_survshap) + + expect_s3_class(ranger_global_survshap, c("aggregated_surv_shap", "surv_shap")) + expect_equal(length(ranger_global_survshap$eval_times), length(rsf_ranger_exp$times)) + expect_true(all(names(ranger_global_survshap$variable_values) == colnames(rsf_ranger_exp$data))) + +}) + +test_that("global survshap explanations with kernelshap work for ranger, using explainer data", { + + # using all explainer data + ranger_global_survshap <- model_survshap( + explainer = rsf_ranger_exp, + aggregation_method = "mean_absolute", + calculation_method = "kernelshap" + ) + plot(ranger_global_survshap) + + # using only 6 observations + ranger_global_survshap <- model_survshap( + explainer = rsf_ranger_exp, + aggregation_method = "mean_absolute", + calculation_method = "kernelshap", + N = 6L + ) + plot(ranger_global_survshap) + +}) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 15e795f3..c47c8a77 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -48,28 +48,6 @@ test_that("survshap explanations work", { }) -test_that("global survshap explanations with kernelshap work for ranger", { - veteran <- survival::veteran - - rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) - rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) - - parts_ranger <- predict_parts( - rsf_ranger_exp, - veteran[1:40, !colnames(veteran) %in% c("time", "status")], - y_true = Surv(veteran$time[1:40], veteran$status[1:40]), - aggregation_method = "mean_absolute", - calculation_method = "kernelshap" - ) - plot(parts_ranger) - - expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) - expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp$times)) - expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp$data))) - -}) - - test_that("survlime explanations work", { veteran <- survival::veteran From 72622bf64b025c098c3f3ca6b7dfe82b9a59f7e3 Mon Sep 17 00:00:00 2001 From: kapsner Date: Tue, 25 Jul 2023 14:37:22 +0200 Subject: [PATCH 011/207] docs: added working example to model_survshap and also updated unit-test to allow specifying both, new_observation and N --- R/model_survshap.R | 36 +++++++++++++++++++++++++++- man/model_survshap.surv_explainer.Rd | 30 +++++++++++++++++++++++ tests/testthat/test-model_survshap.R | 10 ++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) diff --git a/R/model_survshap.R b/R/model_survshap.R index 4c80edca..0022158c 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -13,6 +13,35 @@ #' #' @return An object of class `aggregated_surv_shap` containing the computed global SHAP values. #' +#' @examples +#' \donttest{ +#' veteran <- survival::veteran +#' rsf_ranger <- ranger::ranger( +#' survival::Surv(time, status) ~ ., +#' data = veteran, +#' respect.unordered.factors = TRUE, +#' num.trees = 100, +#' mtry = 3, +#' max.depth = 5 +#' ) +#' rsf_ranger_exp <- explain( +#' rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = Surv(veteran$time, veteran$status), +#' verbose = FALSE +#' ) +#' +#' ranger_global_survshap <- model_survshap( +#' explainer = rsf_ranger_exp, +#' new_observation = veteran[1:40, !colnames(veteran) %in% c("time", "status")], +#' y_true = Surv(veteran$time[1:40], veteran$status[1:40]), +#' aggregation_method = "mean_absolute", +#' calculation_method = "kernelshap", +#' N = 5L +#' ) +#' plot(ranger_global_survshap) +#' } +#' #' @rdname model_survshap.surv_explainer #' @export model_survshap <- @@ -32,7 +61,7 @@ model_survshap.surv_explainer <- function(explainer, stopifnot( "`N` must be a positive integer" = ifelse( !is.null(N), - is.integer(N) && N > 0L, + is.integer(as.integer(N)) && N > 0L, TRUE ), "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( @@ -45,6 +74,8 @@ model_survshap.surv_explainer <- function(explainer, TRUE )) + N <- as.integer(N) + test_explainer( explainer, has_data = TRUE, @@ -67,6 +98,9 @@ model_survshap.surv_explainer <- function(explainer, if (!is.null(N)) { selected_observations <- sample(1:nrow(observations), N) observations <- observations[selected_observations, ] + if (!is.null(y_true)) { + y_true <- y_true[selected_observations, ] + } } shap_values <- surv_shap( diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 07f4cb12..dc7a3d6e 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -43,3 +43,33 @@ If specifying \code{y_true}, also \code{new_observation} must be specified. Using the argument \code{new_observation}, global SHAP values are computed for the provided data. Otherwise, global SHAP values are computed for the data, the \code{explainer} was trained with. } +\examples{ +\donttest{ +veteran <- survival::veteran +rsf_ranger <- ranger::ranger( + survival::Surv(time, status) ~ ., + data = veteran, + respect.unordered.factors = TRUE, + num.trees = 100, + mtry = 3, + max.depth = 5 +) +rsf_ranger_exp <- explain( + rsf_ranger, + data = veteran[, -c(3, 4)], + y = Surv(veteran$time, veteran$status), + verbose = FALSE +) + +ranger_global_survshap <- model_survshap( + explainer = rsf_ranger_exp, + new_observation = veteran[1:40, !colnames(veteran) \%in\% c("time", "status")], + y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + aggregation_method = "mean_absolute", + calculation_method = "kernelshap", + N = 5L +) +plot(ranger_global_survshap) +} + +} diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index d33ae3da..77561f5c 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -18,6 +18,16 @@ test_that("global survshap explanations with kernelshap work for ranger, using n expect_equal(length(ranger_global_survshap$eval_times), length(rsf_ranger_exp$times)) expect_true(all(names(ranger_global_survshap$variable_values) == colnames(rsf_ranger_exp$data))) + # test functioning of special case, when providing new observation and specify N + ranger_global_survshap <- model_survshap( + explainer = rsf_ranger_exp, + new_observation = veteran[1:40, !colnames(veteran) %in% c("time", "status")], + y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + aggregation_method = "mean_absolute", + calculation_method = "kernelshap", + N = 6 + ) + }) test_that("global survshap explanations with kernelshap work for ranger, using explainer data", { From 3a5aadded01e7e29e71f0dabad9cc3476ae80dd6 Mon Sep 17 00:00:00 2001 From: kapsner Date: Tue, 25 Jul 2023 14:46:33 +0200 Subject: [PATCH 012/207] feat: added n over which shap were aggregated to plot subtitle --- R/plot_surv_shap.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 9d682089..02475b6d 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -135,9 +135,10 @@ plot.aggregated_surv_shap <- function(x, long_df <- long_df[order(long_df$values, decreasing = TRUE),] label <- unique(long_df$label) - if (!is.null(subtitle) && subtitle == "default") { - subtitle <- paste0("created for the ", paste(label, collapse = ", "), " model") - } + subtitle <- paste0( + "created for the ", paste(label, collapse = ", "), " model ", + "(n=", x$n_observations, ")" + ) left_plot <- with(long_df, { ggplot(long_df, aes(x = values, y = reorder(vars, values))) + From b8dab05f14a08f4da577a6cb3b9ebb94b2d760b7 Mon Sep 17 00:00:00 2001 From: kapsner Date: Tue, 25 Jul 2023 14:48:32 +0200 Subject: [PATCH 013/207] fix: reintroduced missing if-statement for plot title --- R/plot_surv_shap.R | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 02475b6d..56c3a414 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -135,10 +135,12 @@ plot.aggregated_surv_shap <- function(x, long_df <- long_df[order(long_df$values, decreasing = TRUE),] label <- unique(long_df$label) - subtitle <- paste0( - "created for the ", paste(label, collapse = ", "), " model ", - "(n=", x$n_observations, ")" - ) + if (!is.null(subtitle) && subtitle == "default") { + subtitle <- paste0( + "created for the ", paste(label, collapse = ", "), " model ", + "(n=", x$n_observations, ")" + ) + } left_plot <- with(long_df, { ggplot(long_df, aes(x = values, y = reorder(vars, values))) + From c50c958a19f6bbe1768d16d66347db8a0f2083f9 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 16:49:57 +0200 Subject: [PATCH 014/207] create calculate_integral function as utils --- R/metrics.R | 27 ++------------------------- R/utils.R | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/R/metrics.R b/R/metrics.R index 3fd63a02..9cdb9e3c 100644 --- a/R/metrics.R +++ b/R/metrics.R @@ -27,7 +27,7 @@ loss_integrate <- function(loss_function, ..., normalization = NULL , max_quanti times <- times[quantile_mask] surv <- surv[,quantile_mask] - loss_values <- loss_function(y_true = y_true, risk = risk, surv = surv,times = times) + loss_values <- loss_function(y_true = y_true, risk = risk, surv = surv, times = times) na_mask <- (!is.na(loss_values)) @@ -35,30 +35,7 @@ loss_integrate <- function(loss_function, ..., normalization = NULL , max_quanti loss_values <- loss_values[na_mask] surv <- surv[na_mask] - n <- length(loss_values) - # integral using trapezoid method - - if (is.null(normalization)){ - tmp <- (loss_values[1:(n - 1)] + loss_values[2:n]) * diff(times) / 2 - integrated_metric <- sum(tmp) / (max(times) - min(times)) - return(integrated_metric) - } - else if (normalization == "t_max") { - tmp <- (loss_values[1:(n - 1)] + loss_values[2:n]) * diff(times) / 2 - integrated_metric <- sum(tmp) - return(integrated_metric/max(times)) - } else if (normalization == "survival"){ - - km <- survival::survfit(y_true ~ 1) - estimator <- stepfun(km$time, c(1, km$surv)) - - dwt <- 1 - estimator(times) - - tmp <- (loss_values[1:(n - 1)] + loss_values[2:n]) * diff(dwt) / 2 - integrated_metric <- sum(tmp) - return(integrated_metric/(1 - estimator(max(times)))) - } - + calculate_integral(loss_values, times, normalization, y_true) } attr(integrated_loss_function, "loss_type") <- "integrated" diff --git a/R/utils.R b/R/utils.R index fe489d65..46f4a33a 100644 --- a/R/utils.R +++ b/R/utils.R @@ -185,3 +185,30 @@ add_rug_to_plot <- function(base_plot, rug_df, rug, rug_colors){ } } + +#' @keywords internal +calculate_integral <- function(values, times, normalization = "t_max", ...){ + n <- length(values) + + if (is.null(normalization)){ + tmp <- (values[1:(n - 1)] + values[2:n]) * diff(times) / 2 + integrated_metric <- sum(tmp) / (max(times) - min(times)) + return(integrated_metric) + } + else if (normalization == "t_max") { + tmp <- (values[1:(n - 1)] + values[2:n]) * diff(times) / 2 + integrated_metric <- sum(tmp) + return(integrated_metric/max(times)) + } else if (normalization == "survival"){ + y_true <- list(...)$y_true + km <- survival::survfit(y_true ~ 1) + estimator <- stepfun(km$time, c(1, km$surv)) + + dwt <- 1 - estimator(times) + + tmp <- (values[1:(n - 1)] + values[2:n]) * diff(dwt) / 2 + integrated_metric <- sum(tmp) + return(integrated_metric/(1 - estimator(max(times)))) + } +} + From ce3be3b1a5e898df243823379fe9b673ecdce053 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 16:51:25 +0200 Subject: [PATCH 015/207] remove N from model_survshap --- R/model_survshap.R | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/R/model_survshap.R b/R/model_survshap.R index 0022158c..9b7c28cf 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -2,7 +2,6 @@ #' #' This function computes global SHAP values. #' -#' @param N A positive integer indicating the number of observations that should be used to compute global SHAP values. #' @inheritParams surv_shap #' #' @details @@ -51,19 +50,13 @@ model_survshap <- #' @rdname model_survshap.surv_explainer #' @export model_survshap.surv_explainer <- function(explainer, - calculation_method = "kernelshap", - aggregation_method = "integral", new_observation = NULL, y_true = NULL, - ..., - N = NULL) { + calculation_method = "kernelshap", + aggregation_method = "integral", + ...) { stopifnot( - "`N` must be a positive integer" = ifelse( - !is.null(N), - is.integer(as.integer(N)) && N > 0L, - TRUE - ), "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( !is.null(y_true), ifelse( @@ -74,8 +67,6 @@ model_survshap.surv_explainer <- function(explainer, TRUE )) - N <- as.integer(N) - test_explainer( explainer, has_data = TRUE, @@ -92,15 +83,7 @@ model_survshap.surv_explainer <- function(explainer, } } else { observations <- explainer$data - y_true <- NULL - } - - if (!is.null(N)) { - selected_observations <- sample(1:nrow(observations), N) - observations <- observations[selected_observations, ] - if (!is.null(y_true)) { - y_true <- y_true[selected_observations, ] - } + y_true <- explainer$y } shap_values <- surv_shap( From 8ed8b91b26bf8f7c13a772b4a701b729ceca6ea7 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 16:51:53 +0200 Subject: [PATCH 016/207] update aggregation methods in local survshap --- R/surv_shap.R | 15 +++++---------- man/surv_shap.Rd | 2 +- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index b3c86c4b..54f24a02 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -5,7 +5,7 @@ #' @param ... additional parameters, passed to internal functions #' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting #' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation -#' @param aggregation_method a character, either `"mean_absolute"` or `"integral"`, `"max_absolute"`, `"sum_of_squares"` +#' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"` #' @param observation_aggregation_method a function, if `new_observation` contains multiple observation this function is applied to the same time point of generated shap profiles for each observation. Defaults to `mean`. #' #' @return A list, containing the calculated SurvSHAP(t) results in the `result` field @@ -172,20 +172,15 @@ make_prediction_for_simplified_input <- function(explainer, model, data, simplif } -aggregate_surv_shap <- function(survshap, times, method) { +aggregate_surv_shap <- function(survshap, times, method, ...) { switch( method, "sum_of_squares" = return(apply(survshap, 2, function(x) sum(x^2))), "mean_absolute" = return(apply(survshap, 2, function(x) mean(abs(x)))), "max_absolute" = return(apply(survshap, 2, function(x) max(abs(x)))), - "integral" = return(apply(survshap, 2, function(x) { - x <- abs(x) - names(x) <- NULL - n <- length(x) - i <- (x[1:(n - 1)] + x[2:n]) * diff(times) / 2 - sum(i) / (max(times) - min(times)) - })), - stop("aggregation_method has to be one of `sum_of_squares`, `mean_absolute`, `max_absolute` or `integral`") + "integral" = return(apply(survshap, 2, function(x) calculate_integral(x, times, normalization = "t_max"))), + "integral_absolute" = return(apply(survshap, 2, function(x) calculate_integral(abs(x), times, normalization = "t_max"))), + stop("aggregation_method has to be one of 'integral', 'integral_absolute', 'mean_absolute', 'max_absolute', or 'sum_of_squares'") ) } diff --git a/man/surv_shap.Rd b/man/surv_shap.Rd index 7bfc4060..4cc684b9 100644 --- a/man/surv_shap.Rd +++ b/man/surv_shap.Rd @@ -24,7 +24,7 @@ surv_shap( \item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} -\item{aggregation_method}{a character, either \code{"mean_absolute"} or \code{"integral"}, \code{"max_absolute"}, \code{"sum_of_squares"}} +\item{aggregation_method}{a character, either \code{"integral"}, \code{"integral_absolute"}, \code{"mean_absolute"}, \code{"max_absolute"}, or \code{"sum_of_squares"}} \item{observation_aggregation_method}{a function, if \code{new_observation} contains multiple observation this function is applied to the same time point of generated shap profiles for each observation. Defaults to \code{mean}.} } From 6ceba5e3573e0d7de01518fd68440a4e012da98f Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 16:52:23 +0200 Subject: [PATCH 017/207] update plots --- NAMESPACE | 1 - R/plot_surv_shap.R | 242 ++++++++++++++++++++++++++------------------- 2 files changed, 142 insertions(+), 101 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 64b4b60a..79800b51 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -79,7 +79,6 @@ importFrom(DALEX,theme_drwhy) importFrom(DALEX,theme_drwhy_vertical) importFrom(DALEX,theme_ema) importFrom(DALEX,theme_ema_vertical) -importFrom(ggbeeswarm,geom_beeswarm) importFrom(stats,aggregate) importFrom(stats,as.formula) importFrom(stats,median) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 9a6fa6a7..b86f794c 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -94,105 +94,39 @@ plot.surv_shap <- function(x, #'@export plot.aggregated_surv_shap <- function(x, - kind = c("importance", "swarm"), + kind = "importance", ..., - title = "Feature importance according to aggregated |SurvSHAP(t)|", - subtitle = "default", - xlab_left = "Importance", - ylab_right = "Aggregated SurvSHAP(t) value", - max_vars = 7, - colors = NULL, - rug = "all", - rug_colors = c("#dd0000", "#222222")){ - - - if (kind == "importance"){ - - pl <- plot_shap_global_importance(x = x, + colors = NULL){ + switch( + kind, + "importance" = plot_shap_global_importance(x = x, ... = ..., - title = title, - subtitle = subtitle, - xlab_left = xlab_left, - ylab_right = ylab_right, - max_vars = max_vars, - colors = colors, - rug = rug, - rug_colors = rug_colors) - - - } else if (kind == "swarm") { - - pl <- plot_shap_global_swarm(x = x, + colors = colors), + "swarm" = plot_shap_global_swarm(x = x, ... = ..., - title = title, - subtitle = subtitle, - max_vars = max_vars, - colors = colors, - rug = rug, - rug_colors = rug_colors) - - } else { - stop("Unknown `kind` argument. Please use one of 'importance'") - } - - - return(pl) - - -} - -#'@importFrom ggbeeswarm geom_beeswarm -plot_shap_global_swarm <- function(x, - ..., - title = "Aaaaa", - subtitle = "default", - max_vars = 7, - colors = NULL, - rug = "all", - rug_colors = c("#dd0000", "#222222")){ - - # print(x$variable_values) - - df <- as.data.frame(do.call(rbind, x$aggregate)) - df <- stack(df) - original_values <- as.data.frame(x$variable_values) - # print(nrow(df)) - # print(nrow(original_values)) - print(apply(original_values, 2, function(x) any(is.na((as.numeric(x)))))) - colnames(df) <- c("shap_value", "variable") - cbind(df, label = "TODO label") - - - - - # print(df) - - ggplot(data = df, aes(y = `shap_value`, x = variable)) + - coord_flip()+ - ggbeeswarm::geom_beeswarm() - - - - + colors = colors), + "profile" = plot_shap_global_profile(x = x, + ... = ..., + colors = colors), + stop("`kind` must be one of 'importance', 'swarm' or 'profile'") + ) } - plot_shap_global_importance <- function(x, ..., title = "Feature importance according to aggregated |SurvSHAP(t)|", subtitle = "default", - xlab_left = "Importance", - ylab_right = "Aggregated SurvSHAP(t) value", max_vars = 7, colors = NULL, rug = "all", - rug_colors = c("#dd0000", "#222222")){ + rug_colors = c("#dd0000", "#222222"), + xlab_left = "Average |aggregated SurvSHAP(t)| value", + ylab_right = "Average |SurvSHAP(t)| value"){ x$result <- aggregate_shap_multiple_observations(x$result, colnames(x$result[[1]]), function(x) mean(abs(x))) x$aggregate <- apply(do.call(rbind, x$aggregate), 2, function(x) mean(abs(x))) right_plot <- plot.surv_shap(x = x, - ... = ..., title = NULL, subtitle = NULL, max_vars = max_vars, @@ -201,34 +135,21 @@ plot_shap_global_importance <- function(x, rug_colors = rug_colors) + labs(y = ylab_right) + label <- attr(x, "label") + long_df <- stack(x$aggregate) + long_df <- long_df[order(long_df$values, decreasing = TRUE),][1:min(max_vars, length(x$aggregate)), ] - dfl <- c(list(x), list(...)) - - df_list <- lapply(dfl, function(x) { - label <- attr(x, "label") - values <- x$aggregate - vars <- names(x$aggregate) - df <- data.frame(label, values, vars) - rownames(df) <- NULL - df - }) - - long_df <- do.call("rbind", df_list) - long_df <- long_df[order(long_df$values, decreasing = TRUE),] - - label <- unique(long_df$label) if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0( - "created for the ", paste(label, collapse = ", "), " model ", + "created for the ", label, " model ", "(n=", x$n_observations, ")" ) } left_plot <- with(long_df, { - ggplot(long_df, aes(x = values, y = reorder(vars, values))) + + ggplot(long_df, aes(x = values, y = reorder(ind, values))) + geom_col(fill = "#46bac2") + theme_default_survex() + - facet_wrap(~label, ncol = 1, scales = "free_y") + labs(x = xlab_left) + theme(axis.title.y = element_blank()) @@ -246,6 +167,127 @@ plot_shap_global_importance <- function(x, return(pl) } +plot_shap_global_swarm <- function(x, + ..., + title = "Aggregated SurvSHAP(t) values summary", + subtitle = "default", + max_vars = 7, + colors = NULL){ + + df <- as.data.frame(do.call(rbind, x$aggregate)) + cols <- names(sort(colMeans(abs(df))))[1:min(max_vars, length(df))] + df <- df[,cols] + df <- stack(df) + colnames(df) <- c("shap_value", "variable") + + original_values <- as.data.frame(x$variable_values)[,cols] + var_value <- preprocess_values_to_common_scale(original_values) + df <- cbind(df, var_value) + + label <- attr(x, "label") + if (!is.null(subtitle) && subtitle == "default") { + subtitle <- paste0( + "created for the ", label, " model ", + "(n=", x$n_observations, ")" + ) + } + ggplot(data = df, aes(x = shap_value, y = variable, color = var_value)) + + geom_vline(xintercept = 0, color = "#ceced9", linetype="solid") + + geom_jitter(width=0) + + scale_color_gradient2( + name = "Variable value", + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3", + midpoint = 0.5, + limits=c(0,1), + breaks = c(0, 1), + labels=c("", "")) + + labs(title = title, subtitle = subtitle, + x = "Aggregated SurvSHAP(t) value", + y = "Variable") + + theme_default_survex() + + theme(legend.position = "bottom") + + guides(color = guide_colorbar(title.position = "top", title.hjust = 0.5)) +} + + + +plot_shap_global_profile <- function(x, + ..., + variable = NULL, + color_variable = NULL, + title = "Aggregated SurvSHAP(t) profile", + subtitle = "default", + max_vars = 7, + colors = NULL){ + + df <- as.data.frame(do.call(rbind, x$aggregate)) + + if (is.null(variable)){ + variable <- colnames(df)[1] + } + if (is.null(color_variable)){ + color_variable <- variable + } + + shap_val <- df[,variable] + + original_values <- as.data.frame(x$variable_values) + var_vals <- original_values[,c(variable, color_variable)] + + df <- cbind(shap_val, var_vals) + colnames(df) <- c("shap_val", "variable_val", "color_variable_val") + + label <- attr(x, "label") + if (!is.null(subtitle) && subtitle == "default") { + subtitle <- paste0( + "created for the ", label, " model ", + "(n=", x$n_observations, ")" + ) + } + + p <- ggplot(df, aes(x = variable_val, y = shap_val, color = color_variable_val)) + + geom_hline(yintercept = 0, color = "#ceced9", linetype="solid") + + geom_point() + + geom_rug(aes(x = df$variable_val), inherit.aes=F, color = "#ceced9") + + labs(x = paste(variable, "value"), + y = "Aggregated SurvSHAP(t) value", + title = title, + subtitle = subtitle) + + theme_default_survex() + + theme(legend.position = "bottom") + + if (!is.factor(df$color_variable_val)) { + p + scale_color_gradient2( + name = paste(color_variable, "value"), + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3", + midpoint = median(df$color_variable_val)) + } else { + p + scale_color_manual(name = paste(color_variable, "value"), + values = generate_discrete_color_scale(length(unique(df$color_variable_val)), colors)) + } +} + +preprocess_values_to_common_scale <- function(data) { + # Scale numerical columns to range [0, 1] + num_cols <- sapply(data, is.numeric) + data[num_cols] <- lapply(data[num_cols], function(x) (x - min(x)) / (max(x) - min(x))) + + # Map categorical columns to integers with even differences + cat_cols <- sapply(data, function(x) !is.numeric(x) & is.factor(x)) + data[cat_cols] <- lapply(data[cat_cols], function(x) { + levels_count <- length(levels(x)) + mapped_values <- seq(0, 1, length.out = levels_count) + mapped_values[match(x, levels(x))] + }) + res <- stack(data) + colnames(res) <- c("var_value", "variable") + return(res[,1]) +} + From 47d7f21b9d3fe2cb71e7e6642aec4df5e14a2969 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 16:52:50 +0200 Subject: [PATCH 018/207] update docs --- man/model_survshap.surv_explainer.Rd | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index dc7a3d6e..6d0a4d20 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -9,12 +9,11 @@ model_survshap(explainer, ...) \method{model_survshap}{surv_explainer}( explainer, - calculation_method = "kernelshap", - aggregation_method = "integral", new_observation = NULL, y_true = NULL, - ..., - N = NULL + calculation_method = "kernelshap", + aggregation_method = "integral", + ... ) } \arguments{ @@ -22,15 +21,13 @@ model_survshap(explainer, ...) \item{...}{additional parameters, passed to internal functions} -\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} - -\item{aggregation_method}{a character, either \code{"mean_absolute"} or \code{"integral"}, \code{"max_absolute"}, \code{"sum_of_squares"}} - \item{new_observation}{new observations for which predictions need to be explained} \item{y_true}{a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting} -\item{N}{A positive integer indicating the number of observations that should be used to compute global SHAP values.} +\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} + +\item{aggregation_method}{a character, either \code{"integral"}, \code{"integral_absolute"}, \code{"mean_absolute"}, \code{"max_absolute"}, or \code{"sum_of_squares"}} } \value{ An object of class \code{aggregated_surv_shap} containing the computed global SHAP values. From 9b81d48753d9a66065a7c6794826c7ad168602c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Wed, 26 Jul 2023 17:19:43 +0200 Subject: [PATCH 019/207] Begin integration of additional pdp plots --- DESCRIPTION | 2 +- NAMESPACE | 2 + R/plot_model_profile_survival.R | 232 +++++++++++++++++++++++++++----- R/surv_ceteris_paribus.R | 140 ++++++++++--------- 4 files changed, 275 insertions(+), 101 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 93e5581c..0be37ae5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,7 +18,7 @@ Description: Survival analysis models are commonly used in medicine and other ar License: GPL (>= 3) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.1 +RoxygenNote: 7.2.3 Depends: R (>= 3.5.0) Imports: DALEX (>= 2.2.1), diff --git a/NAMESPACE b/NAMESPACE index d055c56b..97a883f1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -24,6 +24,7 @@ S3method(plot,surv_lime) S3method(plot,surv_model_performance) S3method(plot,surv_model_performance_rocs) S3method(plot,surv_shap) +S3method(plot2,model_profile_survival) S3method(predict,surv_explainer) S3method(predict_parts,default) S3method(predict_parts,surv_explainer) @@ -59,6 +60,7 @@ export(loss_one_minus_integrated_cd_auc) export(model_parts) export(model_performance) export(model_profile) +export(plot2) export(predict_parts) export(predict_profile) export(risk_from_chf) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index d8a47053..62f235b4 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -45,53 +45,215 @@ plot.model_profile_survival <- function(x, subtitle = "default", colors = NULL, rug = "all", - rug_colors = c("#dd0000", "#222222")) -{ - - + rug_colors = c("#dd0000", "#222222")) { explanations_list <- c(list(x), list(...)) num_models <- length(explanations_list) - if (num_models == 1){ - result <- prepare_model_profile_plots(x, - variables = variables, - variable_type = variable_type, - facet_ncol = facet_ncol, - numerical_plot_type = numerical_plot_type, - title = title, - subtitle = subtitle, - colors = colors, - rug = rug, - rug_colors = rug_colors) + if (num_models == 1) { + result <- prepare_model_profile_plots(x, + variables = variables, + variable_type = variable_type, + facet_ncol = facet_ncol, + numerical_plot_type = numerical_plot_type, + title = title, + subtitle = subtitle, + colors = colors, + rug = rug, + rug_colors = rug_colors + ) return(result) } return_list <- list() labels <- list() - for (i in 1:num_models){ + for (i in 1:num_models) { this_title <- unique(explanations_list[[i]]$result$`_label_`) - return_list[[i]] <- prepare_model_profile_plots(explanations_list[[i]], - variables = variables, - variable_type = variable_type, - facet_ncol = 1, - numerical_plot_type = numerical_plot_type, - title = this_title, - subtitle = NULL, - colors = colors, - rug = rug, - rug_colors = rug_colors) - labels[[i]] <- c(this_title, rep("", length(return_list[[i]]$patches)-2)) + return_list[[i]] <- prepare_model_profile_plots(explanations_list[[i]], + variables = variables, + variable_type = variable_type, + facet_ncol = 1, + numerical_plot_type = numerical_plot_type, + title = this_title, + subtitle = NULL, + colors = colors, + rug = rug, + rug_colors = rug_colors + ) + labels[[i]] <- c(this_title, rep("", length(return_list[[i]]$patches) - 2)) } labels <- unlist(labels) - return_plot <- patchwork::wrap_plots(return_list, nrow = 1, tag_level="keep") + + return_plot <- patchwork::wrap_plots(return_list, nrow = 1, tag_level = "keep") + patchwork::plot_annotation(title, tag_levels = list(labels)) & theme_default_survex() return(return_plot) +} + + +#' @export +plot2 <- function(x, ...) UseMethod("plot2") + +#' @export +plot2.model_profile_survival <- function(x, + variable, + times = NULL, + marginalize_over_time = FALSE, + plot_type = "pdp+ice", + ..., + title = "Partial dependence profile", + subtitle = "default", + colors = NULL) { + if (is.null(variable) || !is.character(variable)) { + stop("A variable must be specified by name") + } + + if (length(variable) > 1) { + stop("Only one variable can be specified") + } + + if (!variable %in% x$result$`_vname_`) { + stop(paste0("Variable ", variable, " not found")) + } + + if (is.null(times)) { + times <- quantile(x$eval_times, p = 0.5, type = 1) + } + + if (!all(times %in% x$eval_times)) { + stop(paste0( + "For one of the provided times the explanations has not been calculated not found. + Please modify the times argument in your explainer or use only values from the following: ", + paste(x$eval_times, collapse = ", ") + )) + } + + if (!is.null(subtitle) && subtitle == "default") { + subtitle <- paste0("created for the ", unique(x$result$`_label_`), " model") + } + + # Select relevant information from the pdp result + pdp_df <- x$result[(x$result$`_vname_` == variable) & (x$result$`_times_` %in% times), c("_x_", "_yhat_")] + colnames(pdp_df) <- c(variable, "pd") + + # Select relevant information from the ceteris paribus profiles + # TODO: REMOVE THIS WHEN ID IS FIXED + # for (i in 1:length(x$cp_profiles)) { + # return(unique(x$cp_profiles$result$`_vname_`)) + # } + + ice_df <- x$cp_profiles$result[(x$cp_profiles$result$`_vname_` == variable) & + (x$cp_profiles$result$`_times_` %in% times), ] + ice_df$`_times_` <- NULL + ice_df$`_vname_` <- NULL + ice_df$`_vtype_` <- NULL + ice_df$`_label_` <- NULL + + colnames(ice_df)[colnames(ice_df) == "_ids_"] <- "id" + colnames(ice_df)[colnames(ice_df) == "_yhat_"] <- "predictions" + + + feature_name_sym <- sym(variable) + y_floor_pd <- floor(min(pdp_df[, "pd"]) * 10) / 10 + y_ceiling_pd <- ceiling(max(pdp_df[, "pd"]) * 10) / 10 + + single_timepoint <- ((length(times) == 1) || marginalize_over_time) + + return(ice_df) + + if (unique(x$result[x$result$`_vname_` == variable, "_vtype_"]) == "categorical") { + pl <- plot_pdp_cat( + pdp_dt = pdp_df, + ice_dt = ice_df, + data_dt = NULL, + feature_name_sym, + y_floor_ice = NULL, + y_ceiling_ice = NULL, + y_floor_pd = y_floor_pd, + y_ceiling_pd = y_ceiling_pd, + plot_type = plot_type, + single_timepoint = single_timepoint + ) + } else { + pdp_df[, 1] <- as.numeric(as.character(pdp_df[, 1])) + pl <- plot_pdp_num( + pdp_dt = pdp_df, + ice_dt = ice_df, + data_dt = NULL, + feature_name_sym, + y_floor_ice = NULL, + y_ceiling_ice = NULL, + y_floor_pd = y_floor_pd, + y_ceiling_pd = y_ceiling_pd, + plot_type = plot_type, + single_timepoint = single_timepoint + ) + } + pl +} + +plot_pdp_num <- function(pdp_dt, + ice_dt, + data_dt, + feature_name_sym, + y_floor_ice, + y_ceiling_ice, + y_floor_pd, + y_ceiling_pd, + plot_type, + single_timepoint) { + if (single_timepoint == TRUE) { ## single timepoint + if (plot_type == "ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + + geom_line(alpha = 0.2, mapping = aes(group = id)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + } + # PDP + ICE + else if (plot_type == "pdp+ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + + geom_line(mapping = aes(group = id), alpha = 0.2) + + geom_line(data = pdp_dt, aes(x = !!feature_name_sym, y = pd), linewidth = 2, color = "gold") #+ + # geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + } + # PDP + else if (plot_type == "pdp") { + ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + + geom_line() + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + ylim(y_floor_pd, y_ceiling_pd) + } + } else { ## multiple timepoints + if (plot_type == "ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + + geom_line(alpha = 0.2, mapping = aes(group = interaction(id, time), color = time)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + + ylim(y_floor_ice, y_ceiling_ice) + } + # PDP + ICE + else if (plot_type == "pdp+ice") { + ggplot() + + geom_line(data = ice_dt, aes(x = !!feature_name_sym, y = predictions, group = interaction(id, time), color = time), alpha = 0.1) + + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, color = time), linewidth = 1.5, lineend = "round", linejoin = "round") + + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, group = time), color = "black", linewidth = 0.5, linetype = "dashed", lineend = "round", linejoin = "round") + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + + ylim(y_floor_ice, y_ceiling_ice) + } + # PDP + else if (plot_type == "pdp") { + ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + + geom_line(aes(color = time)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + ylim(y_floor_pd, y_ceiling_pd) + } + } +} + +plot_pdp_cat <- function() { } + + prepare_model_profile_plots <- function(x, variables = NULL, variable_type = NULL, @@ -102,7 +264,6 @@ prepare_model_profile_plots <- function(x, colors = NULL, rug = rug, rug_colors = rug_colors) { - rug_df <- data.frame(times = x$event_times, statuses = as.character(x$event_statuses), label = unique(x$result$`_label_`)) aggregated_profiles <- x$result class(aggregated_profiles) <- "data.frame" @@ -110,11 +271,12 @@ prepare_model_profile_plots <- function(x, all_variables <- unique(aggregated_profiles$`_vname_`) if (!is.null(variables)) { all_variables <- intersect(all_variables, variables) - if (length(all_variables) == 0) + if (length(all_variables) == 0) { stop(paste0( "variables do not overlap with ", paste(all_variables, collapse = ", ") )) + } } aggregated_profiles <- aggregated_profiles[aggregated_profiles$`_vname_` %in% all_variables, ] @@ -127,7 +289,8 @@ prepare_model_profile_plots <- function(x, } if (is.null(variables)) { - variables <- unique(aggregated_profiles$`_vname_`)} + variables <- unique(aggregated_profiles$`_vname_`) + } if (!is.null(variable_type) && variable_type == "numerical") { aggregated_profiles <- aggregated_profiles[aggregated_profiles$`_vtype_` == "numerical", ] @@ -145,7 +308,8 @@ prepare_model_profile_plots <- function(x, pl <- plot_individual_ceteris_paribus_survival(aggregated_profiles, variables, colors, numerical_plot_type, rug_df, rug, rug_colors) patchwork::wrap_plots(pl, ncol = facet_ncol) + - patchwork::plot_annotation(title = title, - subtitle = subtitle) & theme_default_survex() - + patchwork::plot_annotation( + title = title, + subtitle = subtitle + ) & theme_default_survex() } diff --git a/R/surv_ceteris_paribus.R b/R/surv_ceteris_paribus.R index c4549d51..77e93cd1 100644 --- a/R/surv_ceteris_paribus.R +++ b/R/surv_ceteris_paribus.R @@ -20,14 +20,13 @@ surv_ceteris_paribus <- function(x, ...) UseMethod("surv_ceteris_paribus", x) #' #' @keywords internal surv_ceteris_paribus.surv_explainer <- function(x, - new_observation, - variables = NULL, - categorical_variables = NULL, - variable_splits = NULL, - grid_points = 101, - variable_splits_type = "uniform", - ...) { - + new_observation, + variables = NULL, + categorical_variables = NULL, + variable_splits = NULL, + grid_points = 101, + variable_splits_type = "uniform", + ...) { test_explainer(x, has_data = TRUE, has_survival = TRUE, has_y = TRUE, function_name = "ceteris_paribus_survival") data <- x$data @@ -36,37 +35,36 @@ surv_ceteris_paribus.surv_explainer <- function(x, predict_survival_function <- x$predict_survival_function times <- x$times - surv_ceteris_paribus.default(x = model, - data = data, - predict_survival_function = predict_survival_function, - new_observation = new_observation, - variables = variables, - categorical_variables = categorical_variables, - variable_splits = variable_splits, - grid_points = grid_points, - variable_splits_type = variable_splits_type, - variable_splits_with_obs = TRUE, - label = label, - times = times, - ...) - + surv_ceteris_paribus.default( + x = model, + data = data, + predict_survival_function = predict_survival_function, + new_observation = new_observation, + variables = variables, + categorical_variables = categorical_variables, + variable_splits = variable_splits, + grid_points = grid_points, + variable_splits_type = variable_splits_type, + variable_splits_with_obs = TRUE, + label = label, + times = times, + ... + ) } surv_ceteris_paribus.default <- function(x, - data, - predict_survival_function = NULL, - new_observation, - variables = NULL, - categorical_variables = NULL, - variable_splits = NULL, - grid_points = 101, - variable_splits_type = "uniform", - variable_splits_with_obs = TRUE, - label = NULL, - times = times, - ...) { - - + data, + predict_survival_function = NULL, + new_observation, + variables = NULL, + categorical_variables = NULL, + variable_splits = NULL, + grid_points = 101, + variable_splits_type = "uniform", + variable_splits_with_obs = TRUE, + label = NULL, + times = times, + ...) { if (is.data.frame(data)) { common_variables <- intersect(colnames(new_observation), colnames(data)) new_observation <- new_observation[, common_variables, drop = FALSE] @@ -83,23 +81,28 @@ surv_ceteris_paribus.default <- function(x, # calculate splits if (is.null(variable_splits)) { - if (is.null(data)) + if (is.null(data)) { stop("The ceteris_paribus() function requires explainers created with specified 'data'.") - if (is.null(variables)) + } + if (is.null(variables)) { variables <- colnames(data) - variable_splits <- calculate_variable_split(data, variables = variables, - categorical_variables = categorical_variables, - grid_points = grid_points, - variable_splits_type = variable_splits_type, - new_observation = if (variable_splits_with_obs) new_observation else NA) - } + variable_splits <- calculate_variable_split(data, + variables = variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + variable_splits_type = variable_splits_type, + new_observation = if (variable_splits_with_obs) new_observation else NA + ) + } - profiles <- calculate_variable_survival_profile(new_observation, - variable_splits, - x, - predict_survival_function, - times) + profiles <- calculate_variable_survival_profile( + new_observation, + variable_splits, + x, + predict_survival_function, + times + ) profiles$`_vtype_` <- ifelse(profiles$`_vname_` %in% categorical_variables, "categorical", "numerical") @@ -110,9 +113,11 @@ surv_ceteris_paribus.default <- function(x, attr(profiles, "times") <- times attr(profiles, "observations") <- new_observation - ret <- list(eval_times = times, - variable_values = new_observation, - result = cbind(profiles, `_label_` = label)) + ret <- list( + eval_times = times, + variable_values = new_observation, + result = cbind(profiles, `_label_` = label) + ) class(ret) <- c("surv_ceteris_paribus", "list") @@ -120,7 +125,7 @@ surv_ceteris_paribus.default <- function(x, } -calculate_variable_split <- function(data, variables = colnames(data), categorical_variables = NULL, grid_points = 101, variable_splits_type = "quantiles", new_observation = NA) { +calculate_variable_split <- function(data, variables = colnames(data), categorical_variables = NULL, grid_points = 101, variable_splits_type = "quantiles", new_observation = NA) { UseMethod("calculate_variable_split", data) } @@ -136,14 +141,17 @@ calculate_variable_split.default <- function(data, variables = colnames(data), c } else { selected_splits <- seq(min(selected_column, na.rm = TRUE), max(selected_column, na.rm = TRUE), length.out = grid_points) } - if (!any(is.na(new_observation))) + if (!any(is.na(new_observation))) { selected_splits <- sort(unique(c(selected_splits, na.omit(new_observation[, var])))) + } } else { if (any(is.na(new_observation))) { selected_splits <- sort(unique(selected_column)) } else { - selected_splits <- sort(unique(rbind(data[, var, drop = FALSE], - new_observation[, var, drop = FALSE])[, 1])) + selected_splits <- sort(unique(rbind( + data[, var, drop = FALSE], + new_observation[, var, drop = FALSE] + )[, 1])) } } selected_splits @@ -159,8 +167,6 @@ calculate_variable_survival_profile <- function(data, variable_splits, model, pr } calculate_variable_survival_profile.default <- function(data, variable_splits, model, predict_survival_function = NULL, times = NULL, ...) { - - variables <- names(variable_splits) prog <- progressr::progressor(along = 1:(length(variables))) profiles <- lapply(variables, function(variable) { @@ -168,11 +174,12 @@ calculate_variable_survival_profile.default <- function(data, variable_splits, m if (is.null(rownames(data))) { - ids <- rep(1:nrow(data), each = length(split_points)) # it never goes here, because null rownames are automatically setted to 1:n + ids <- rep(1:nrow(data), each = length(times)) # it never goes here, because null rownames are automatically setted to 1:n } else { - ids <- rep(rownames(data), each = length(split_points)) + ids <- rep(rownames(data), each = length(times)) } + print(head(data)) new_data <- data[rep(1:nrow(data), each = length(split_points)), , drop = FALSE] @@ -181,11 +188,13 @@ calculate_variable_survival_profile.default <- function(data, variable_splits, m yhat <- c(t(predict_survival_function(model, new_data, times))) new_data <- data.frame(new_data[rep(seq_len(nrow(new_data)), each = length(times)), ], - `_times_` = rep(times, times = nrow(new_data)), - `_yhat_` = yhat, - `_vname_` = variable, - `_ids_` = ids, - check.names = FALSE) + `_times_` = rep(times, times = nrow(new_data)), + `_yhat_` = yhat, + `_vname_` = variable, + `_ids_` = ids, + check.names = FALSE + ) + # print(table(ids)) prog() new_data }) @@ -193,5 +202,4 @@ calculate_variable_survival_profile.default <- function(data, variable_splits, m profile <- do.call(rbind, profiles) class(profile) <- c("individual_variable_profile", class(profile)) profile - } From 6aff34c33756d20ab734f337c661f9a380c27f71 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 17:36:45 +0200 Subject: [PATCH 020/207] remove ggbeeswarm from imports --- DESCRIPTION | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index c09211ea..fc9a5754 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -26,8 +26,7 @@ Imports: kernelshap, pec, survival, - patchwork, - ggbeeswarm + patchwork Suggests: censored, covr, From 02e1aab420afc04a71103e714f02d121a28eeb68 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 17:36:59 +0200 Subject: [PATCH 021/207] fix using util integration --- R/metrics.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/metrics.R b/R/metrics.R index 9cdb9e3c..bf9bdd82 100644 --- a/R/metrics.R +++ b/R/metrics.R @@ -35,7 +35,7 @@ loss_integrate <- function(loss_function, ..., normalization = NULL , max_quanti loss_values <- loss_values[na_mask] surv <- surv[na_mask] - calculate_integral(loss_values, times, normalization, y_true) + calculate_integral(loss_values, times, normalization, y_true=y_true) } attr(integrated_loss_function, "loss_type") <- "integrated" From c52f8a37488102289e18d72882bc43ca3c1495e2 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 17:37:12 +0200 Subject: [PATCH 022/207] fix examples --- R/model_survshap.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/R/model_survshap.R b/R/model_survshap.R index 9b7c28cf..fb5151d9 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -26,17 +26,16 @@ #' rsf_ranger_exp <- explain( #' rsf_ranger, #' data = veteran[, -c(3, 4)], -#' y = Surv(veteran$time, veteran$status), +#' y = survival::Surv(veteran$time, veteran$status), #' verbose = FALSE #' ) #' #' ranger_global_survshap <- model_survshap( #' explainer = rsf_ranger_exp, #' new_observation = veteran[1:40, !colnames(veteran) %in% c("time", "status")], -#' y_true = Surv(veteran$time[1:40], veteran$status[1:40]), +#' y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), #' aggregation_method = "mean_absolute", #' calculation_method = "kernelshap", -#' N = 5L #' ) #' plot(ranger_global_survshap) #' } From 5051763c3f66c9b32bd6d6b91f898e31ca28e2cf Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 17:37:51 +0200 Subject: [PATCH 023/207] ggplot with df --- R/plot_surv_shap.R | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index b86f794c..7b1ca0a3 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -192,7 +192,7 @@ plot_shap_global_swarm <- function(x, ) } - + with(df, { ggplot(data = df, aes(x = shap_value, y = variable, color = var_value)) + geom_vline(xintercept = 0, color = "#ceced9", linetype="solid") + geom_jitter(width=0) + @@ -211,10 +211,10 @@ plot_shap_global_swarm <- function(x, theme_default_survex() + theme(legend.position = "bottom") + guides(color = guide_colorbar(title.position = "top", title.hjust = 0.5)) + } + ) } - - plot_shap_global_profile <- function(x, ..., variable = NULL, @@ -249,7 +249,8 @@ plot_shap_global_profile <- function(x, ) } - p <- ggplot(df, aes(x = variable_val, y = shap_val, color = color_variable_val)) + + p <- with(df, { + ggplot(df, aes(x = variable_val, y = shap_val, color = color_variable_val)) + geom_hline(yintercept = 0, color = "#ceced9", linetype="solid") + geom_point() + geom_rug(aes(x = df$variable_val), inherit.aes=F, color = "#ceced9") + @@ -259,6 +260,7 @@ plot_shap_global_profile <- function(x, subtitle = subtitle) + theme_default_survex() + theme(legend.position = "bottom") + }) if (!is.factor(df$color_variable_val)) { p + scale_color_gradient2( From f0a36b6ed1d78765e9d6f696a650163c8d11144d Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 17:37:58 +0200 Subject: [PATCH 024/207] update docs --- man/model_survshap.surv_explainer.Rd | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 6d0a4d20..47bfbd27 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -54,17 +54,16 @@ rsf_ranger <- ranger::ranger( rsf_ranger_exp <- explain( rsf_ranger, data = veteran[, -c(3, 4)], - y = Surv(veteran$time, veteran$status), + y = survival::Surv(veteran$time, veteran$status), verbose = FALSE ) ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, new_observation = veteran[1:40, !colnames(veteran) \%in\% c("time", "status")], - y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), aggregation_method = "mean_absolute", calculation_method = "kernelshap", - N = 5L ) plot(ranger_global_survshap) } From 9969a92e8197fabdbe8747dd9084248af70ca912 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 26 Jul 2023 17:38:30 +0200 Subject: [PATCH 025/207] make tests faster --- tests/testthat/test-model_survshap.R | 56 ++++++++++++++-------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index 77561f5c..c86f2163 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -3,50 +3,50 @@ veteran <- survival::veteran rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) +cph <- survival::coxph( + survival::Surv(time, status) ~ ., + data = veteran, + model = TRUE, + x = TRUE, + y = TRUE + ) +cph_exp <- explain(cph, verbose = FALSE) + test_that("global survshap explanations with kernelshap work for ranger, using new data", { ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, - new_observation = veteran[1:40, !colnames(veteran) %in% c("time", "status")], - y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + new_observation = veteran[1:10, !colnames(veteran) %in% c("time", "status")], + y_true = Surv(veteran$time[1:10], veteran$status[1:10]), aggregation_method = "mean_absolute", calculation_method = "kernelshap" ) plot(ranger_global_survshap) + plot(ranger_global_survshap, kind = "swarm") + plot(ranger_global_survshap, kind = "profile") + plot(ranger_global_survshap, kind = "profile", variable = "karno", color_variable = "celltype") + plot(ranger_global_survshap, kind = "profile", variable = "karno", color_variable = "age") + expect_s3_class(ranger_global_survshap, c("aggregated_surv_shap", "surv_shap")) expect_equal(length(ranger_global_survshap$eval_times), length(rsf_ranger_exp$times)) expect_true(all(names(ranger_global_survshap$variable_values) == colnames(rsf_ranger_exp$data))) - - # test functioning of special case, when providing new observation and specify N - ranger_global_survshap <- model_survshap( - explainer = rsf_ranger_exp, - new_observation = veteran[1:40, !colnames(veteran) %in% c("time", "status")], - y_true = Surv(veteran$time[1:40], veteran$status[1:40]), - aggregation_method = "mean_absolute", - calculation_method = "kernelshap", - N = 6 - ) - }) -test_that("global survshap explanations with kernelshap work for ranger, using explainer data", { +test_that("global survshap explanations with kernelshap work for coxph, using explainer data", { # using all explainer data - ranger_global_survshap <- model_survshap( - explainer = rsf_ranger_exp, - aggregation_method = "mean_absolute", + cph_global_survshap <- model_survshap( + explainer = cph_exp, calculation_method = "kernelshap" ) - plot(ranger_global_survshap) - - # using only 6 observations - ranger_global_survshap <- model_survshap( - explainer = rsf_ranger_exp, - aggregation_method = "mean_absolute", - calculation_method = "kernelshap", - N = 6L - ) - plot(ranger_global_survshap) - + plot(cph_global_survshap) + plot(cph_global_survshap, kind = "swarm") + plot(cph_global_survshap, kind = "profile") + plot(cph_global_survshap, kind = "profile", variable = "karno", color_variable = "celltype") + plot(cph_global_survshap, kind = "profile", variable = "karno", color_variable = "age") + + expect_s3_class(cph_global_survshap, c("aggregated_surv_shap", "surv_shap")) + expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times)) + expect_true(all(names(cph_global_survshap$variable_values) == colnames(cph_exp$data))) }) From f80b8b3c54f740b888c61daf78a71c0bbfb59ba6 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 10:03:21 +0200 Subject: [PATCH 026/207] add ale for numerical variables --- R/surv_model_profiles.R | 207 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 R/surv_model_profiles.R diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R new file mode 100644 index 00000000..0f89f192 --- /dev/null +++ b/R/surv_model_profiles.R @@ -0,0 +1,207 @@ +#' Helper functions for `model_profile.R` +#' +#' @param x an object containing calculated ceteris_paribus profiles +#' @param ... other parameters, ignored +#' @param variable_type character, either `"numerical"` or `"categorical"`, the type of variable to be calculated, if left `NULL` (default), both are calculated +#' @param variables a character vector containing names of variables to be explained +#' @param center logical, if the profiles should be centered before aggregations +#' +#' @return A data.frame with calculated results. +#' +#' @keywords internal +surv_aggregate_profiles <- function(x, ..., + variable_type = NULL, + groups = NULL, + variables = NULL, + center = FALSE) { + + + + + all_profiles <- x$result + class(all_profiles) <- "data.frame" + + all_profiles$`_ids_` <- factor(all_profiles$`_ids_`) + + + # variables to use + all_variables <- na.omit(as.character(unique(all_profiles$`_vname_`))) + if (!is.null(variables)) { + all_variables_intersect <- intersect(all_variables, variables) + if (length(all_variables_intersect) == 0) stop(paste0("parameter variables do not overlap with ", paste(all_variables, collapse = ", "))) + all_variables <- all_variables_intersect + } + + if (!is.null(variable_type) && variable_type == "numerical") { + all_profiles <- all_profiles[all_profiles$`_vtype_` == "numerical", ] + } + + if (!is.null(variable_type) && variable_type == "categorical") { + all_profiles <- all_profiles[all_profiles$`_vtype_` == "categorical", ] + } + + all_variables <- intersect(all_variables, unique(all_profiles$`_vname_`)) + + # select only suitable variables + all_profiles <- all_profiles[all_profiles$`_vname_` %in% all_variables, ] + # create _x_ + tmp <- as.character(all_profiles$`_vname_`) + for (viname in unique(tmp)) { + all_profiles$`_x_`[tmp == viname] <- all_profiles[tmp == viname, viname] + } + + if (!inherits(class(all_profiles), "data.frame")) { + all_profiles <- as.data.frame(all_profiles) + } + + # change x column to proper character values + for (variable in all_variables) { + if (variable %in% all_profiles[all_profiles$`_vtype_` == "categorical", "_vname_"]) + all_profiles[all_profiles$`_vname_` == variable, ]$`_x_` <- as.character(apply(all_profiles[all_profiles$`_vname_` == variable, ], 1, function(all_profiles) all_profiles[all_profiles["_vname_"]])) + } + + aggregated_profiles <- surv_aggregate_profiles_partial(all_profiles) + class(aggregated_profiles) <- c("aggregated_survival_profiles_explainer", + "partial_dependence_survival_explainer", + "data.frame") + + return(aggregated_profiles) +} + + +surv_aggregate_profiles_partial <- function(all_profiles) { + + tmp <- all_profiles[, c("_vname_", "_vtype_", "_label_", "_x_", "_yhat_", "_times_")] + aggregated_profiles <- aggregate(tmp$`_yhat_`, by = list(tmp$`_vname_`, tmp$`_vtype_`, tmp$`_label_`, tmp$`_x_`, tmp$`_times_`), FUN = mean, na.rm = TRUE) + colnames(aggregated_profiles) <- c("_vname_", "_vtype_", "_label_", "_x_", "_times_", "_yhat_") + aggregated_profiles$`_ids_` <- 0 + + # for factors, keep proper order + # as in https://github.com/ModelOriented/ingredients/issues/82 + if (!is.numeric(all_profiles$`_x_`)) { + aggregated_profiles$`_x_` <- factor(aggregated_profiles$`_x_`, levels = unique(all_profiles$`_x_`)) + aggregated_profiles <- aggregated_profiles[order(aggregated_profiles$`_x_`), ] + } + + aggregated_profiles +} + + + +#' @keywords internal +surv_ale <- function(x, ..., + data, + variables, + categorical_variables, + grid_points) { + + test_explainer(x, has_data = TRUE, has_survival = TRUE, function_name = "surv_ale") + + + # change categorical_features to column names + if (is.numeric(categorical_variables)) categorical_variables <- colnames(data)[categorical_variables] + additional_categorical_variables <- categorical_variables + factor_variables <- colnames(data)[sapply(data, is.factor)] + categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) + + model <- x$model + label <- x$label + predict_survival_function <- x$predict_survival_function + times <- x$times + + profiles <- lapply(variables, function(variable) { + + # Number of quantile points for determined by grid length + quantile_vals <- as.numeric(quantile(data[,variable], + seq(0.01, 1, length.out = grid_points), + type = 1)) + + # Quantile points vector + quantile_vec <- c(min(data[,variable]), quantile_vals) + quantile_vec <- unique(quantile_vec) + + quantile_df <- data.frame(id = 1:length(quantile_vec), value = quantile_vec) + + # Match feature instances to quantile intervals + interval_index <- findInterval(data[,variable], quantile_vec, left.open = TRUE) + + # Points in interval 0 should be in interval 1 + interval_index[interval_index == 0] <- 1 + + # Prepare datasets with upper and lower interval limits replacing original feature values + X_lower <- X_upper <- data + X_lower[, variable] <- quantile_vec[interval_index] + X_upper[, variable] <- quantile_vec[interval_index + 1] + + # Get survival predictions for instances of upper and lower interval limits + predictions_lower = predict_survival_function(model = model, + newdata = X_lower, + times = times) + predictions_upper = predict_survival_function(model = model, + newdata = X_upper, + times = times) + + predictions_original = predict_survival_function(model = model, + newdata = data, + times = times) + mean_pred <- colMeans(predictions_original) + + # First order finite differences + prediction_deltas <- predictions_upper - predictions_lower + # Rename columns to timepoints for which predictions were made + colnames(prediction_deltas) <- times + + deltas <- data.frame( + x = rep(X_lower[,variable], each=length(times)), + interval = rep(interval_index, each=length(times)), + time = rep(times, times = nrow(data)), + yhat = c(t(prediction_deltas)) + ) + + deltas <- aggregate(yhat ~ interval + time, data = deltas, FUN = mean) + deltas$yhat_cumsum <- ave(deltas$yhat, deltas$time, FUN = cumsum) + interval_n <- as.numeric(table(interval_index)) + n <- sum(interval_n) + + ale_means <- aggregate(yhat_cumsum ~ time, data = deltas, FUN = function(x) { + sum(((c(0, x[1:(length(x)-1)]) + x) / 2) * interval_n / n) + }) + colnames(ale_means)[2] <- "ale0" + + # Centering the ALEs to obtain final ALE values + ale_values <- merge(deltas, + ale_means, + all.x = TRUE, + by = "time") + + ale_values$ale <- ale_values$yhat_cumsum - ale_values$ale0 + ale_values$interval <- ale_values$interval + 1 + ale_values1 <- ale_values[seq(1, nrow(ale_values), length(quantile_vec)-1),] + ale_values1$interval <- 1 + ale_values <- rbind(ale_values, ale_values1) + + ale_values <- merge(ale_values, + quantile_df, + by.x = "interval", + by.y = "id") + ale_values <- ale_values[order(ale_values$interval, ale_values$time), ] + ale_values$ale <- ale_values$ale + mean_pred + + + data.frame(`_vname_` = variable, + `_vtype_` = "numerical", + `_label_` = label, + `_x_` = ale_values$value, + `_times_` = ale_values$time, + `_yhat_` = ale_values$ale, + `_ids_` = 0, + check.names = FALSE + ) + } + ) + profiles <- do.call(rbind, profiles) + class(profiles) <- c("aggregated_survival_profiles_explainer", + "accumulated_local_effects_survival_explainer", + "data.frame") + return(profiles) +} From a283ddea42815518cfea4fc4c487e5a9634d3ff0 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 10:03:40 +0200 Subject: [PATCH 027/207] remove old file --- R/surv_aggregate_profiles.R | 98 ------------------------------------- 1 file changed, 98 deletions(-) delete mode 100644 R/surv_aggregate_profiles.R diff --git a/R/surv_aggregate_profiles.R b/R/surv_aggregate_profiles.R deleted file mode 100644 index 6f203ba4..00000000 --- a/R/surv_aggregate_profiles.R +++ /dev/null @@ -1,98 +0,0 @@ -#' Helper functions for `model_profile.R` -#' -#' @param x an object containing calculated ceteris_paribus profiles -#' @param ... other parameters, ignored -#' @param variable_type character, either `"numerical"` or `"categorical"`, the type of variable to be calculated, if left `NULL` (default), both are calculated -#' @param groups unused, left for compatibility -#' @param type character, only `"partial"` is implemented -#' @param variables a character vector containing names of variables to be explained -#' @param center logical, if the profiles should be centered before aggregations -#' -#' @return A data.frame with calculated results. -#' -#' @keywords internal -surv_aggregate_profiles <- function(x, ..., - variable_type = NULL, - groups = NULL, - type = "partial", - variables = NULL, - center = FALSE) { - - - - - all_profiles <- x$result - class(all_profiles) <- "data.frame" - - all_profiles$`_ids_` <- factor(all_profiles$`_ids_`) - - - # variables to use - all_variables <- na.omit(as.character(unique(all_profiles$`_vname_`))) - if (!is.null(variables)) { - all_variables_intersect <- intersect(all_variables, variables) - if (length(all_variables_intersect) == 0) stop(paste0("parameter variables do not overlap with ", paste(all_variables, collapse = ", "))) - all_variables <- all_variables_intersect - } - - if (!is.null(variable_type) && variable_type == "numerical") { - all_profiles <- all_profiles[all_profiles$`_vtype_` == "numerical", ] - } - - if (!is.null(variable_type) && variable_type == "categorical") { - all_profiles <- all_profiles[all_profiles$`_vtype_` == "categorical", ] - } - - all_variables <- intersect(all_variables, unique(all_profiles$`_vname_`)) - - # select only suitable variables - all_profiles <- all_profiles[all_profiles$`_vname_` %in% all_variables, ] - # create _x_ - tmp <- as.character(all_profiles$`_vname_`) - for (viname in unique(tmp)) { - all_profiles$`_x_`[tmp == viname] <- all_profiles[tmp == viname, viname] - } - - if (!inherits(class(all_profiles), "data.frame")) { - all_profiles <- as.data.frame(all_profiles) - } - - # change x column to proper character values - for (variable in all_variables) { - if (variable %in% all_profiles[all_profiles$`_vtype_` == "categorical", "_vname_"]) - all_profiles[all_profiles$`_vname_` == variable, ]$`_x_` <- as.character(apply(all_profiles[all_profiles$`_vname_` == variable, ], 1, function(all_profiles) all_profiles[all_profiles["_vname_"]])) - } - - if (type == "partial") { - aggregated_profiles <- surv_aggregate_profiles_partial(all_profiles) - class(aggregated_profiles) <- c("aggregated_survival_profiles_explainer", - "partial_dependence_survival_explainer", - "data.frame") - } - if (type == "conditional") { - stop("Not implemented") - } - if (type == "accumulated") { - stop("Not implemented") - } - - return(aggregated_profiles) -} - - -surv_aggregate_profiles_partial <- function(all_profiles) { - - tmp <- all_profiles[, c("_vname_", "_vtype_", "_label_", "_x_", "_yhat_", "_times_")] - aggregated_profiles <- aggregate(tmp$`_yhat_`, by = list(tmp$`_vname_`, tmp$`_vtype_`, tmp$`_label_`, tmp$`_x_`, tmp$`_times_`), FUN = mean, na.rm = TRUE) - colnames(aggregated_profiles) <- c("_vname_", "_vtype_", "_label_", "_x_", "_times_", "_yhat_") - aggregated_profiles$`_ids_` <- 0 - - # for factors, keep proper order - # as in https://github.com/ModelOriented/ingredients/issues/82 - if (!is.numeric(all_profiles$`_x_`)) { - aggregated_profiles$`_x_` <- factor(aggregated_profiles$`_x_`, levels = unique(all_profiles$`_x_`)) - aggregated_profiles <- aggregated_profiles[order(aggregated_profiles$`_x_`), ] - } - - aggregated_profiles -} From e1cce633f811d0f04e27be620cd8cfce158679dd Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 10:04:11 +0200 Subject: [PATCH 028/207] move error to the right place --- R/surv_ceteris_paribus.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/R/surv_ceteris_paribus.R b/R/surv_ceteris_paribus.R index c4549d51..66707f90 100644 --- a/R/surv_ceteris_paribus.R +++ b/R/surv_ceteris_paribus.R @@ -79,12 +79,11 @@ surv_ceteris_paribus.default <- function(x, factor_variables <- colnames(data)[sapply(data, is.factor)] categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) - + if (is.null(data)) + stop("The ceteris_paribus() function requires explainers created with specified 'data'.") # calculate splits if (is.null(variable_splits)) { - if (is.null(data)) - stop("The ceteris_paribus() function requires explainers created with specified 'data'.") if (is.null(variables)) variables <- colnames(data) variable_splits <- calculate_variable_split(data, variables = variables, From d83188ac9e724e5963ddcda5e8fb0c568b33c066 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 10:04:28 +0200 Subject: [PATCH 029/207] add ale to model_profile --- R/model_profile.R | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/R/model_profile.R b/R/model_profile.R index 478056aa..d73edca2 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -92,20 +92,34 @@ model_profile.surv_explainer <- function(explainer, ndata <- data } - cp_profiles <- surv_ceteris_paribus(explainer, + if (type == "partial"){ + cp_profiles <- surv_ceteris_paribus(explainer, new_observation = ndata, variables = variables, categorical_variables = categorical_variables, grid_points = grid_points, ...) - agr_profiles <- surv_aggregate_profiles(cp_profiles, ..., - groups = groups, - type = type, - variables = variables, - center = center) + result <- surv_aggregate_profiles(cp_profiles, ..., + variables = variables, + center = center) + } else if (type == "accumulated"){ + cp_profiles <- NULL - ret <- list(eval_times = unique(agr_profiles$`_times_`), cp_profiles = cp_profiles, result = agr_profiles) + result <- surv_ale(explainer, + data = ndata, + variables = variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + ...) + } else { + stop("Currently only `partial` and `accumulated` types are implemented") + } + + ret <- list(eval_times = unique(result$`_times_`), + cp_profiles = cp_profiles, + result = result, + type = type) class(ret) <- c("model_profile_survival", "list") ret$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] ret$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] From 045ef560323f7a4f2e04d5b336b7ffb4c7d1c52e Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 10:04:34 +0200 Subject: [PATCH 030/207] add ale plots --- R/plot_model_profile_survival.R | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index d8a47053..6702f845 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -41,7 +41,7 @@ plot.model_profile_survival <- function(x, variable_type = NULL, facet_ncol = NULL, numerical_plot_type = "lines", - title = "Partial dependence survival profile", + title = "default", subtitle = "default", colors = NULL, rug = "all", @@ -51,6 +51,12 @@ plot.model_profile_survival <- function(x, explanations_list <- c(list(x), list(...)) num_models <- length(explanations_list) + if (title == "default"){ + if (x$type == "partial") + title <- "Partial dependence survival profiles" + if (x$type == "accumulated") + title <- "Accumulated local effects survival profiles" + } if (num_models == 1){ result <- prepare_model_profile_plots(x, @@ -97,7 +103,7 @@ prepare_model_profile_plots <- function(x, variable_type = NULL, facet_ncol = NULL, numerical_plot_type = "lines", - title = "Partial dependence survival profile", + title = "default", subtitle = "default", colors = NULL, rug = rug, From 2b0be077af7965237663e6653b06ee9591bf4ad3 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 12:29:38 +0200 Subject: [PATCH 031/207] add order levels for ALE --- R/utils.R | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/R/utils.R b/R/utils.R index fe489d65..63a4a2b3 100644 --- a/R/utils.R +++ b/R/utils.R @@ -185,3 +185,83 @@ add_rug_to_plot <- function(base_plot, rug_df, rug, rug_colors){ } } + +#' Order levels of a categorical features +#' +#' @description +#' Orders the levels by their similarity in other features. Computes per feature +#' the distance, sums up all distances and does multi-dimensional scaling +#' +#' @details +#' Goal: Compute the distances between two categories. +#' Input: Instances from category 1 and 2 +#' +#' 1. For all features, do (excluding the categorical feature for which we are computing the order): +#' - If the feature is numerical: Take instances from category 1, calculate the +#' empirical cumulative probability distribution function (ecdf) of the +#' feature. The ecdf is a function that tells us for a given feature value, how +#' many values are smaller. Do the same for category 2. The distance is the +#' absolute maximum point-wise distance of the two ecdf. Practically, this +#' value is high when the distribution from one category is strongly shifted +#' far away from the other. This measure is also known as the +#' Kolmogorov-Smirnov distance +#' (\url{https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test}). +#' - If the feature is categorical: Take instances from category 1 and +#' calculate a table with the relative frequency of each category of the other +#' feature. Do the same for instances from category 2. The distance is the sum +#' of the absolute difference of both relative frequency tables. +#' 2. Sum up the distances over all features +#' +#' This algorithm we run for all pairs of categories. +#' Then we have a k times k matrix, when k is the number of categories, where +#' each entry is the distance between two categories. Still not enough to have a +#' single order, because, a (dis)similarity tells you the pair-wise distances, +#' but does not give you a one-dimensional ordering of the classes. To kind of +#' force this thing into a single dimension, we have to use a dimension +#' reduction trick called multi-dimensional scaling. This can be solved using +#' multi-dimensional scaling, which takes in a distance matrix and returns a +#' distance matrix with reduced dimension. In our case, we only want 1 dimension +#' left, so that we have a single ordering of the categories and can compute the +#' accumulated local effects. After reducing it to a single ordering, we are +#' done and can use this ordering to compute ALE. This is not the Holy Grail how +#' to order the factors, but one possibility. +#' +#' @param data_dt data.frame with the training data +#' @param feature_name the name of the categorical feature +#' @return the order of the levels (not levels itself) +#' @keywords internal +order_levels <- function(data, variable) { + data[, variable] <- droplevels(data[, variable]) + feature <- data[, variable] + x.count <- as.numeric(table(data[, variable])) + x.prob <- x.count / sum(x.count) + K <- nlevels(data[, variable]) + + dists <- lapply(setdiff(colnames(data), variable), function(x) { + feature.x <- data[, x] + dists <- expand.grid(levels(feature), levels(feature)) + colnames(dists) <- c("from.level", "to.level") + if (inherits(feature.x, "factor")) { + A <- table(feature, feature.x) / x.count + dists$dist <- rowSums(abs(A[dists[, "from.level"], ] - A[dists[, "to.level"], ])) / 2 + } else { + quants <- quantile(feature.x, probs = seq(0, 1, length.out = 100), na.rm = TRUE, names = FALSE) + ecdfs <- data.frame(lapply(levels(feature), function(lev) { + x.ecdf <- ecdf(feature.x[feature == lev])(quants) + })) + colnames(ecdfs) <- levels(feature) + ecdf.dists.all <- abs(ecdfs[, dists$from.level] - ecdfs[, dists$to.level]) + dists$dist <- apply(ecdf.dists.all, 2, max) + } + dists + }) + + dists.cumulated.long <- Reduce(function(d1, d2) { + d1$dist <- d1$dist + d2$dist + d1 + }, dists) + dists.cumulated <- xtabs(dist ~ from.level + to.level, dists.cumulated.long) + scaled <- cmdscale(dists.cumulated, k = 1) + order(scaled) +} + From 58b89fe16f7db1d60cb92ad9d4a6aa00362a15f7 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 12:29:58 +0200 Subject: [PATCH 032/207] change print.model_profile for ALE --- R/print.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/print.R b/R/print.R index 436ececf..635c6c1e 100644 --- a/R/print.R +++ b/R/print.R @@ -1,9 +1,9 @@ #' @export print.model_profile_survival <- function(x, ...) { res <- x$result - + method_name <- ifelse(x$type == "partial", "Partial dependence", "Accumulated local effects") res <- res[order(res$`_vname_`), ] - text <- paste0("Partial dependence for the ", unique(res$`_label_`), " model:\n") + text <- paste0(method_name, " survival profiles for the ", unique(res$`_label_`), " model:\n") cat(text) print.data.frame(res[, !colnames(res) %in% c("_label_", "_ids_")], ...) } From a1c656cc287377e736cfa0e0cd9e8584d61856a6 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 12:30:16 +0200 Subject: [PATCH 033/207] add categorical ALE --- R/surv_model_profiles.R | 420 ++++++++++++++++++++++++++++------------ 1 file changed, 299 insertions(+), 121 deletions(-) diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index 0f89f192..de1b2cbd 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -9,15 +9,12 @@ #' @return A data.frame with calculated results. #' #' @keywords internal -surv_aggregate_profiles <- function(x, ..., - variable_type = NULL, - groups = NULL, - variables = NULL, - center = FALSE) { - - - - +surv_aggregate_profiles <- function(x, + ..., + variable_type = NULL, + groups = NULL, + variables = NULL, + center = FALSE) { all_profiles <- x$result class(all_profiles) <- "data.frame" @@ -25,29 +22,39 @@ surv_aggregate_profiles <- function(x, ..., # variables to use - all_variables <- na.omit(as.character(unique(all_profiles$`_vname_`))) + all_variables <- + na.omit(as.character(unique(all_profiles$`_vname_`))) if (!is.null(variables)) { all_variables_intersect <- intersect(all_variables, variables) - if (length(all_variables_intersect) == 0) stop(paste0("parameter variables do not overlap with ", paste(all_variables, collapse = ", "))) + if (length(all_variables_intersect) == 0) + stop(paste0( + "parameter variables do not overlap with ", + paste(all_variables, collapse = ", ") + )) all_variables <- all_variables_intersect } if (!is.null(variable_type) && variable_type == "numerical") { - all_profiles <- all_profiles[all_profiles$`_vtype_` == "numerical", ] + all_profiles <- + all_profiles[all_profiles$`_vtype_` == "numerical",] } if (!is.null(variable_type) && variable_type == "categorical") { - all_profiles <- all_profiles[all_profiles$`_vtype_` == "categorical", ] + all_profiles <- + all_profiles[all_profiles$`_vtype_` == "categorical",] } - all_variables <- intersect(all_variables, unique(all_profiles$`_vname_`)) + all_variables <- + intersect(all_variables, unique(all_profiles$`_vname_`)) # select only suitable variables - all_profiles <- all_profiles[all_profiles$`_vname_` %in% all_variables, ] + all_profiles <- + all_profiles[all_profiles$`_vname_` %in% all_variables,] # create _x_ tmp <- as.character(all_profiles$`_vname_`) for (viname in unique(tmp)) { - all_profiles$`_x_`[tmp == viname] <- all_profiles[tmp == viname, viname] + all_profiles$`_x_`[tmp == viname] <- + all_profiles[tmp == viname, viname] } if (!inherits(class(all_profiles), "data.frame")) { @@ -57,30 +64,52 @@ surv_aggregate_profiles <- function(x, ..., # change x column to proper character values for (variable in all_variables) { if (variable %in% all_profiles[all_profiles$`_vtype_` == "categorical", "_vname_"]) - all_profiles[all_profiles$`_vname_` == variable, ]$`_x_` <- as.character(apply(all_profiles[all_profiles$`_vname_` == variable, ], 1, function(all_profiles) all_profiles[all_profiles["_vname_"]])) + all_profiles[all_profiles$`_vname_` == variable,]$`_x_` <- + as.character(apply(all_profiles[all_profiles$`_vname_` == variable,], 1, function(all_profiles) + all_profiles[all_profiles["_vname_"]])) } - aggregated_profiles <- surv_aggregate_profiles_partial(all_profiles) - class(aggregated_profiles) <- c("aggregated_survival_profiles_explainer", - "partial_dependence_survival_explainer", - "data.frame") + aggregated_profiles <- + surv_aggregate_profiles_partial(all_profiles) + class(aggregated_profiles) <- + c( + "aggregated_survival_profiles_explainer", + "partial_dependence_survival_explainer", + "data.frame" + ) return(aggregated_profiles) } surv_aggregate_profiles_partial <- function(all_profiles) { - - tmp <- all_profiles[, c("_vname_", "_vtype_", "_label_", "_x_", "_yhat_", "_times_")] - aggregated_profiles <- aggregate(tmp$`_yhat_`, by = list(tmp$`_vname_`, tmp$`_vtype_`, tmp$`_label_`, tmp$`_x_`, tmp$`_times_`), FUN = mean, na.rm = TRUE) - colnames(aggregated_profiles) <- c("_vname_", "_vtype_", "_label_", "_x_", "_times_", "_yhat_") + tmp <- + all_profiles[, c("_vname_", "_vtype_", "_label_", "_x_", "_yhat_", "_times_")] + aggregated_profiles <- + aggregate( + tmp$`_yhat_`, + by = list( + tmp$`_vname_`, + tmp$`_vtype_`, + tmp$`_label_`, + tmp$`_x_`, + tmp$`_times_` + ), + FUN = mean, + na.rm = TRUE + ) + colnames(aggregated_profiles) <- + c("_vname_", "_vtype_", "_label_", "_x_", "_times_", "_yhat_") aggregated_profiles$`_ids_` <- 0 # for factors, keep proper order # as in https://github.com/ModelOriented/ingredients/issues/82 if (!is.numeric(all_profiles$`_x_`)) { - aggregated_profiles$`_x_` <- factor(aggregated_profiles$`_x_`, levels = unique(all_profiles$`_x_`)) - aggregated_profiles <- aggregated_profiles[order(aggregated_profiles$`_x_`), ] + aggregated_profiles$`_x_` <- + factor(aggregated_profiles$`_x_`, + levels = unique(all_profiles$`_x_`)) + aggregated_profiles <- + aggregated_profiles[order(aggregated_profiles$`_x_`),] } aggregated_profiles @@ -89,119 +118,268 @@ surv_aggregate_profiles_partial <- function(all_profiles) { #' @keywords internal -surv_ale <- function(x, ..., +surv_ale <- function(x, + ..., data, variables, categorical_variables, grid_points) { + test_explainer( + x, + has_data = TRUE, + has_survival = TRUE, + function_name = "surv_ale" + ) - test_explainer(x, has_data = TRUE, has_survival = TRUE, function_name = "surv_ale") - + if (is.null(variables)) + variables <- colnames(data) # change categorical_features to column names - if (is.numeric(categorical_variables)) categorical_variables <- colnames(data)[categorical_variables] + if (is.numeric(categorical_variables)) + categorical_variables <- colnames(data)[categorical_variables] additional_categorical_variables <- categorical_variables factor_variables <- colnames(data)[sapply(data, is.factor)] - categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) + categorical_variables <- + unique(c(additional_categorical_variables, factor_variables)) model <- x$model label <- x$label predict_survival_function <- x$predict_survival_function times <- x$times - profiles <- lapply(variables, function(variable) { - - # Number of quantile points for determined by grid length - quantile_vals <- as.numeric(quantile(data[,variable], - seq(0.01, 1, length.out = grid_points), - type = 1)) - - # Quantile points vector - quantile_vec <- c(min(data[,variable]), quantile_vals) - quantile_vec <- unique(quantile_vec) - - quantile_df <- data.frame(id = 1:length(quantile_vec), value = quantile_vec) + # Make predictions for original levels + predictions_original <- predict_survival_function(model = model, + newdata = data, + times = times) + mean_pred <- colMeans(predictions_original) - # Match feature instances to quantile intervals - interval_index <- findInterval(data[,variable], quantile_vec, left.open = TRUE) - # Points in interval 0 should be in interval 1 - interval_index[interval_index == 0] <- 1 - - # Prepare datasets with upper and lower interval limits replacing original feature values + profiles <- lapply(variables, function(variable) { X_lower <- X_upper <- data - X_lower[, variable] <- quantile_vec[interval_index] - X_upper[, variable] <- quantile_vec[interval_index + 1] - - # Get survival predictions for instances of upper and lower interval limits - predictions_lower = predict_survival_function(model = model, - newdata = X_lower, - times = times) - predictions_upper = predict_survival_function(model = model, - newdata = X_upper, - times = times) - - predictions_original = predict_survival_function(model = model, - newdata = data, - times = times) - mean_pred <- colMeans(predictions_original) - - # First order finite differences - prediction_deltas <- predictions_upper - predictions_lower - # Rename columns to timepoints for which predictions were made - colnames(prediction_deltas) <- times - - deltas <- data.frame( - x = rep(X_lower[,variable], each=length(times)), - interval = rep(interval_index, each=length(times)), - time = rep(times, times = nrow(data)), - yhat = c(t(prediction_deltas)) - ) + variable_values <- data[, variable] + + if (variable %in% categorical_variables) { + if (!is.factor(variable_values)){ + data[, variable] <- as.factor(data[, variable]) + variable_values <- as.factor(variable_values) + } + levels_original <- levels(droplevels(variable_values)) + levels_n <- nlevels(droplevels(variable_values)) + if (inherits(variable_values, "ordered")) { + level_order <- 1:levels_n + } else { + level_order <- order_levels(data, variable) + } + + # The new order of the levels + levels_ordered <- levels_original[level_order] + + # The feature with the levels in the new order + x_ordered <- + order(level_order)[as.numeric(droplevels(variable_values))] + + # Filter rows which are not already at maximum or minimum level values + row_ind_increase <- (1:nrow(data))[x_ordered < levels_n] + row_ind_decrease <- (1:nrow(data))[x_ordered > 1] + X_lower[row_ind_decrease, variable] <- + levels_ordered[x_ordered[row_ind_decrease] - 1] + X_upper[row_ind_increase, variable] <- + levels_ordered[x_ordered[row_ind_increase] + 1] + + # Make predictions for decreased levels (excluding minimum levels) + predictions_lower <- + predict_survival_function(model = model, + newdata = X_lower[row_ind_decrease,], + times = times) + + # Make predictions for increased levels (excluding maximum levels) + predictions_upper <- + predict_survival_function(model = model, + newdata = X_upper[row_ind_increase,], + times = times) + + d_increase <- + predictions_upper - predictions_original[row_ind_increase,] + d_decrease <- + predictions_original[row_ind_decrease,] - predictions_lower + prediction_deltas <- rbind(d_increase, d_decrease) + colnames(prediction_deltas) <- times + + + deltas <- data.frame( + interval = rep(c(x_ordered[row_ind_increase], + x_ordered[row_ind_decrease] - 1), + each = length(times)), + time = rep(times, times = nrow(prediction_deltas)), + yhat = c(t(prediction_deltas)) + ) + + deltas <- + aggregate(yhat ~ interval + time, + data = deltas, + FUN = mean) + deltas1 <- deltas[deltas$interval == 1,] + deltas1$yhat <- 0 + deltas$interval <- deltas$interval + 1 + deltas <- rbind(deltas, deltas1) + deltas <- deltas[order(deltas$time, deltas$interval),] + rownames(deltas) <- NULL + deltas$yhat_cumsum <- + ave(deltas$yhat, deltas$time, FUN = cumsum) + + x_count <- as.numeric(table(variable_values)) + x_prob <- x_count / sum(x_count) + + ale_means <- + aggregate( + yhat_cumsum ~ time, + data = deltas, + FUN = function(x) { + sum(x * x_prob[level_order]) + } + ) + colnames(ale_means)[2] <- "ale0" + + ale_values <- merge(deltas, + ale_means, + all.x = TRUE, + by = "time") + + ale_values$ale <- + ale_values$yhat_cumsum - ale_values$ale0 + ale_values$level <- levels_ordered[ale_values$interval] + + ale_values <- + ale_values[order(ale_values$interval, ale_values$time),] + ale_values$ale <- ale_values$ale + mean_pred + + return( + data.frame( + `_vname_` = variable, + `_vtype_` = "categorical", + `_label_` = label, + `_x_` = ale_values$level, + `_times_` = ale_values$time, + `_yhat_` = ale_values$ale, + `_ids_` = 0, + check.names = FALSE + ) + ) + + } else { + # Number of quantile points for determined by grid length + quantile_vals <- as.numeric(quantile( + variable_values, + seq(0.01, 1, length.out = grid_points), + type = 1 + )) + + # Quantile points vector + quantile_vec <- c(min(variable_values), quantile_vals) + quantile_vec <- unique(quantile_vec) + + quantile_df <- + data.frame(id = 1:length(quantile_vec), + value = quantile_vec) + + # Match feature instances to quantile intervals + interval_index <- + findInterval(variable_values, quantile_vec, left.open = TRUE) + + # Points in interval 0 should be in interval 1 + interval_index[interval_index == 0] <- 1 + + # Prepare datasets with upper and lower interval limits replacing original feature values + X_lower[, variable] <- quantile_vec[interval_index] + X_upper[, variable] <- quantile_vec[interval_index + 1] + # Get survival predictions for instances of upper and lower interval limits + predictions_lower <- + predict_survival_function(model = model, + newdata = X_lower, + times = times) + predictions_upper <- + predict_survival_function(model = model, + newdata = X_upper, + times = times) + + # First order finite differences + prediction_deltas <- + predictions_upper - predictions_lower + # Rename columns to timepoints for which predictions were made + colnames(prediction_deltas) <- times + + deltas <- data.frame( + x = rep(X_lower[, variable], each = length(times)), + interval = rep(interval_index, each = length(times)), + time = rep(times, times = nrow(data)), + yhat = c(t(prediction_deltas)) + ) + + deltas <- + aggregate(yhat ~ interval + time, + data = deltas, + FUN = mean) + deltas$yhat_cumsum <- + ave(deltas$yhat, deltas$time, FUN = cumsum) + interval_n <- as.numeric(table(interval_index)) + n <- sum(interval_n) + + ale_means <- + aggregate( + yhat_cumsum ~ time, + data = deltas, + FUN = function(x) { + sum(((c( + 0, x[1:(length(x) - 1)] + ) + x) / 2) * interval_n / n) + } + ) + colnames(ale_means)[2] <- "ale0" + + # Centering the ALEs to obtain final ALE values + ale_values <- merge(deltas, + ale_means, + all.x = TRUE, + by = "time") + + ale_values$ale <- + ale_values$yhat_cumsum - ale_values$ale0 + ale_values$interval <- ale_values$interval + 1 + ale_values1 <- + ale_values[seq(1, nrow(ale_values), length(quantile_vec) - 1), ] + ale_values1$interval <- 1 + ale_values <- rbind(ale_values, ale_values1) + + ale_values <- merge(ale_values, + quantile_df, + by.x = "interval", + by.y = "id") + ale_values <- + ale_values[order(ale_values$interval, ale_values$time),] + ale_values$ale <- ale_values$ale + mean_pred + + + return( + data.frame( + `_vname_` = variable, + `_vtype_` = "numerical", + `_label_` = label, + `_x_` = ale_values$value, + `_times_` = ale_values$time, + `_yhat_` = ale_values$ale, + `_ids_` = 0, + check.names = FALSE + ) + ) + } + + }) - deltas <- aggregate(yhat ~ interval + time, data = deltas, FUN = mean) - deltas$yhat_cumsum <- ave(deltas$yhat, deltas$time, FUN = cumsum) - interval_n <- as.numeric(table(interval_index)) - n <- sum(interval_n) - - ale_means <- aggregate(yhat_cumsum ~ time, data = deltas, FUN = function(x) { - sum(((c(0, x[1:(length(x)-1)]) + x) / 2) * interval_n / n) - }) - colnames(ale_means)[2] <- "ale0" - - # Centering the ALEs to obtain final ALE values - ale_values <- merge(deltas, - ale_means, - all.x = TRUE, - by = "time") - - ale_values$ale <- ale_values$yhat_cumsum - ale_values$ale0 - ale_values$interval <- ale_values$interval + 1 - ale_values1 <- ale_values[seq(1, nrow(ale_values), length(quantile_vec)-1),] - ale_values1$interval <- 1 - ale_values <- rbind(ale_values, ale_values1) - - ale_values <- merge(ale_values, - quantile_df, - by.x = "interval", - by.y = "id") - ale_values <- ale_values[order(ale_values$interval, ale_values$time), ] - ale_values$ale <- ale_values$ale + mean_pred - - - data.frame(`_vname_` = variable, - `_vtype_` = "numerical", - `_label_` = label, - `_x_` = ale_values$value, - `_times_` = ale_values$time, - `_yhat_` = ale_values$ale, - `_ids_` = 0, - check.names = FALSE - ) - } - ) profiles <- do.call(rbind, profiles) - class(profiles) <- c("aggregated_survival_profiles_explainer", - "accumulated_local_effects_survival_explainer", - "data.frame") + class(profiles) <- c( + "aggregated_survival_profiles_explainer", + "accumulated_local_effects_survival_explainer", + "data.frame" + ) return(profiles) } From 9b39ae86d159587fb6f34a1c5a64a68526e1786f Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 12:48:53 +0200 Subject: [PATCH 034/207] change default value for center param --- R/model_profile.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/model_profile.R b/R/model_profile.R index d73edca2..e97e5401 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -13,7 +13,7 @@ #' @param groups if `output_type == "risk"` a variable name that will be used for grouping. By default `NULL`, so no groups are calculated. If `output_type == "survival"` then ignored #' @param k passed to `DALEX::model_profile` if `output_type == "risk"`, otherwise ignored #' @param center logical, should profiles be centered before clustering -#' @param type the type of variable profile. If `output_type == "survival"` then only `"partial"` is implemented, otherwise passed to `DALEX::model_profile`. +#' @param type the type of variable profile, `"partial"` for Partial Dependence, `"accumulated"` for Accumulated Local Effects, or `"conditional"` (available only for `output_type == "risk"`) #' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the `DALEX::model_profile` function. #' #' @return An object of class `model_profile_survival`. It is a list with the element `result` containing the results of the calculation. @@ -54,7 +54,7 @@ model_profile <- function(explainer, ..., groups = NULL, k = NULL, - center = TRUE, + center = FALSE, type = "partial", output_type = "survival") UseMethod("model_profile", explainer) From f496db491c824b3ff75dcea2535499f328be6a41 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 12:49:21 +0200 Subject: [PATCH 035/207] add optional centering and docs --- R/surv_model_profiles.R | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index de1b2cbd..810f03d1 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -116,14 +116,24 @@ surv_aggregate_profiles_partial <- function(all_profiles) { } - +#' @param x an explainer object - model preprocessed by the `explain()` function +#' @param ... other parameters, ignored +#' @param data data used to create explanations +#' @param variables a character vector containing names of variables to be explained +#' @param categorical_variables character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the `variables` argument, they will be added at the end. +#' @param grid_points maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default `51`. +#' @param center logical, if the profiles should be centered before aggregations +#' +#' @return A data.frame with calculated results. +#' #' @keywords internal surv_ale <- function(x, ..., data, variables, categorical_variables, - grid_points) { + grid_points, + center = FALSE) { test_explainer( x, has_data = TRUE, @@ -153,7 +163,6 @@ surv_ale <- function(x, times = times) mean_pred <- colMeans(predictions_original) - profiles <- lapply(variables, function(variable) { X_lower <- X_upper <- data variable_values <- data[, variable] @@ -251,8 +260,9 @@ surv_ale <- function(x, ale_values <- ale_values[order(ale_values$interval, ale_values$time),] - ale_values$ale <- ale_values$ale + mean_pred - + if (!center){ + ale_values$ale <- ale_values$ale + mean_pred + } return( data.frame( `_vname_` = variable, @@ -357,7 +367,9 @@ surv_ale <- function(x, ale_values <- ale_values[order(ale_values$interval, ale_values$time),] ale_values$ale <- ale_values$ale + mean_pred - + if (!center){ + ale_values$ale <- ale_values$ale + mean_pred + } return( data.frame( From 05b6cacac2be6fc2e987a9cd16b9e655e8cdf1b8 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 12:49:36 +0200 Subject: [PATCH 036/207] remove detailed description --- R/utils.R | 44 +------------------------------------------- 1 file changed, 1 insertion(+), 43 deletions(-) diff --git a/R/utils.R b/R/utils.R index 63a4a2b3..c6b203c7 100644 --- a/R/utils.R +++ b/R/utils.R @@ -186,49 +186,7 @@ add_rug_to_plot <- function(base_plot, rug_df, rug, rug_colors){ } -#' Order levels of a categorical features -#' -#' @description -#' Orders the levels by their similarity in other features. Computes per feature -#' the distance, sums up all distances and does multi-dimensional scaling -#' -#' @details -#' Goal: Compute the distances between two categories. -#' Input: Instances from category 1 and 2 -#' -#' 1. For all features, do (excluding the categorical feature for which we are computing the order): -#' - If the feature is numerical: Take instances from category 1, calculate the -#' empirical cumulative probability distribution function (ecdf) of the -#' feature. The ecdf is a function that tells us for a given feature value, how -#' many values are smaller. Do the same for category 2. The distance is the -#' absolute maximum point-wise distance of the two ecdf. Practically, this -#' value is high when the distribution from one category is strongly shifted -#' far away from the other. This measure is also known as the -#' Kolmogorov-Smirnov distance -#' (\url{https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test}). -#' - If the feature is categorical: Take instances from category 1 and -#' calculate a table with the relative frequency of each category of the other -#' feature. Do the same for instances from category 2. The distance is the sum -#' of the absolute difference of both relative frequency tables. -#' 2. Sum up the distances over all features -#' -#' This algorithm we run for all pairs of categories. -#' Then we have a k times k matrix, when k is the number of categories, where -#' each entry is the distance between two categories. Still not enough to have a -#' single order, because, a (dis)similarity tells you the pair-wise distances, -#' but does not give you a one-dimensional ordering of the classes. To kind of -#' force this thing into a single dimension, we have to use a dimension -#' reduction trick called multi-dimensional scaling. This can be solved using -#' multi-dimensional scaling, which takes in a distance matrix and returns a -#' distance matrix with reduced dimension. In our case, we only want 1 dimension -#' left, so that we have a single ordering of the categories and can compute the -#' accumulated local effects. After reducing it to a single ordering, we are -#' done and can use this ordering to compute ALE. This is not the Holy Grail how -#' to order the factors, but one possibility. -#' -#' @param data_dt data.frame with the training data -#' @param feature_name the name of the categorical feature -#' @return the order of the levels (not levels itself) +# based on iml::order_levels #' @keywords internal order_levels <- function(data, variable) { data[, variable] <- droplevels(data[, variable]) From 86e4b80407bd7a764664329667022c0114fabd1e Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 12:49:46 +0200 Subject: [PATCH 037/207] update tests --- tests/testthat/test-model_profile.R | 48 +++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 1aca8119..5e902445 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -1,4 +1,4 @@ -test_that("model_profile works", { +test_that("model_profile with type = 'partial' works", { veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] @@ -53,12 +53,54 @@ test_that("model_profile works", { model_profile(rsf_ranger_exp, varaibles = "trt", grid_points = 6) - - expect_error(model_profile(rsf_ranger_exp, type = "accumulated")) expect_error(model_profile(rsf_ranger_exp, type = "conditional")) }) +test_that("model_profile with type = 'accumulated' works", { + + veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] + type <- 'accumulated' + + cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) + rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) + + cph_exp <- explain(cph, verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_src_exp <- explain(rsf_src, verbose = FALSE) + + + mp_cph_cat <- model_profile(cph_exp, + output_type = "survival", + variable_type = "categorical", + grid_points = 6, + type = type) + plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") + + expect_s3_class(mp_cph_cat, "model_profile_survival") + expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) + expect_equal(ncol(mp_cph_cat$result), 7) + expect_true(all(unique(mp_cph_cat$result$`_vname_`) %in% colnames(cph_exp$data))) + + + mp_cph_num <- model_profile(cph_exp, + output_type = "survival", + variable_type = "numerical", + grid_points = 6, + type = type) + plot(mp_cph_num, variable_type = "numerical") + plot(mp_cph_num, numerical_plot_type = "contours") + + expect_s3_class(mp_cph_num, "model_profile_survival") + expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) + expect_equal(ncol(mp_cph_num$result), 7) + expect_true(all(unique(mp_cph_num$result$`_vname_`) %in% colnames(cph_exp$data))) + + expect_output(print(mp_cph_num)) + expect_error(plot(mp_rsf_num, variables = "nonexistent", grid_points = 6)) +}) + test_that("default DALEX::model_profile is ok", { veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] From 341c7ef2a3aecdbf699ca76ae6227b9326081a8c Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 12:49:52 +0200 Subject: [PATCH 038/207] add docs --- man/model_profile.surv_explainer.Rd | 4 ++-- man/plot.model_profile_survival.Rd | 2 +- man/surv_aggregate_profiles.Rd | 7 +------ 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/man/model_profile.surv_explainer.Rd b/man/model_profile.surv_explainer.Rd index f9750ff6..41f1a487 100644 --- a/man/model_profile.surv_explainer.Rd +++ b/man/model_profile.surv_explainer.Rd @@ -12,7 +12,7 @@ model_profile( ..., groups = NULL, k = NULL, - center = TRUE, + center = FALSE, type = "partial", output_type = "survival" ) @@ -46,7 +46,7 @@ model_profile( \item{center}{logical, should profiles be centered before clustering} -\item{type}{the type of variable profile. If \code{output_type == "survival"} then only \code{"partial"} is implemented, otherwise passed to \code{DALEX::model_profile}.} +\item{type}{the type of variable profile, \code{"partial"} for Partial Dependence, \code{"accumulated"} for Accumulated Local Effects, or \code{"conditional"} (available only for \code{output_type == "risk"})} \item{output_type}{either \code{"survival"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the \code{DALEX::model_profile} function.} diff --git a/man/plot.model_profile_survival.Rd b/man/plot.model_profile_survival.Rd index eb575806..338f3cff 100644 --- a/man/plot.model_profile_survival.Rd +++ b/man/plot.model_profile_survival.Rd @@ -11,7 +11,7 @@ variable_type = NULL, facet_ncol = NULL, numerical_plot_type = "lines", - title = "Partial dependence survival profile", + title = "default", subtitle = "default", colors = NULL, rug = "all", diff --git a/man/surv_aggregate_profiles.Rd b/man/surv_aggregate_profiles.Rd index 6f7b9032..89570cf3 100644 --- a/man/surv_aggregate_profiles.Rd +++ b/man/surv_aggregate_profiles.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/surv_aggregate_profiles.R +% Please edit documentation in R/surv_model_profiles.R \name{surv_aggregate_profiles} \alias{surv_aggregate_profiles} \title{Helper functions for \code{model_profile.R}} @@ -9,7 +9,6 @@ surv_aggregate_profiles( ..., variable_type = NULL, groups = NULL, - type = "partial", variables = NULL, center = FALSE ) @@ -21,10 +20,6 @@ surv_aggregate_profiles( \item{variable_type}{character, either \code{"numerical"} or \code{"categorical"}, the type of variable to be calculated, if left \code{NULL} (default), both are calculated} -\item{groups}{unused, left for compatibility} - -\item{type}{character, only \code{"partial"} is implemented} - \item{variables}{a character vector containing names of variables to be explained} \item{center}{logical, if the profiles should be centered before aggregations} From c0f41029dce7816f53908b21b039bf35387e2959 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 13:09:21 +0200 Subject: [PATCH 039/207] bump roxygen version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 93e5581c..0be37ae5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,7 +18,7 @@ Description: Survival analysis models are commonly used in medicine and other ar License: GPL (>= 3) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.1 +RoxygenNote: 7.2.3 Depends: R (>= 3.5.0) Imports: DALEX (>= 2.2.1), From a227bbb4619a88f5871162dd1c4336d9c6c8599c Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 13:09:32 +0200 Subject: [PATCH 040/207] add imports from stats --- NAMESPACE | 4 ++++ R/surv_model_profiles.R | 2 ++ R/utils.R | 1 + 3 files changed, 7 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index d055c56b..2dca1757 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -78,6 +78,9 @@ importFrom(DALEX,theme_ema) importFrom(DALEX,theme_ema_vertical) importFrom(stats,aggregate) importFrom(stats,as.formula) +importFrom(stats,ave) +importFrom(stats,cmdscale) +importFrom(stats,ecdf) importFrom(stats,median) importFrom(stats,model.frame) importFrom(stats,model.matrix) @@ -88,6 +91,7 @@ importFrom(stats,quantile) importFrom(stats,reorder) importFrom(stats,rnorm) importFrom(stats,stepfun) +importFrom(stats,xtabs) importFrom(utils,head) importFrom(utils,stack) importFrom(utils,tail) diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index 810f03d1..883503f1 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -124,6 +124,8 @@ surv_aggregate_profiles_partial <- function(all_profiles) { #' @param grid_points maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default `51`. #' @param center logical, if the profiles should be centered before aggregations #' +#' @importFrom stats ave +#' #' @return A data.frame with calculated results. #' #' @keywords internal diff --git a/R/utils.R b/R/utils.R index c6b203c7..a4e1f8bc 100644 --- a/R/utils.R +++ b/R/utils.R @@ -187,6 +187,7 @@ add_rug_to_plot <- function(base_plot, rug_df, rug, rug_colors){ # based on iml::order_levels +#' @importFrom stats ecdf xtabs cmdscale #' @keywords internal order_levels <- function(data, variable) { data[, variable] <- droplevels(data[, variable]) From b329afa878bb6956947834d38c217fa1e32b7f9b Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 13:45:57 +0200 Subject: [PATCH 041/207] fix centering --- R/surv_model_profiles.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index 883503f1..c8298990 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -368,7 +368,7 @@ surv_ale <- function(x, by.y = "id") ale_values <- ale_values[order(ale_values$interval, ale_values$time),] - ale_values$ale <- ale_values$ale + mean_pred + if (!center){ ale_values$ale <- ale_values$ale + mean_pred } From e1cb3015aad2ed6d643d85fe05b8c0a8c42a0d6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 31 Jul 2023 14:24:22 +0200 Subject: [PATCH 042/207] Fix ids column in pdp explanation results --- R/surv_ceteris_paribus.R | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/R/surv_ceteris_paribus.R b/R/surv_ceteris_paribus.R index 77e93cd1..f7db18a1 100644 --- a/R/surv_ceteris_paribus.R +++ b/R/surv_ceteris_paribus.R @@ -169,18 +169,15 @@ calculate_variable_survival_profile <- function(data, variable_splits, model, pr calculate_variable_survival_profile.default <- function(data, variable_splits, model, predict_survival_function = NULL, times = NULL, ...) { variables <- names(variable_splits) prog <- progressr::progressor(along = 1:(length(variables))) - profiles <- lapply(variables, function(variable) { - split_points <- variable_splits[[variable]] - - if (is.null(rownames(data))) { - ids <- rep(1:nrow(data), each = length(times)) # it never goes here, because null rownames are automatically setted to 1:n - } else { - ids <- rep(rownames(data), each = length(times)) - } - - print(head(data)) + if (is.null(rownames(data))) { + ids <- 1:nrow(data) # it never goes here, because null rownames are automatically setted to 1:n + } else { + ids <- rownames(data) + } + profiles <- lapply(variables, function(variable) { + split_points <- variable_splits[[variable]] new_data <- data[rep(1:nrow(data), each = length(split_points)), , drop = FALSE] new_data[, variable] <- rep(split_points, nrow(data)) @@ -191,7 +188,7 @@ calculate_variable_survival_profile.default <- function(data, variable_splits, m `_times_` = rep(times, times = nrow(new_data)), `_yhat_` = yhat, `_vname_` = variable, - `_ids_` = ids, + `_ids_` = rep(ids, each = length(times) * length(split_points)), check.names = FALSE ) # print(table(ids)) From 10d2101b72a3eb76b11b331a60e137d0baee8f9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 31 Jul 2023 14:24:40 +0200 Subject: [PATCH 043/207] Working implelmentation of other pdp/ale plots. --- R/plot_model_profile_survival.R | 107 ++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 27 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 62f235b4..30b0d433 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -131,19 +131,35 @@ plot2.model_profile_survival <- function(x, subtitle <- paste0("created for the ", unique(x$result$`_label_`), " model") } - # Select relevant information from the pdp result - pdp_df <- x$result[(x$result$`_vname_` == variable) & (x$result$`_times_` %in% times), c("_x_", "_yhat_")] - colnames(pdp_df) <- c(variable, "pd") + single_timepoint <- ((length(times) == 1) || marginalize_over_time) + is_categorical <- (unique(x$result[x$result$`_vname_` == variable, "_vtype_"]) == "categorical") - # Select relevant information from the ceteris paribus profiles - # TODO: REMOVE THIS WHEN ID IS FIXED - # for (i in 1:length(x$cp_profiles)) { - # return(unique(x$cp_profiles$result$`_vname_`)) - # } + if (single_timepoint) { + pdp_df <- x$result[(x$result$`_vname_` == variable) & (x$result$`_times_` %in% times), c("_x_", "_yhat_")] + colnames(pdp_df) <- c(variable, "pd") + } else { + pdp_df <- x$result[(x$result$`_vname_` == variable) & (x$result$`_times_` %in% times), c("_x_", "_times_", "_yhat_")] + colnames(pdp_df) <- c(variable, "time", "pd") + pdp_df$time <- as.factor(pdp_df$time) + } ice_df <- x$cp_profiles$result[(x$cp_profiles$result$`_vname_` == variable) & (x$cp_profiles$result$`_times_` %in% times), ] - ice_df$`_times_` <- NULL + + if (single_timepoint) { + ice_df$`_times_` <- NULL + } else { + colnames(ice_df)[colnames(ice_df) == "_times_"] <- "time" + ice_df$time <- as.factor(ice_df$time) + } + + + if (is_categorical) { + pdp_df[, variable] <- as.factor(pdp_df[, variable]) + ice_df[, variable] <- as.factor(ice_df[, variable]) + } + + ice_df$`_vname_` <- NULL ice_df$`_vtype_` <- NULL ice_df$`_label_` <- NULL @@ -151,23 +167,25 @@ plot2.model_profile_survival <- function(x, colnames(ice_df)[colnames(ice_df) == "_ids_"] <- "id" colnames(ice_df)[colnames(ice_df) == "_yhat_"] <- "predictions" - feature_name_sym <- sym(variable) + + data_df <- x$cp_profiles$variable_values + # print(ice_df) + y_floor_pd <- floor(min(pdp_df[, "pd"]) * 10) / 10 y_ceiling_pd <- ceiling(max(pdp_df[, "pd"]) * 10) / 10 + y_floor_ice <- floor(min(ice_df[, "predictions"]) * 10) / 10 + y_ceiling_ice <- ceiling(max(ice_df[, "predictions"]) * 10) / 10 - single_timepoint <- ((length(times) == 1) || marginalize_over_time) - return(ice_df) - - if (unique(x$result[x$result$`_vname_` == variable, "_vtype_"]) == "categorical") { + if (is_categorical) { pl <- plot_pdp_cat( pdp_dt = pdp_df, ice_dt = ice_df, - data_dt = NULL, - feature_name_sym, - y_floor_ice = NULL, - y_ceiling_ice = NULL, + data_dt = data_df, + feature_name_count_sym = feature_name_sym, + y_floor_ice = y_floor_ice, + y_ceiling_ice = y_ceiling_ice, y_floor_pd = y_floor_pd, y_ceiling_pd = y_ceiling_pd, plot_type = plot_type, @@ -178,10 +196,10 @@ plot2.model_profile_survival <- function(x, pl <- plot_pdp_num( pdp_dt = pdp_df, ice_dt = ice_df, - data_dt = NULL, - feature_name_sym, - y_floor_ice = NULL, - y_ceiling_ice = NULL, + data_dt = data_df, + feature_name_sym = feature_name_sym, + y_floor_ice = y_floor_ice, + y_ceiling_ice = y_ceiling_ice, y_floor_pd = y_floor_pd, y_ceiling_pd = y_ceiling_pd, plot_type = plot_type, @@ -205,14 +223,16 @@ plot_pdp_num <- function(pdp_dt, if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(alpha = 0.2, mapping = aes(group = id)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + ylim(y_floor_ice, y_ceiling_ice) } # PDP + ICE else if (plot_type == "pdp+ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(mapping = aes(group = id), alpha = 0.2) + - geom_line(data = pdp_dt, aes(x = !!feature_name_sym, y = pd), linewidth = 2, color = "gold") #+ - # geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + geom_line(data = pdp_dt, aes(x = !!feature_name_sym, y = pd), linewidth = 2, color = "gold") + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + ylim(y_floor_ice, y_ceiling_ice) } # PDP else if (plot_type == "pdp") { @@ -247,8 +267,41 @@ plot_pdp_num <- function(pdp_dt, } } -plot_pdp_cat <- function() { - +plot_pdp_cat <- function(pdp_dt, + ice_dt, + data_dt, + feature_name_count_sym, + y_floor_ice, + y_ceiling_ice, + y_floor_pd, + y_ceiling_pd, + plot_type, + single_timepoint) { + if (single_timepoint == TRUE) { ## single timepoint + if (plot_type == "ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + + geom_boxplot(alpha = 0.2) + } else if (plot_type == "pdp+ice") { + ggplot() + + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = "gold") + } else if (plot_type == "pdp") { + ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd), ) + + geom_bar(stat = "identity", width = 0.5) + } + } else { + if (plot_type == "ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + + geom_boxplot(alpha = 0.2, mapping = aes(color = time)) + } else if (plot_type == "pdp+ice") { + ggplot(mapping = aes(color = time)) + + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = time), linewidth = 0.6) + } else if (plot_type == "pdp") { + ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, fill = time)) + + geom_bar(stat = "identity", width = 0.5, position = "dodge") + } + } } From f9d23c902a9de914588bc5fd375ab01ad389a62b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 31 Jul 2023 15:32:47 +0200 Subject: [PATCH 044/207] Add documentation and theming of plots --- R/plot_model_profile_survival.R | 85 ++++++++++++++++++++++++----- man/plot2.model_profile_survival.Rd | 62 +++++++++++++++++++++ 2 files changed, 134 insertions(+), 13 deletions(-) create mode 100644 man/plot2.model_profile_survival.Rd diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 30b0d433..262df349 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -93,6 +93,40 @@ plot.model_profile_survival <- function(x, #' @export plot2 <- function(x, ...) UseMethod("plot2") +#' Plot Model Profile for Survival Models (without continuous time aspect) +#' +#' This function plots objects of class `"model_profile_survival"` created +#' using the `model_profile()` function. +#' +#' @param x an object of class `model_profile_survival` to be plotted +#' @param variable character, name of a single variable to be plotted +#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If `NULL` (default) then the median time from the explainer object is used. +#' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately +#' @param plot_type character, one of `"pdp"`, `"ice"`, `"pdp+ice"` selects the type of plot to be drawn +#' @param ... other parameters. Currently ignored. +#' @param title character, title of the plot +#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels +#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") +#' +#' @return A `ggplot` object. +#' +#' @examples +#' \donttest{ +#' library(survival) +#' library(survex) +#' +#' model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) +#' exp <- explain(model) +#' +#' m_prof <- model_profile(exp, categorical_variables = "trt") +#' +#' plot2(m_prof_cph, variable = "karno", plot_type = "pdp+ice") +#' +#' plot2(m_prof_cph, times = c(1, 2.72), variable = "karno", plot_type = "pdp+ice") +#' +#' plot2(m_prof_cph, times = c(1, 2.72), variable = "celltype", plot_type = "pdp+ice") +#' } +#' #' @export plot2.model_profile_survival <- function(x, variable, @@ -128,7 +162,7 @@ plot2.model_profile_survival <- function(x, } if (!is.null(subtitle) && subtitle == "default") { - subtitle <- paste0("created for the ", unique(x$result$`_label_`), " model") + subtitle <- paste0("created for the ", unique(variable), " variable") } single_timepoint <- ((length(times) == 1) || marginalize_over_time) @@ -170,13 +204,19 @@ plot2.model_profile_survival <- function(x, feature_name_sym <- sym(variable) data_df <- x$cp_profiles$variable_values - # print(ice_df) y_floor_pd <- floor(min(pdp_df[, "pd"]) * 10) / 10 y_ceiling_pd <- ceiling(max(pdp_df[, "pd"]) * 10) / 10 y_floor_ice <- floor(min(ice_df[, "predictions"]) * 10) / 10 y_ceiling_ice <- ceiling(max(ice_df[, "predictions"]) * 10) / 10 + if (marginalize_over_time) { + color_scale <- generate_discrete_color_scale(1, colors) + } else { + color_scale <- generate_discrete_color_scale(length(times), colors) + } + + if (is_categorical) { pl <- plot_pdp_cat( @@ -189,7 +229,8 @@ plot2.model_profile_survival <- function(x, y_floor_pd = y_floor_pd, y_ceiling_pd = y_ceiling_pd, plot_type = plot_type, - single_timepoint = single_timepoint + single_timepoint = single_timepoint, + colors = color_scale ) } else { pdp_df[, 1] <- as.numeric(as.character(pdp_df[, 1])) @@ -203,10 +244,17 @@ plot2.model_profile_survival <- function(x, y_floor_pd = y_floor_pd, y_ceiling_pd = y_ceiling_pd, plot_type = plot_type, - single_timepoint = single_timepoint + single_timepoint = single_timepoint, + colors = color_scale ) } - pl + + pl + + labs( + title = title, + subtitle = subtitle + ) + + theme_default_survex() } plot_pdp_num <- function(pdp_dt, @@ -218,7 +266,8 @@ plot_pdp_num <- function(pdp_dt, y_floor_pd, y_ceiling_pd, plot_type, - single_timepoint) { + single_timepoint, + colors) { if (single_timepoint == TRUE) { ## single timepoint if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + @@ -246,6 +295,7 @@ plot_pdp_num <- function(pdp_dt, ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(alpha = 0.2, mapping = aes(group = interaction(id, time), color = time)) + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + + scale_color_manual(name = "time", values = colors) + ylim(y_floor_ice, y_ceiling_ice) } # PDP + ICE @@ -255,6 +305,7 @@ plot_pdp_num <- function(pdp_dt, geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, color = time), linewidth = 1.5, lineend = "round", linejoin = "round") + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, group = time), color = "black", linewidth = 0.5, linetype = "dashed", lineend = "round", linejoin = "round") + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + + scale_color_manual(name = "time", values = colors) + ylim(y_floor_ice, y_ceiling_ice) } # PDP @@ -262,6 +313,7 @@ plot_pdp_num <- function(pdp_dt, ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + geom_line(aes(color = time)) + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + scale_color_manual(name = "time", values = colors) + ylim(y_floor_pd, y_ceiling_pd) } } @@ -276,30 +328,37 @@ plot_pdp_cat <- function(pdp_dt, y_floor_pd, y_ceiling_pd, plot_type, - single_timepoint) { + single_timepoint, + colors) { if (single_timepoint == TRUE) { ## single timepoint if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + - geom_boxplot(alpha = 0.2) + geom_boxplot(alpha = 0.2) + + scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp+ice") { ggplot() + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + - geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = "gold") + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = "gold") + + scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp") { ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd), ) + - geom_bar(stat = "identity", width = 0.5) + geom_bar(stat = "identity", width = 0.5) + + scale_fill_manual(name = "time", values = colors) } } else { if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + - geom_boxplot(alpha = 0.2, mapping = aes(color = time)) + geom_boxplot(alpha = 0.2, mapping = aes(color = time)) + + scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp+ice") { ggplot(mapping = aes(color = time)) + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + - geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = time), linewidth = 0.6) + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = time), linewidth = 0.6) + + scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp") { ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, fill = time)) + - geom_bar(stat = "identity", width = 0.5, position = "dodge") + geom_bar(stat = "identity", width = 0.5, position = "dodge") + + scale_fill_manual(name = "time", values = colors) } } } diff --git a/man/plot2.model_profile_survival.Rd b/man/plot2.model_profile_survival.Rd new file mode 100644 index 00000000..939e5309 --- /dev/null +++ b/man/plot2.model_profile_survival.Rd @@ -0,0 +1,62 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot_model_profile_survival.R +\name{plot2.model_profile_survival} +\alias{plot2.model_profile_survival} +\title{Plot Model Profile for Survival Models (without continuous time aspect)} +\usage{ +\method{plot2}{model_profile_survival}( + x, + variable, + times = NULL, + marginalize_over_time = FALSE, + plot_type = "pdp+ice", + ..., + title = "Partial dependence profile", + subtitle = "default", + colors = NULL +) +} +\arguments{ +\item{x}{an object of class \code{model_profile_survival} to be plotted} + +\item{variable}{character, name of a single variable to be plotted} + +\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used.} + +\item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately} + +\item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"} selects the type of plot to be drawn} + +\item{...}{other parameters. Currently ignored.} + +\item{title}{character, title of the plot} + +\item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} + +\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} +} +\value{ +A \code{ggplot} object. +} +\description{ +This function plots objects of class \code{"model_profile_survival"} created +using the \code{model_profile()} function. +} +\examples{ +\donttest{ +library(survival) +library(survex) + +model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) +exp <- explain(model) + +m_prof <- model_profile(exp, categorical_variables = "trt") + +plot2(m_prof_cph, variable = "karno", plot_type = "pdp+ice") + +plot2(m_prof_cph, times = c(1, 2.72), variable = "karno", plot_type = "pdp+ice") + +plot2(m_prof_cph, times = c(1, 2.72), variable = "celltype", plot_type = "pdp+ice") +} + +} From 994210ae81b026eb9666405d8895aa4b6327affb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 31 Jul 2023 15:39:05 +0200 Subject: [PATCH 045/207] Add plot type check --- R/plot_model_profile_survival.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 262df349..f82be00b 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -137,6 +137,10 @@ plot2.model_profile_survival <- function(x, title = "Partial dependence profile", subtitle = "default", colors = NULL) { + if (!plot_type %in% c("pdp", "ice", "pdp+ice")) { + stop("plot_type must be one of 'pdp', 'ice', 'pdp+ice'") + } + if (is.null(variable) || !is.character(variable)) { stop("A variable must be specified by name") } From 5dd663dd4964eab4f8884497cc777b52bb5bb6cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 31 Jul 2023 15:54:14 +0200 Subject: [PATCH 046/207] Add Sophie to authors --- DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION b/DESCRIPTION index 0be37ae5..e50477b3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -5,6 +5,7 @@ Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), person("Mateusz", "Krzyziński", role = c("aut"), comment = c(ORCID = "0000-0001-6143-488X")), + person("Sophie", "Langbein", role = c("aut")), person("Hubert", "Baniecki", role = c("aut"), comment = c(ORCID = "0000-0001-6661-5364")), person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823")) ) From 5a135a0bf7ba600ce1e2a95b549aac3796af4be7 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 16:05:53 +0200 Subject: [PATCH 047/207] remove unused file --- R/zzz.R | 1 - 1 file changed, 1 deletion(-) delete mode 100644 R/zzz.R diff --git a/R/zzz.R b/R/zzz.R deleted file mode 100644 index 7951defe..00000000 --- a/R/zzz.R +++ /dev/null @@ -1 +0,0 @@ -NULL From b8043a6caf7463dd9c07addec6cf3c55e9314a81 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 16:54:00 +0200 Subject: [PATCH 048/207] add model_profile_2d pdp version --- R/model_profile_2d.R | 124 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 R/model_profile_2d.R diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R new file mode 100644 index 00000000..f1e2bbeb --- /dev/null +++ b/R/model_profile_2d.R @@ -0,0 +1,124 @@ +#' @return An object of class `model_profile_2d_survival`. It is a list with the element `result` containing the results of the calculation. +#' +#' @export +model_profile_2d <- function(explainer, + variables = NULL, + N = 100, + ..., + categorical_variables = NULL, + grid_points = 51, + variable_splits_type = "uniform", + type = "partial", + output_type = "survival") + UseMethod("model_profile_2d", explainer) + +#' @rdname model_profile.surv_explainer +#' @export +model_profile_2d.surv_explainer <- function(explainer, + variables = NULL, + N = 100, + ..., + categorical_variables = NULL, + grid_points = 51, + variable_splits_type = "uniform", + type = "partial", + output_type = "survival" + ) { + variables <- unique(variables, categorical_variables) + if (output_type != "survival") { + stop("Currently only `survival` output type is implemented") + } + test_explainer(explainer, + "model_profile", + has_data = TRUE, + has_survival = TRUE) + + data <- explainer$data + if (!is.null(N) && N < nrow(data)) { + ndata <- data[sample(1:nrow(data), N), ] + } else { + ndata <- data + } + + if (type == "partial") { + result <- surv_pdp_2d( + explainer, + data = ndata, + variables = variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + variable_splits_type = variable_splits_type, + ... + ) + } else if (type == "accumulated") { + result <- surv_ale_2d( + explainer, + data = ndata, + variables = variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + ... + ) + } else { + stop("Currently only `partial` and `accumulated` types are implemented") + } + + ret <- list( + eval_times = unique(result$`_times_`), + result = result, + type = type + ) + class(ret) <- c("model_profile_2d_survival", "list") + ret$event_times <- + explainer$y[explainer$y[, 1] <= max(explainer$times), 1] + ret$event_statuses <- + explainer$y[explainer$y[, 1] <= max(explainer$times), 2] + ret + +} + + +surv_pdp_2d <- function(explainer, + data, + variables, + categorical_variables, + grid_points, + variable_splits_type, + ...) { + data <- x$data + model <- x$model + label <- x$label + predict_survival_function <- x$predict_survival_function + times <- x$times + + # change categorical_features to column names + if (is.numeric(categorical_variables)) categorical_variables <- colnames(data)[categorical_variables] + additional_categorical_variables <- categorical_variables + factor_variables <- colnames(data)[sapply(data, is.factor)] + categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) + + if (is.null(variables) | !is.list(variables) | !all(sapply(variables, length) == 2)) + stop("'variables' must be specified as a list of pairs (two-element vectors)") + + unique_variables <- unlist(variables) + variable_splits <- calculate_variable_split(data, + variables = unique_variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + variable_splits_type = variable_splits_type) + + lapply(variables, FUN = function(variables_pair){ + var1 <- variables_pair[1] + var2 <- variables_pair[2] + expanded_data <- merge(variable_splits[[var1]], data[,!colnames(data) %in% variables_pair]) + names(expanded_data)[colnames(expanded_data) == "x"] <- var1 + expanded_data <- merge(variable_splits[[var2]], expanded_data) + names(expanded_data)[colnames(expanded_data) == "x"] <- var2 + + predictions <- predict_function(model = model, + newdata = expanded_data, + times = times) + }) + +} + From a50105bb0275202a9ad9a4a2e4cedcd73799c711 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 31 Jul 2023 16:54:19 +0200 Subject: [PATCH 049/207] add calculate_variable_split as internal function --- R/surv_ceteris_paribus.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/surv_ceteris_paribus.R b/R/surv_ceteris_paribus.R index 66707f90..34bc9ec8 100644 --- a/R/surv_ceteris_paribus.R +++ b/R/surv_ceteris_paribus.R @@ -124,6 +124,7 @@ calculate_variable_split <- function(data, variables = colnames(data), categoric } #' @importFrom stats na.omit quantile +#' @keywords internal calculate_variable_split.default <- function(data, variables = colnames(data), categorical_variables = NULL, grid_points = 101, variable_splits_type = "quantiles", new_observation = NA) { variable_splits <- lapply(variables, function(var) { selected_column <- na.omit(data[, var]) From 3d419061b43440956d55b9d954b46eca2d6cff70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Tue, 1 Aug 2023 13:01:56 +0200 Subject: [PATCH 050/207] Fix R CMD CHECK notes --- R/plot_model_profile_survival.R | 169 ++++++++++++++-------------- man/plot2.model_profile_survival.Rd | 15 ++- 2 files changed, 96 insertions(+), 88 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index f82be00b..2ac7c1f0 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -89,7 +89,7 @@ plot.model_profile_survival <- function(x, return(return_plot) } - +#' @rdname plot2.model_profile_survival #' @export plot2 <- function(x, ...) UseMethod("plot2") @@ -110,6 +110,7 @@ plot2 <- function(x, ...) UseMethod("plot2") #' #' @return A `ggplot` object. #' +#' @rdname plot2.model_profile_survival #' @examples #' \donttest{ #' library(survival) @@ -120,11 +121,11 @@ plot2 <- function(x, ...) UseMethod("plot2") #' #' m_prof <- model_profile(exp, categorical_variables = "trt") #' -#' plot2(m_prof_cph, variable = "karno", plot_type = "pdp+ice") +#' plot2(m_prof, variable = "karno", plot_type = "pdp+ice") #' -#' plot2(m_prof_cph, times = c(1, 2.72), variable = "karno", plot_type = "pdp+ice") +#' plot2(m_prof, times = c(1, 2.72), variable = "karno", plot_type = "pdp+ice") #' -#' plot2(m_prof_cph, times = c(1, 2.72), variable = "celltype", plot_type = "pdp+ice") +#' plot2(m_prof, times = c(1, 2.72), variable = "celltype", plot_type = "pdp+ice") #' } #' #' @export @@ -272,55 +273,57 @@ plot_pdp_num <- function(pdp_dt, plot_type, single_timepoint, colors) { - if (single_timepoint == TRUE) { ## single timepoint - if (plot_type == "ice") { - ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + - geom_line(alpha = 0.2, mapping = aes(group = id)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + - ylim(y_floor_ice, y_ceiling_ice) - } - # PDP + ICE - else if (plot_type == "pdp+ice") { - ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + - geom_line(mapping = aes(group = id), alpha = 0.2) + - geom_line(data = pdp_dt, aes(x = !!feature_name_sym, y = pd), linewidth = 2, color = "gold") + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + - ylim(y_floor_ice, y_ceiling_ice) - } - # PDP - else if (plot_type == "pdp") { - ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + - geom_line() + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + - ylim(y_floor_pd, y_ceiling_pd) - } - } else { ## multiple timepoints - if (plot_type == "ice") { - ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + - geom_line(alpha = 0.2, mapping = aes(group = interaction(id, time), color = time)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + - scale_color_manual(name = "time", values = colors) + - ylim(y_floor_ice, y_ceiling_ice) - } - # PDP + ICE - else if (plot_type == "pdp+ice") { - ggplot() + - geom_line(data = ice_dt, aes(x = !!feature_name_sym, y = predictions, group = interaction(id, time), color = time), alpha = 0.1) + - geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, color = time), linewidth = 1.5, lineend = "round", linejoin = "round") + - geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, group = time), color = "black", linewidth = 0.5, linetype = "dashed", lineend = "round", linejoin = "round") + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + - scale_color_manual(name = "time", values = colors) + - ylim(y_floor_ice, y_ceiling_ice) + with(pdp_dt, { # to get rid of Note: no visible binding for global variable ... + if (single_timepoint == TRUE) { ## single timepoint + if (plot_type == "ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + + geom_line(alpha = 0.2, mapping = aes(group = id)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + ylim(y_floor_ice, y_ceiling_ice) + } + # PDP + ICE + else if (plot_type == "pdp+ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + + geom_line(mapping = aes(group = id), alpha = 0.2) + + geom_line(data = pdp_dt, aes(x = !!feature_name_sym, y = pd), linewidth = 2, color = "gold") + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + ylim(y_floor_ice, y_ceiling_ice) + } + # PDP + else if (plot_type == "pdp") { + ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + + geom_line() + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + ylim(y_floor_pd, y_ceiling_pd) + } + } else { ## multiple timepoints + if (plot_type == "ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + + geom_line(alpha = 0.2, mapping = aes(group = interaction(id, time), color = time)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + + scale_color_manual(name = "time", values = colors) + + ylim(y_floor_ice, y_ceiling_ice) + } + # PDP + ICE + else if (plot_type == "pdp+ice") { + ggplot() + + geom_line(data = ice_dt, aes(x = !!feature_name_sym, y = predictions, group = interaction(id, time), color = time), alpha = 0.1) + + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, color = time), linewidth = 1.5, lineend = "round", linejoin = "round") + + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, group = time), color = "black", linewidth = 0.5, linetype = "dashed", lineend = "round", linejoin = "round") + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + + scale_color_manual(name = "time", values = colors) + + ylim(y_floor_ice, y_ceiling_ice) + } + # PDP + else if (plot_type == "pdp") { + ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + + geom_line(aes(color = time)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + scale_color_manual(name = "time", values = colors) + + ylim(y_floor_pd, y_ceiling_pd) + } } - # PDP - else if (plot_type == "pdp") { - ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + - geom_line(aes(color = time)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + - scale_color_manual(name = "time", values = colors) + - ylim(y_floor_pd, y_ceiling_pd) - } - } + }) } plot_pdp_cat <- function(pdp_dt, @@ -334,37 +337,39 @@ plot_pdp_cat <- function(pdp_dt, plot_type, single_timepoint, colors) { - if (single_timepoint == TRUE) { ## single timepoint - if (plot_type == "ice") { - ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + - geom_boxplot(alpha = 0.2) + - scale_color_manual(name = "time", values = colors) - } else if (plot_type == "pdp+ice") { - ggplot() + - geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + - geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = "gold") + - scale_color_manual(name = "time", values = colors) - } else if (plot_type == "pdp") { - ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd), ) + - geom_bar(stat = "identity", width = 0.5) + - scale_fill_manual(name = "time", values = colors) + with(pdp_dt, { # to get rid of Note: no visible binding for global variable ... + if (single_timepoint == TRUE) { ## single timepoint + if (plot_type == "ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + + geom_boxplot(alpha = 0.2) + + scale_color_manual(name = "time", values = colors) + } else if (plot_type == "pdp+ice") { + ggplot() + + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = "gold") + + scale_color_manual(name = "time", values = colors) + } else if (plot_type == "pdp") { + ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd), ) + + geom_bar(stat = "identity", width = 0.5) + + scale_fill_manual(name = "time", values = colors) + } + } else { + if (plot_type == "ice") { + ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + + geom_boxplot(alpha = 0.2, mapping = aes(color = time)) + + scale_color_manual(name = "time", values = colors) + } else if (plot_type == "pdp+ice") { + ggplot(mapping = aes(color = time)) + + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = time), linewidth = 0.6) + + scale_color_manual(name = "time", values = colors) + } else if (plot_type == "pdp") { + ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, fill = time)) + + geom_bar(stat = "identity", width = 0.5, position = "dodge") + + scale_fill_manual(name = "time", values = colors) + } } - } else { - if (plot_type == "ice") { - ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + - geom_boxplot(alpha = 0.2, mapping = aes(color = time)) + - scale_color_manual(name = "time", values = colors) - } else if (plot_type == "pdp+ice") { - ggplot(mapping = aes(color = time)) + - geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + - geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = time), linewidth = 0.6) + - scale_color_manual(name = "time", values = colors) - } else if (plot_type == "pdp") { - ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, fill = time)) + - geom_bar(stat = "identity", width = 0.5, position = "dodge") + - scale_fill_manual(name = "time", values = colors) - } - } + }) } diff --git a/man/plot2.model_profile_survival.Rd b/man/plot2.model_profile_survival.Rd index 939e5309..9d345e45 100644 --- a/man/plot2.model_profile_survival.Rd +++ b/man/plot2.model_profile_survival.Rd @@ -1,9 +1,12 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/plot_model_profile_survival.R -\name{plot2.model_profile_survival} +\name{plot2} +\alias{plot2} \alias{plot2.model_profile_survival} \title{Plot Model Profile for Survival Models (without continuous time aspect)} \usage{ +plot2(x, ...) + \method{plot2}{model_profile_survival}( x, variable, @@ -19,6 +22,8 @@ \arguments{ \item{x}{an object of class \code{model_profile_survival} to be plotted} +\item{...}{other parameters. Currently ignored.} + \item{variable}{character, name of a single variable to be plotted} \item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used.} @@ -27,8 +32,6 @@ \item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"} selects the type of plot to be drawn} -\item{...}{other parameters. Currently ignored.} - \item{title}{character, title of the plot} \item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} @@ -52,11 +55,11 @@ exp <- explain(model) m_prof <- model_profile(exp, categorical_variables = "trt") -plot2(m_prof_cph, variable = "karno", plot_type = "pdp+ice") +plot2(m_prof, variable = "karno", plot_type = "pdp+ice") -plot2(m_prof_cph, times = c(1, 2.72), variable = "karno", plot_type = "pdp+ice") +plot2(m_prof, times = c(1, 2.72), variable = "karno", plot_type = "pdp+ice") -plot2(m_prof_cph, times = c(1, 2.72), variable = "celltype", plot_type = "pdp+ice") +plot2(m_prof, times = c(1, 2.72), variable = "celltype", plot_type = "pdp+ice") } } From a2144a6bcae6f0e6f52be6ecc81add09a807e38e Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 1 Aug 2023 13:18:37 +0200 Subject: [PATCH 051/207] 2d profiles calculation --- R/model_profile_2d.R | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index f1e2bbeb..071aa35a 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -12,7 +12,6 @@ model_profile_2d <- function(explainer, output_type = "survival") UseMethod("model_profile_2d", explainer) -#' @rdname model_profile.surv_explainer #' @export model_profile_2d.surv_explainer <- function(explainer, variables = NULL, @@ -64,8 +63,9 @@ model_profile_2d.surv_explainer <- function(explainer, } ret <- list( - eval_times = unique(result$`_times_`), result = result, + eval_times = unique(result$`_times_`), + variables = variables, type = type ) class(ret) <- c("model_profile_2d_survival", "list") @@ -73,12 +73,11 @@ model_profile_2d.surv_explainer <- function(explainer, explainer$y[explainer$y[, 1] <= max(explainer$times), 1] ret$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] - ret - + return(ret) } - -surv_pdp_2d <- function(explainer, +#' @keywords internal +surv_pdp_2d <- function(x, data, variables, categorical_variables, @@ -104,21 +103,37 @@ surv_pdp_2d <- function(explainer, variable_splits <- calculate_variable_split(data, variables = unique_variables, categorical_variables = categorical_variables, - grid_points = grid_points, - variable_splits_type = variable_splits_type) + grid_points = grid_points, + variable_splits_type = variable_splits_type) - lapply(variables, FUN = function(variables_pair){ + profiles <- lapply(variables, FUN = function(variables_pair){ var1 <- variables_pair[1] var2 <- variables_pair[2] expanded_data <- merge(variable_splits[[var1]], data[,!colnames(data) %in% variables_pair]) names(expanded_data)[colnames(expanded_data) == "x"] <- var1 expanded_data <- merge(variable_splits[[var2]], expanded_data) names(expanded_data)[colnames(expanded_data) == "x"] <- var2 + expanded_data <- expanded_data[,colnames(data)] - predictions <- predict_function(model = model, + predictions <- predict_survival_function(model = model, newdata = expanded_data, times = times) - }) + res <- data.frame( + "_v1name_" = var1, + "_v2name_" = var2, + "_v1type_" = ifelse(var1 %in% categorical_variables, "categorical", "numerical"), + "_v2type_" = ifelse(var2 %in% categorical_variables, "categorical", "numerical"), + "_v1value_" = rep(expanded_data[,var1], each=length(times)), + "_v2value_" = rep(expanded_data[,var2], each=length(times)), + "_times_" = rep(times, nrow(expanded_data)), + "_yhat_" = c(t(predictions)), + "_label_" = label, + check.names = FALSE + ) + return(aggregate(`_yhat_`~., data = res, FUN=mean)) + }) + profiles <- do.call(rbind, profiles) + profiles } From 88d1b23f0e4f07d6fd70414268db462323bfc67e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Tue, 1 Aug 2023 13:28:29 +0200 Subject: [PATCH 052/207] Add tests for plot2 plots for PDP --- tests/testthat/test-model_profile.R | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 1aca8119..32e49920 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -24,6 +24,16 @@ test_that("model_profile works", { plot(mp_cph_num, variable_type = "numerical") plot(mp_cph_num, numerical_plot_type = "contours") + ### Add tests for plot2 for numerical PDP + # single timepoint + plot2(mp_cph_num, variable = "karno", plot_type = "pdp+ice") + plot2(mp_cph_num, variable = "karno", plot_type = "pdp") + plot2(mp_cph_num, variable = "karno", plot_type = "ice") + # multiple timepoints + plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "pdp+ice") + plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "pdp") + plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "ice") + expect_s3_class(mp_cph_num, "model_profile_survival") expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) expect_equal(ncol(mp_cph_num$result), 7) @@ -33,6 +43,17 @@ test_that("model_profile works", { mp_rsf_cat <- model_profile(rsf_ranger_exp, output_type = "survival", variable_splits_type = "uniform", variable_type = "categorical", grid_points = 6) plot(mp_rsf_cat, variable_type = "categorical") + ### Add tests for plot2 for categorical PDP + # single timepoint + plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp+ice") + plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp") + plot2(mp_rsf_cat, variable = "celltype", plot_type = "ice") + # multiple timepoints + plot2(mp_rsf_cat, times = c(4, 5.84), variable = "celltype", plot_type = "pdp+ice") + plot2(mp_rsf_cat, times = c(4, 5.84), variable = "celltype", plot_type = "pdp") + plot2(mp_rsf_cat, times = c(4, 5.84), variable = "celltype", plot_type = "ice") + + expect_s3_class(mp_rsf_cat, "model_profile_survival") expect_true(all(mp_rsf_cat$eval_times == cph_exp$times)) expect_equal(ncol(mp_rsf_cat$result), 7) From 3d0e5ceb920ea990816bb240e7502c52e14d36ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Tue, 1 Aug 2023 13:30:21 +0200 Subject: [PATCH 053/207] Get rid of the deprecated size aesthetic for line plots. Add ggplot2 version dependency (>=3.4.0) --- DESCRIPTION | 2 +- R/plot_feature_importance.R | 2 +- R/plot_surv_ceteris_paribus.R | 7 +++---- R/plot_surv_feature_importance.R | 2 +- R/plot_surv_lime.R | 2 +- R/plot_surv_model_performance.R | 2 +- R/plot_surv_model_performance_rocs.R | 6 +++--- R/plot_surv_shap.R | 2 +- 8 files changed, 12 insertions(+), 13 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index e50477b3..b394b836 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -23,7 +23,7 @@ RoxygenNote: 7.2.3 Depends: R (>= 3.5.0) Imports: DALEX (>= 2.2.1), - ggplot2, + ggplot2 (>= 3.4.0), kernelshap, pec, survival, diff --git a/R/plot_feature_importance.R b/R/plot_feature_importance.R index d74664a7..0cd4c4a1 100644 --- a/R/plot_feature_importance.R +++ b/R/plot_feature_importance.R @@ -111,7 +111,7 @@ plot.feature_importance_explainer <- function(x, ..., max_vars = NULL, show_boxp # plot it pl <- ggplot(ext_expl_df, aes(variable, ymin = dropout_loss.y, ymax = dropout_loss.x, color = label)) + geom_hline(data = bestFits, aes(yintercept = dropout_loss, color = label), lty= 3) + - geom_linerange(size = bar_width) + geom_linerange(linewidth = bar_width) if (show_boxplots) { pl <- pl + diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index 3dfb332b..b28f625a 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -226,7 +226,7 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, midpoint = median(as.numeric(as.character(df$`_x_`))) ) + geom_line(data = df[df$`_real_point_`, ], color = - "red", linewidth = 0.8, size = 0.8) + + "red", linewidth = 0.8) + xlab("") + ylab("survival function value") + ylim(c(0, 1)) + xlim(c(0,NA))+ theme_default_survex() + facet_wrap(~`_vname_`) @@ -270,10 +270,9 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, ) ) + geom_line(data = df[!df$`_real_point_`, ], - linewidth = 0.8, - size = 0.8) + + linewidth = 0.8) + geom_line(data = df[df$`_real_point_`, ], - size = 0.8, linewidth = 0.8, linetype = "longdash") + + linewidth = 0.8, linetype = "longdash") + scale_color_manual(name = paste0(unique(df$`_vname_`), " value"), values = generate_discrete_color_scale(n_colors, colors)) + theme_default_survex() + diff --git a/R/plot_surv_feature_importance.R b/R/plot_surv_feature_importance.R index 2c0ac1dd..c0ec5a6b 100644 --- a/R/plot_surv_feature_importance.R +++ b/R/plot_surv_feature_importance.R @@ -86,7 +86,7 @@ plot.surv_feature_importance <- function(x, ..., base_plot <- with(plotting_df, { ggplot(data = plotting_df, aes(x = `_times_`, y = values, color = ind, label = ind)) + - geom_line(linewidth = 0.8, size = 0.8) + + geom_line(linewidth = 0.8) + theme_default_survex() + xlab("") + ylab(y_lab) + diff --git a/R/plot_surv_lime.R b/R/plot_surv_lime.R index db3be280..520be24f 100644 --- a/R/plot_surv_lime.R +++ b/R/plot_surv_lime.R @@ -88,7 +88,7 @@ plot.surv_lime <- function(x, pl2 <- with(sf_df,{ ggplot(data = sf_df, aes(x = times, y = sfs, group = type, color = type)) + - geom_line(linewidth = 0.8, size = 0.8) + + geom_line(linewidth = 0.8) + theme_default_survex() + xlab("") + xlim(c(0,NA))+ diff --git a/R/plot_surv_model_performance.R b/R/plot_surv_model_performance.R index dfffdabe..7ddc6f22 100644 --- a/R/plot_surv_model_performance.R +++ b/R/plot_surv_model_performance.R @@ -71,7 +71,7 @@ plot_td_surv_model_performance <- function(x, ..., metrics = NULL, title = NULL, base_plot <- with(df,{ ggplot(data = df[df$ind %in% metrics, ], aes(x = times, y = values, group = label, color = label)) + - geom_line(linewidth = 0.8, size = 0.8) + + geom_line(linewidth = 0.8) + theme_default_survex() + xlab("") + ylab("metric value") + diff --git a/R/plot_surv_model_performance_rocs.R b/R/plot_surv_model_performance_rocs.R index 36e9f425..6752c3ac 100644 --- a/R/plot_surv_model_performance_rocs.R +++ b/R/plot_surv_model_performance_rocs.R @@ -52,13 +52,13 @@ plot.surv_model_performance_rocs <- function(x, num_colors <- length(unique(df$label)) base_plot <- with(df, {ggplot(data = df, aes(x = FPR, y = TPR, group = label, color = label)) + - geom_line(linewidth = 0.8, size = 0.8) + + geom_line(linewidth = 0.8) + theme_default_survex() + xlab("1 - specificity (FPR)") + ylab("sensitivity (TPR)") + coord_fixed() + - theme(panel.grid.major.x = element_line(color = "grey90", linewidth = 0.5, size = 0.5, linetype = 1), - panel.grid.minor.x = element_line(color = "grey90", linewidth = 0.5, size = 0.5, linetype = 1)) + + theme(panel.grid.major.x = element_line(color = "grey90", linewidth = 0.5, linetype = 1), + panel.grid.minor.x = element_line(color = "grey90", linewidth = 0.5, linetype = 1)) + labs(title = title, subtitle = subtitle) + scale_color_manual("", values = generate_discrete_color_scale(num_colors, colors)) + facet_wrap(~time, ncol = facet_ncol, labeller = function(x) lapply(x, function(x) paste0("t=", x))) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index b418b808..8768d664 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -75,7 +75,7 @@ plot.surv_shap <- function(x, base_plot <- with(long_df, { ggplot(data = long_df, aes(x = times, y = values, color = ind)) + - geom_line(linewidth = 0.8, size = 0.8) + + geom_line(linewidth = 0.8) + ylab(y_lab) + xlab("") + xlim(c(0,NA))+ labs(title = title, subtitle = subtitle) + From a0fa1f36f1ec289765708e3e22437f1700a6b675 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 1 Aug 2023 15:13:05 +0200 Subject: [PATCH 054/207] remove unused event times and statuses --- R/model_profile_2d.R | 4 ---- 1 file changed, 4 deletions(-) diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index 071aa35a..801ffe4c 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -69,10 +69,6 @@ model_profile_2d.surv_explainer <- function(explainer, type = type ) class(ret) <- c("model_profile_2d_survival", "list") - ret$event_times <- - explainer$y[explainer$y[, 1] <= max(explainer$times), 1] - ret$event_statuses <- - explainer$y[explainer$y[, 1] <= max(explainer$times), 2] return(ret) } From 34b08ef9576fdeeb08688564a3c303e1b43a8497 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 1 Aug 2023 15:13:16 +0200 Subject: [PATCH 055/207] add plot model_profile_2d --- R/plot_model_profile_2d.R | 141 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 R/plot_model_profile_2d.R diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R new file mode 100644 index 00000000..e392c1f7 --- /dev/null +++ b/R/plot_model_profile_2d.R @@ -0,0 +1,141 @@ +#' @export +plot.model_profile_2d_survival <- function(x, + ..., + variables = NULL, + times = NULL, + marginalize_over_time = FALSE, + facet_ncol = NULL, + title = "default", + subtitle = "default", + colors = NULL){ + + explanations_list <- c(list(x), list(...)) + num_models <- length(explanations_list) + if (title == "default"){ + if (x$type == "partial") + title <- "2D partial dependence survival profiles" + if (x$type == "accumulated") + title <- "2D accumulated local effects survival profiles" + } + + all_variables <- x$variables + if (!is.null(variables)) { + all_variables <- intersect(all_variables, variables) + if (length(all_variables) == 0) + stop(paste0( + "variables do not overlap with ", + paste(all_variables, collapse = ", ") + )) + } + + if (is.null(colors)) + colors <- c("#c7f5bf", "#8bdcbe", "#46bac2", "#4378bf", "#371ea3") + + if (num_models == 1){ + result <- prepare_model_profile_2d_plots(x, + variables = variables, + times = times, + marginalize_over_time = marginalize_over_time, + facet_ncol = facet_ncol, + title = title, + subtitle = subtitle, + colors = colors + ) + return(result) + } + + return_list <- list() + labels <- list() + for (i in 1:num_models){ + this_title <- unique(explanations_list[[i]]$result$`_label_`) + return_list[[i]] <- prepare_model_profile_2d_plots(explanations_list[[i]], + variables = variables, + times = times, + marginalize_over_time = marginalize_over_time, + facet_ncol = 1, + title = this_title, + subtitle = subtitle, + colors = colors) + labels[[i]] <- c(this_title, rep("", length(return_list[[i]]$patches)-2)) + } + + labels <- unlist(labels) + patchwork::wrap_plots(return_list, nrow = 1, tag_level="keep") + + patchwork::plot_annotation(title, tag_levels = list(labels)) & theme_default_survex() +} + + +prepare_model_profile_2d_plots <- function(x, + variables, + times, + marginalize_over_time, + facet_ncol, + title, + subtitle, + colors +){ + if (is.null(times)) { + times <- quantile(x$eval_times, p = 0.5, type = 1) + } + + if (!marginalize_over_time) { + times <- times[1] + } + + if (!all(times %in% x$eval_times)) { + stop(paste0( + "For one of the provided times the explanations has not been calculated not found. + Please modify the times argument in your explainer or use only values from the following: ", + paste(x$eval_times, collapse = ", ") + )) + } + + all_profiles <- x$result + df_time <- all_profiles[all_profiles$`_times_` %in% times, ] + + sf_range <- range(df_time$`_yhat_`) + + pl <- lapply(seq_along(x$variables), function(i){ + variable_pair <- x$variables[[i]] + df <- df_time[df_time$`_v1name_` == variable_pair[1] & + df_time$`_v2name_` == variable_pair[2],] + if (any(df$`_v1type_` == "numerical")) + df$`_v1value_` <- as.numeric(as.character(df$`_v1value_`)) + if (any(df$`_v2type_` == "numerical")) + df$`_v2value_` <- as.numeric(as.character(df$`_v2value_`)) + xlabel <- unique(df$`_v1name_`) + ylabel <- unique(df$`_v2name_`) + + p <- with(df, { + ggplot(df, + aes(x = `_v1value_`, y = `_v2value_`, fill = `_yhat_`)) + + geom_tile() + + scale_fill_gradientn(name = "SF value", + colors = rev(grDevices::colorRampPalette(colors)(10)), + limits = sf_range) + + labs(x = xlabel, y = ylabel) + + theme(legend.position = "top") + + facet_wrap(~paste(`_v1name_`, `_v2name_`, sep = " : ")) + }) + + if (i != length(x$variables)) + p <- p + guides(fill = "none") + return(p) + }) + if (!is.null(subtitle) && subtitle == "default") { + labels <- + paste0(unique(all_profiles$`_label_`), collapse = ", ") + subtitle <- paste0("created for the ", labels, " model") + if (!marginalize_over_time) + subtitle <- paste0(subtitle, " and t = ", times) + } + + patchwork::wrap_plots(pl, ncol = facet_ncol) & + patchwork::plot_annotation(title = title, + subtitle = subtitle) & theme_default_survex() & + plot_layout(guides = "collect") +} + + + + From 2dbff6f7c4ea28c821efc2b39966a63cef759803 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 1 Aug 2023 15:13:40 +0200 Subject: [PATCH 056/207] add new methods --- NAMESPACE | 3 +++ 1 file changed, 3 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 2dca1757..62f7bc7e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,9 +12,11 @@ S3method(model_performance,default) S3method(model_performance,surv_explainer) S3method(model_profile,default) S3method(model_profile,surv_explainer) +S3method(model_profile_2d,surv_explainer) S3method(plot,feature_importance_explainer) S3method(plot,model_parts_survival) S3method(plot,model_performance_survival) +S3method(plot,model_profile_2d_survival) S3method(plot,model_profile_survival) S3method(plot,predict_parts_survival) S3method(plot,predict_profile_survival) @@ -59,6 +61,7 @@ export(loss_one_minus_integrated_cd_auc) export(model_parts) export(model_performance) export(model_profile) +export(model_profile_2d) export(predict_parts) export(predict_profile) export(risk_from_chf) From f91cb288eacf68092bd50d89f9cc4fdc4ef1a797 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 2 Aug 2023 17:33:33 +0200 Subject: [PATCH 057/207] add ale 2d working version --- R/model_profile_2d.R | 162 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 150 insertions(+), 12 deletions(-) diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index 801ffe4c..1060ada0 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -23,7 +23,10 @@ model_profile_2d.surv_explainer <- function(explainer, type = "partial", output_type = "survival" ) { - variables <- unique(variables, categorical_variables) + + if (is.null(variables) | !is.list(variables) | !all(sapply(variables, length) == 2)) + stop("'variables' must be specified as a list of pairs (two-element vectors)") + if (output_type != "survival") { stop("Currently only `survival` output type is implemented") } @@ -39,6 +42,12 @@ model_profile_2d.surv_explainer <- function(explainer, ndata <- data } + # change categorical_features to column names + if (is.numeric(categorical_variables)) categorical_variables <- colnames(data)[categorical_variables] + additional_categorical_variables <- categorical_variables + factor_variables <- colnames(data)[sapply(data, is.factor)] + categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) + if (type == "partial") { result <- surv_pdp_2d( explainer, @@ -72,7 +81,6 @@ model_profile_2d.surv_explainer <- function(explainer, return(ret) } -#' @keywords internal surv_pdp_2d <- function(x, data, variables, @@ -80,21 +88,11 @@ surv_pdp_2d <- function(x, grid_points, variable_splits_type, ...) { - data <- x$data model <- x$model label <- x$label predict_survival_function <- x$predict_survival_function times <- x$times - # change categorical_features to column names - if (is.numeric(categorical_variables)) categorical_variables <- colnames(data)[categorical_variables] - additional_categorical_variables <- categorical_variables - factor_variables <- colnames(data)[sapply(data, is.factor)] - categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) - - if (is.null(variables) | !is.list(variables) | !all(sapply(variables, length) == 2)) - stop("'variables' must be specified as a list of pairs (two-element vectors)") - unique_variables <- unlist(variables) variable_splits <- calculate_variable_split(data, variables = unique_variables, @@ -133,3 +131,143 @@ surv_pdp_2d <- function(x, profiles } +surv_ale_2d <- function(x, + data, + variables, + categorical_variables, + grid_points, + ...){ + model <- x$model + label <- x$label + predict_survival_function <- x$predict_survival_function + times <- x$times + + predictions_original <- predict_survival_function(model = model, + newdata = data, + times = times) + mean_pred <- colMeans(predictions_original) + + + profiles <- lapply(variables, FUN = function(variables_pair){ + var1 <- variables_pair[1] + var2 <- variables_pair[2] + + if (all(!variables_pair %in% categorical_variables)){ + surv_ale_2d_num_num( + model, + data, + predict_survival_function, + times, + grid_points, + var1, + var2, + ) + } + + }) + + +} + + +surv_ale_2d_num_num <- function(model, + data, + predict_survival_function, + times, + grid_points, + var1, + var2){ + + # Number of quantile points for determined by grid length + quantile_vals1 <- as.numeric(quantile(data[, var1], + seq(0.01, 1, length.out = grid_points), + type = 1)) + quantile_vals2 <- as.numeric(quantile(data[, var2], + seq(0.01, 1, length.out = grid_points), + type = 1)) + + quantile_vec1 <- unique(c(min(data[, var1]), quantile_vals1)) + quantile_vec2 <- unique(c(min(data[, var2]), quantile_vals2)) + + data <- data[(data[, var1] <= max(quantile_vec1)) & + (data[, var1] >= min(quantile_vec1)) & + (data[, var2] <= max(quantile_vec2)) & + (data[, var2] >= min(quantile_vec2)), ] + + # Matching instances to the grids of both features + interval_index1 <- findInterval(data[, var1], quantile_vec1, left.open = TRUE) + interval_index2 <- findInterval(data[, var2], quantile_vec2, left.open = TRUE) + + interval_index1[interval_index1 == 0] <- 1 + interval_index2[interval_index2 == 0] <- 1 + + X_low1_low2 <- X_up1_low2 <- X_low1_up2 <- X_up1_up2 <- data + X_low1_low2[, c(var1, var2)] <- cbind(quantile_vec1[interval_index1], + quantile_vec2[interval_index2]) + X_up1_low2[, c(var1, var2)] <- cbind(quantile_vec1[interval_index1 + 1], + quantile_vec2[interval_index2]) + X_low1_up2[, c(var1, var2)] <- cbind(quantile_vec1[interval_index1], + quantile_vec2[interval_index2 + 1]) + X_up1_up2[, c(var1, var2)] <- cbind(quantile_vec1[interval_index1 + 1], + quantile_vec2[interval_index2 + 1]) + + y_hat_11 <- predict_survival_function(model = model, newdata = X_low1_low2, times = times) + y_hat_21 <- predict_survival_function(model = model, newdata = X_up1_low2, times = times) + y_hat_12 <- predict_survival_function(model = model, newdata = X_low1_up2, times = times) + y_hat_22 <- predict_survival_function(model = model, newdata = X_up1_up2, times = times) + + prediction_deltas <- (y_hat_22 - y_hat_21) - (y_hat_12 - y_hat_11) + colnames(prediction_deltas) <- times + + deltas <- data.frame( + interval1 = rep(interval_index1, each = length(times)), + interval2 = rep(interval_index2, each = length(times)), + time = rep(times, times = nrow(data)), + yhat = c(t(prediction_deltas)) + ) + + deltas <- aggregate(`yhat`~., data = deltas, FUN=mean) + interval_grid <- expand.grid( + interval1 = unique(deltas$interval1), + interval2 = unique(deltas$interval2), + time = times + ) + deltas <- merge(deltas, + interval_grid, + on = c("interval1", "interval2"), + all.y = TRUE) + + deltas$yhat_cumsum <- ave(deltas$yhat, deltas$time, deltas$interval1, FUN = function(x) cumsum(ifelse(is.na(x), 0, x))) + deltas$yhat_cumsum <- ave(deltas$yhat_cumsum, deltas$time, deltas$interval2, FUN = function(x) cumsum(ifelse(is.na(x), 0, x))) + + cell_counts <- as.matrix(table(interval_index1, interval_index2)) + cell_counts_df <- as.data.frame(as.table(cell_counts)) + colnames(cell_counts_df) <- c("interval1", "interval2", "count") + cell_counts_df$interval1 <- as.numeric(as.character(cell_counts_df$interval1)) + cell_counts_df$interval2 <- as.numeric(as.character(cell_counts_df$interval2)) + ale <- merge(deltas, cell_counts_df, on = c("interval1", "interval2"), all.x = TRUE) + + + # Computing the first-order effect of feature 1 + ale1 <- data.frame() # Initialize an empty data frame + + res <- copy(ale) + res$yhat_diff <- ave(ale$yhat_cumsum, + list(ale$interval2, ale$time), + FUN = function(x) c(x[1], x[-1] - x[-length(x)])) + + res$ale1 + + res$ale1 <- ave(res$yhat_diff, + list(res$interval1, res$time), + FUN = function(x) c((x[-length(x)] + x[-1]) / 2, 0)) + + # sub_res <- res[, c("time", "interval1", "count", "yhat_diff")] + # + # # Step 2: Calculate the numerator and denominator for each group using ave + # sub_res$numerator <- with(sub_res, ave(count[-1] * (yhat_diff[-nrow(sub_res)] + yhat_diff[-1]) / 2, + # time, interval1, FUN = sum)) + # sub_res$denominator <- with(sub_res, ave(count[-1], time, interval1, FUN = sum)) + + +} From a3c3275e8abc0f7a0da2d1d9b5d9d821d84d9e28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Thu, 3 Aug 2023 12:27:15 +0200 Subject: [PATCH 058/207] Add ALE plot tests --- tests/testthat/test-model_profile.R | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index b3ad8447..9170d9b9 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -99,6 +99,12 @@ test_that("model_profile with type = 'accumulated' works", { type = type) plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") + ### Add tests for plot2 for categorical ALE + # single timepoint + # plot2(mp_cph_cat, variable = "celltype", plot_type = "pdp") + # multiple timepoints + # plot2(mp_cph_cat, times = c(4, 5.84), variable = "celltype", plot_type = "pdp") + expect_s3_class(mp_cph_cat, "model_profile_survival") expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) expect_equal(ncol(mp_cph_cat$result), 7) @@ -113,6 +119,12 @@ test_that("model_profile with type = 'accumulated' works", { plot(mp_cph_num, variable_type = "numerical") plot(mp_cph_num, numerical_plot_type = "contours") + ### Add tests for plot2 for numerical ALE + # single timepoint + # plot2(mp_cph_num, variable = "karno", plot_type = "pdp") + # multiple timepoints + # plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "pdp") + expect_s3_class(mp_cph_num, "model_profile_survival") expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) expect_equal(ncol(mp_cph_num$result), 7) From 0fea81f766636e1e7c9d8d0f4bf23f540232bc35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Thu, 3 Aug 2023 12:27:32 +0200 Subject: [PATCH 059/207] Fix missing brace --- R/surv_ceteris_paribus.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/surv_ceteris_paribus.R b/R/surv_ceteris_paribus.R index 12ffe862..a50ffa44 100644 --- a/R/surv_ceteris_paribus.R +++ b/R/surv_ceteris_paribus.R @@ -84,7 +84,6 @@ surv_ceteris_paribus.default <- function(x, if (is.null(variable_splits)) { if (is.null(variables)) variables <- colnames(data) - } variable_splits <- calculate_variable_split(data, variables = variables, categorical_variables = categorical_variables, From 2fe91513749b9ea91cdb65ef3f8236870301f955 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Thu, 3 Aug 2023 12:27:48 +0200 Subject: [PATCH 060/207] Fix plot2 for ALE. --- R/plot_model_profile_survival.R | 76 +++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 23 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index d46b26be..6647f23c 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -108,9 +108,9 @@ plot2 <- function(x, ...) UseMethod("plot2") #' @param variable character, name of a single variable to be plotted #' @param times numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If `NULL` (default) then the median time from the explainer object is used. #' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately -#' @param plot_type character, one of `"pdp"`, `"ice"`, `"pdp+ice"` selects the type of plot to be drawn +#' @param plot_type character, one of `"pdp"`, `"ice"`, `"pdp+ice"`, or `"ale"` selects the type of plot to be drawn #' @param ... other parameters. Currently ignored. -#' @param title character, title of the plot +#' @param title character, title of the plot. `'default'` automatically generates either "Partial dependence survival profiles" or "Accumulated local effects survival profiles" depending on the explanation type. #' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels #' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") #' @@ -139,12 +139,23 @@ plot2.model_profile_survival <- function(x, variable, times = NULL, marginalize_over_time = FALSE, - plot_type = "pdp+ice", + plot_type = NULL, ..., - title = "Partial dependence profile", + title = "default", subtitle = "default", colors = NULL) { - if (!plot_type %in% c("pdp", "ice", "pdp+ice")) { + + + if (is.null(plot_type)) { + if (x$type == "accumulated") plot_type = "ale" + else if (x$type == "partial") plot_type = "pdp+ice" + } + + if (x$type == "accumulated" && plot_type != "ale") { + stop("For accumulated local effects explanations only plot_type = 'ale' is available") + } + + if (!plot_type %in% c("pdp", "ice", "pdp+ice", "ale")) { stop("plot_type must be one of 'pdp', 'ice', 'pdp+ice'") } @@ -172,12 +183,20 @@ plot2.model_profile_survival <- function(x, )) } + if (title == "default") { + if (x$type == "partial") + title <- "Partial dependence survival profiles" + if (x$type == "accumulated") + title <- "Accumulated local effects survival profiles" + } + if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", unique(variable), " variable") } single_timepoint <- ((length(times) == 1) || marginalize_over_time) is_categorical <- (unique(x$result[x$result$`_vname_` == variable, "_vtype_"]) == "categorical") + ice_needed <- plot_type %in% c("pdp+ice", "ice") if (single_timepoint) { pdp_df <- x$result[(x$result$`_vname_` == variable) & (x$result$`_times_` %in% times), c("_x_", "_yhat_")] @@ -188,29 +207,39 @@ plot2.model_profile_survival <- function(x, pdp_df$time <- as.factor(pdp_df$time) } - ice_df <- x$cp_profiles$result[(x$cp_profiles$result$`_vname_` == variable) & - (x$cp_profiles$result$`_times_` %in% times), ] + if (ice_needed){ + ice_df <- x$cp_profiles$result[(x$cp_profiles$result$`_vname_` == variable) & + (x$cp_profiles$result$`_times_` %in% times), ] - if (single_timepoint) { - ice_df$`_times_` <- NULL - } else { - colnames(ice_df)[colnames(ice_df) == "_times_"] <- "time" - ice_df$time <- as.factor(ice_df$time) + if (single_timepoint) { + ice_df$`_times_` <- NULL + } else { + colnames(ice_df)[colnames(ice_df) == "_times_"] <- "time" + ice_df$time <- as.factor(ice_df$time) + + } + + if (is_categorical) { + ice_df[, variable] <- as.factor(ice_df[, variable]) + } + + ice_df$`_vname_` <- NULL + ice_df$`_vtype_` <- NULL + ice_df$`_label_` <- NULL + + colnames(ice_df)[colnames(ice_df) == "_ids_"] <- "id" + colnames(ice_df)[colnames(ice_df) == "_yhat_"] <- "predictions" + + y_floor_ice <- floor(min(ice_df[, "predictions"]) * 10) / 10 + y_ceiling_ice <- ceiling(max(ice_df[, "predictions"]) * 10) / 10 } if (is_categorical) { pdp_df[, variable] <- as.factor(pdp_df[, variable]) - ice_df[, variable] <- as.factor(ice_df[, variable]) } - ice_df$`_vname_` <- NULL - ice_df$`_vtype_` <- NULL - ice_df$`_label_` <- NULL - - colnames(ice_df)[colnames(ice_df) == "_ids_"] <- "id" - colnames(ice_df)[colnames(ice_df) == "_yhat_"] <- "predictions" feature_name_sym <- sym(variable) @@ -218,10 +247,11 @@ plot2.model_profile_survival <- function(x, y_floor_pd <- floor(min(pdp_df[, "pd"]) * 10) / 10 y_ceiling_pd <- ceiling(max(pdp_df[, "pd"]) * 10) / 10 - y_floor_ice <- floor(min(ice_df[, "predictions"]) * 10) / 10 - y_ceiling_ice <- ceiling(max(ice_df[, "predictions"]) * 10) / 10 if (marginalize_over_time) { + pdp_df <- aggregate(pd ~., data = pdp_df, mean) + ice_df <- aggregate(predictions ~ ., data = ice_df, mean) + color_scale <- generate_discrete_color_scale(1, colors) } else { color_scale <- generate_discrete_color_scale(length(times), colors) @@ -296,7 +326,7 @@ plot_pdp_num <- function(pdp_dt, ylim(y_floor_ice, y_ceiling_ice) } # PDP - else if (plot_type == "pdp") { + else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + geom_line() + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + @@ -321,7 +351,7 @@ plot_pdp_num <- function(pdp_dt, ylim(y_floor_ice, y_ceiling_ice) } # PDP - else if (plot_type == "pdp") { + else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + geom_line(aes(color = time)) + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + From e2738c9c04812032cf8340898a87835473275e08 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 12:39:34 +0200 Subject: [PATCH 061/207] add 2d ale for num+num --- R/model_profile_2d.R | 132 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 113 insertions(+), 19 deletions(-) diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index 1060ada0..8096ca81 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -7,6 +7,7 @@ model_profile_2d <- function(explainer, ..., categorical_variables = NULL, grid_points = 51, + center = TRUE, variable_splits_type = "uniform", type = "partial", output_type = "survival") @@ -19,6 +20,7 @@ model_profile_2d.surv_explainer <- function(explainer, ..., categorical_variables = NULL, grid_points = 51, + center = TRUE, variable_splits_type = "uniform", type = "partial", output_type = "survival" @@ -65,6 +67,7 @@ model_profile_2d.surv_explainer <- function(explainer, variables = variables, categorical_variables = categorical_variables, grid_points = grid_points, + center = center, ... ) } else { @@ -136,6 +139,7 @@ surv_ale_2d <- function(x, variables, categorical_variables, grid_points, + center, ...){ model <- x$model label <- x$label @@ -161,12 +165,17 @@ surv_ale_2d <- function(x, grid_points, var1, var2, + mean_pred, + center ) + } else { + stop("Currently 2D ALE are implemented only for pairs of numerical variables") } }) - + profiles <- do.call(rbind, profiles) + profiles } @@ -176,7 +185,9 @@ surv_ale_2d_num_num <- function(model, times, grid_points, var1, - var2){ + var2, + mean_pred, + center){ # Number of quantile points for determined by grid length quantile_vals1 <- as.numeric(quantile(data[, var1], @@ -228,8 +239,8 @@ surv_ale_2d_num_num <- function(model, deltas <- aggregate(`yhat`~., data = deltas, FUN=mean) interval_grid <- expand.grid( - interval1 = unique(deltas$interval1), - interval2 = unique(deltas$interval2), + interval1 = c(0, sort(unique(deltas$interval1))), + interval2 = c(0, sort(unique(deltas$interval2))), time = times ) deltas <- merge(deltas, @@ -240,34 +251,117 @@ surv_ale_2d_num_num <- function(model, deltas$yhat_cumsum <- ave(deltas$yhat, deltas$time, deltas$interval1, FUN = function(x) cumsum(ifelse(is.na(x), 0, x))) deltas$yhat_cumsum <- ave(deltas$yhat_cumsum, deltas$time, deltas$interval2, FUN = function(x) cumsum(ifelse(is.na(x), 0, x))) + interval_index1 <- factor(interval_index1, sort(unique(interval_index1))) + interval_index2 <- factor(interval_index2, sort(unique(interval_index2))) cell_counts <- as.matrix(table(interval_index1, interval_index2)) cell_counts_df <- as.data.frame(as.table(cell_counts)) colnames(cell_counts_df) <- c("interval1", "interval2", "count") cell_counts_df$interval1 <- as.numeric(as.character(cell_counts_df$interval1)) cell_counts_df$interval2 <- as.numeric(as.character(cell_counts_df$interval2)) ale <- merge(deltas, cell_counts_df, on = c("interval1", "interval2"), all.x = TRUE) - + ale <- ale[order(ale$interval1, ale$interval2), ] # Computing the first-order effect of feature 1 - ale1 <- data.frame() # Initialize an empty data frame - - res <- copy(ale) + res <- ale res$yhat_diff <- ave(ale$yhat_cumsum, list(ale$interval2, ale$time), - FUN = function(x) c(x[1], x[-1] - x[-length(x)])) + FUN = function(x) c(x[1], diff(x))) - res$ale1 + ale1 <- do.call("rbind", lapply(sort(unique(res$interval1)), function(x){ + counts <- res[res$interval1 == x & res$time == times[1], "count"] + aggregate(yhat_diff~time, data=res[res$interval1==x,], + FUN = function(vals){ + sum(counts[-1] * (vals[-length(vals)] + vals[-1]) /2 / sum(counts[-1])) + }) + })) - res$ale1 <- ave(res$yhat_diff, - list(res$interval1, res$time), - FUN = function(x) c((x[-length(x)] + x[-1]) / 2, 0)) + ale1$interval1 <- rep(sort(unique(res$interval1)), each=length(times)) + ale1$yhat_diff[is.na(ale1$yhat_diff)] <- 0 + ale1$ale1 <- ave(ale1$yhat_diff, ale1$time, FUN = cumsum) - # sub_res <- res[, c("time", "interval1", "count", "yhat_diff")] - # - # # Step 2: Calculate the numerator and denominator for each group using ave - # sub_res$numerator <- with(sub_res, ave(count[-1] * (yhat_diff[-nrow(sub_res)] + yhat_diff[-1]) / 2, - # time, interval1, FUN = sum)) - # sub_res$denominator <- with(sub_res, ave(count[-1], time, interval1, FUN = sum)) + # Computing the first-order effect of feature 2 + res <- ale + res$yhat_diff <- ave(ale$yhat_cumsum, + list(ale$interval1, ale$time), + FUN = function(x) c(x[1], diff(x))) + + ale2 <- do.call("rbind", lapply(sort(unique(res$interval2)), function(x){ + counts <- res[res$interval2 == x & res$time == times[1], "count"] + aggregate(yhat_diff~time, data=res[res$interval2==x,], + FUN = function(vals){ + sum(counts[-1] * (vals[-length(vals)] + vals[-1]) /2 / sum(counts[-1])) + }) + })) + + ale2$interval2 <- rep(sort(unique(res$interval2)), each=length(times)) + ale2$yhat_diff[is.na(ale2$yhat_diff)] <- 0 + ale2$ale2 <- ave(ale2$yhat_diff, ale2$time, FUN = cumsum) + + + fJ0 <- unlist(lapply(times, function(time) { + ale_time <- ale[ale$time == time,] + ale1_time <- ale1[ale1$time == time,] + ale2_time <- ale2[ale2$time == time,] + + ale_time <- ale_time[c("interval1", "interval2", "yhat_cumsum")] + dd <- reshape(ale_time, + idvar = "interval1", + timevar = "interval2", + direction = "wide")[,-1] + rownames(dd) <- unique(ale_time$interval1) + colnames(dd) <- unique(ale_time$interval2) + + dd <- dd - outer(ale1_time$ale1, rep(1, nrow(ale2_time))) - + outer(rep(1, nrow(ale1_time)), ale2_time$ale2) + sum(cell_counts * (dd[1:(nrow(dd) - 1), 1:(ncol(dd) - 1)] + + dd[1:(nrow(dd) - 1), 2:ncol(dd)] + + dd[2:nrow(dd), 1:(ncol(dd) - 1)] + + dd[2:nrow(dd), 2:ncol(dd)]) / 4, na.rm = TRUE) / sum(cell_counts) + })) + + fJ0 <- data.frame("fJ0" = fJ0, time = times) + ale <- merge(ale, fJ0, by = c("time")) + ale <- merge(ale, ale1, by = c("time", "interval1")) + ale <- merge(ale, ale2, by = c("time", "interval2")) + ale$ale <- ale$yhat_cumsum - ale$ale1 - ale$ale2 - ale$fJ0 + ale <- ale[order(ale$time, ale$interval1, ale$interval2),] + + if (!center){ + ale$ale <- ale$ale + mean_pred + } + interval_dists <- diff(quantile_vec1[c(1, 1:length(quantile_vec1), length(quantile_vec1))]) + interval_dists <- 0.5 * interval_dists + + ale$right <- quantile_vec1[ale$interval1 + 1] + interval_dists[ale$interval1 + 2] + ale$left <- quantile_vec1[ale$interval1 + 1] - interval_dists[ale$interval1 + 1] + + interval_dists2 <- diff(quantile_vec2[c(1, 1:length(quantile_vec2), length(quantile_vec2))]) + interval_dists2 <- 0.5 * interval_dists2 + + ale$bottom <- quantile_vec2[ale$interval2 + 1] + interval_dists2[ale$interval2 + 2] + ale$top <- quantile_vec2[ale$interval2 + 1] - interval_dists2[ale$interval2 + 1] + + ale[, "_v1value_"] <- quantile_vec1[ale$interval1 + 1] + ale[, "_v2value_"] <- quantile_vec2[ale$interval2 + 1] + + data.frame( + "_v1name_" = var1, + "_v2name_" = var2, + "_v1type_" = "numerical", + "_v2type_" = "numerical", + "_v1value_" = ale$`_v1value_`, + "_v2value_" = ale$`_v2value_`, + "_times_" = ale$time, + "_yhat_" = ale$ale, + "_right_" = ale$right, + "_left_" = ale$left, + "_top_" = ale$top, + "_bottom_" = ale$bottom, + "_count_" = ifelse(is.na(ale$count), 0, ale$count), + "_label_" = label, + check.names = FALSE) } + + From fd89941e99a60eb5b6b300329353574e7ee04ae9 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 12:39:44 +0200 Subject: [PATCH 062/207] add ale 2d plotting --- R/plot_model_profile_2d.R | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index e392c1f7..7b93f61f 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -84,7 +84,7 @@ prepare_model_profile_2d_plots <- function(x, if (!all(times %in% x$eval_times)) { stop(paste0( - "For one of the provided times the explanations has not been calculated not found. + "For one of the provided times the explanations has not been calculated or found. Please modify the times argument in your explainer or use only values from the following: ", paste(x$eval_times, collapse = ", ") )) @@ -106,17 +106,31 @@ prepare_model_profile_2d_plots <- function(x, xlabel <- unique(df$`_v1name_`) ylabel <- unique(df$`_v2name_`) - p <- with(df, { - ggplot(df, - aes(x = `_v1value_`, y = `_v2value_`, fill = `_yhat_`)) + - geom_tile() + - scale_fill_gradientn(name = "SF value", - colors = rev(grDevices::colorRampPalette(colors)(10)), - limits = sf_range) + - labs(x = xlabel, y = ylabel) + - theme(legend.position = "top") + - facet_wrap(~paste(`_v1name_`, `_v2name_`, sep = " : ")) - }) + if (x$type == "partial"){ + p <- with(df, { + ggplot(df, + aes(x = `_v1value_`, y = `_v2value_`, fill = `_yhat_`)) + + geom_tile() + + scale_fill_gradientn(name = "PDP value", + colors = rev(grDevices::colorRampPalette(colors)(10)), + limits = sf_range) + + labs(x = xlabel, y = ylabel) + + theme(legend.position = "top") + + facet_wrap(~paste(`_v1name_`, `_v2name_`, sep = " : ")) + }) + } else { + p <- with(df, { + ggplot(df, aes(x = `_v1value_`, y = `_v2value_`, fill = `_yhat_`)) + + geom_rect(aes(ymin = `_bottom_`, ymax = `_top_`, + xmin = `_left_`, xmax = `_right_`)) + + scale_fill_gradientn(name = "ALE value", + colors = rev(grDevices::colorRampPalette(colors)(10)), + limits = sf_range) + + labs(x = xlabel, y = ylabel) + + theme(legend.position = "top") + + facet_wrap(~paste(`_v1name_`, `_v2name_`, sep = " : ")) + }) + } if (i != length(x$variables)) p <- p + guides(fill = "none") From 83b82a41897e73ee8f025e68bb4e10571e5b48ae Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 12:39:53 +0200 Subject: [PATCH 063/207] fix condition --- R/surv_model_profiles.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index 883503f1..c8298990 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -368,7 +368,7 @@ surv_ale <- function(x, by.y = "id") ale_values <- ale_values[order(ale_values$interval, ale_values$time),] - ale_values$ale <- ale_values$ale + mean_pred + if (!center){ ale_values$ale <- ale_values$ale + mean_pred } From ec2052f7d0d1a290f27f37ec590d45f4d02b8ae0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Thu, 3 Aug 2023 12:58:28 +0200 Subject: [PATCH 064/207] Add final working tests --- tests/testthat/test-model_parts.R | 1 + tests/testthat/test-model_profile.R | 5 +++-- tests/testthat/test-predict_parts.R | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-model_parts.R b/tests/testthat/test-model_parts.R index 984978b6..49eb32d7 100644 --- a/tests/testthat/test-model_parts.R +++ b/tests/testthat/test-model_parts.R @@ -86,6 +86,7 @@ test_that("Brier score fpi works", { # specifying loss function brier rsf_src_model_parts_brier <- model_parts(rsf_src_exp, loss_function = loss_brier_score, output_type = "survival") plot(rsf_src_model_parts_brier) + expect_error(plot(rsf_src_model_parts_brier, desc_sorting = "non-logical")) expect_s3_class(rsf_src_model_parts_brier, "model_parts_survival") expect_equal(ncol(rsf_src_model_parts_brier$result), ncol(cph_exp$data) + 5) # times, full_model, permutation, baseline, label diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 9170d9b9..6fb7abbd 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -50,6 +50,7 @@ test_that("model_profile with type = 'partial' works", { plot2(mp_rsf_cat, variable = "celltype", plot_type = "ice") # multiple timepoints plot2(mp_rsf_cat, times = c(4, 5.84), variable = "celltype", plot_type = "pdp+ice") + plot2(mp_rsf_cat, times = c(4, 5.84), marginalilze_over_time = T, variable = "celltype", plot_type = "pdp+ice") plot2(mp_rsf_cat, times = c(4, 5.84), variable = "celltype", plot_type = "pdp") plot2(mp_rsf_cat, times = c(4, 5.84), variable = "celltype", plot_type = "ice") @@ -101,9 +102,9 @@ test_that("model_profile with type = 'accumulated' works", { ### Add tests for plot2 for categorical ALE # single timepoint - # plot2(mp_cph_cat, variable = "celltype", plot_type = "pdp") + # plot2(mp_cph_cat, variable = "celltype", plot_type = "ale") # multiple timepoints - # plot2(mp_cph_cat, times = c(4, 5.84), variable = "celltype", plot_type = "pdp") + # plot2(mp_cph_cat, times = c(4, 5.84), variable = "celltype", plot_type = "ale") expect_s3_class(mp_cph_cat, "model_profile_survival") expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 20539a9f..ecd7fe59 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -10,6 +10,7 @@ test_that("survshap explanations work", { rsf_src_exp <- explain(rsf_src, verbose = FALSE) parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares") + parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), calculation_method = "exact_kernel", aggregation_method = "max_absolute") plot(parts_cph) plot(parts_cph, rug = "events") plot(parts_cph, rug = "censors") From 41d055ed85fe8457cd63772598de90782c53a354 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 15:06:02 +0200 Subject: [PATCH 065/207] fix & add description --- R/model_profile_2d.R | 49 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index 8096ca81..5a9aa662 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -1,25 +1,68 @@ +#' Dataset Level 2-Dimensional Variable Profile for Survival Models +#' +#' This function calculates explanations on a dataset level that help explore model response as a function of selected pairs of variables. +#' The explanations are calculated as an extension of Partial Dependence Profiles or Accumulated Local Effects with the inclusion of the time dimension. +#' +#' +#' @param explainer an explainer object - model preprocessed by the `explain()` function +#' @param variables list of character vectors of length 2, names of pairs of variables to be explained +#' @param N number of observations used for the calculation of aggregated profiles. By default `100`. If `NULL` all observations are used. +#' @param categorical_variables character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the `variables` argument, they will be added at the end. +#' @param grid_points maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default `25`. +#' @param center logical, should profiles be centered at 0 +#' @param variable_splits_type character, decides how variable grids should be calculated. Use `"quantiles"` for quantiles or `"uniform"` (default) to get uniform grid of points. Used only if `type = "partial"`. +#' @param type the type of variable profile, `"partial"` for Partial Dependence or `"accumulated"` for Accumulated Local Effects +#' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. Currently only `"survival"` is available. +#' #' @return An object of class `model_profile_2d_survival`. It is a list with the element `result` containing the results of the calculation. #' +#' +#' @examples +#' \donttest{ +#' library(survival) +#' library(survex) +#' +#' cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) +#' rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) +#' +#' cph_exp <- explain(cph) +#' rsf_src_exp <- explain(rsf_src) +#' +#' cph_model_profile_2d <- model_profile_2d(cph_exp, +#' variables = list(c("age", "celltype"))) +#' head(cph_model_profile_2d$result) +#' plot(cph_model_profile_2d) +#' +#' rsf_model_profile_2d <- model_profile_2d(rsf_src_exp, +#' variables = list(c("age", "karno")), +#' type = "accumulated") +#' head(rsf_model_profile_2d$result) +#' plot(rsf_model_profile_2d) +#' } +#' +#' @rdname model_profile_2d.surv_explainer #' @export model_profile_2d <- function(explainer, variables = NULL, N = 100, ..., categorical_variables = NULL, - grid_points = 51, + grid_points = 25, center = TRUE, variable_splits_type = "uniform", type = "partial", output_type = "survival") UseMethod("model_profile_2d", explainer) + +#' @rdname model_profile_2d.surv_explainer #' @export model_profile_2d.surv_explainer <- function(explainer, variables = NULL, N = 100, ..., categorical_variables = NULL, - grid_points = 51, + grid_points = 25, center = TRUE, variable_splits_type = "uniform", type = "partial", @@ -160,6 +203,7 @@ surv_ale_2d <- function(x, surv_ale_2d_num_num( model, data, + label, predict_survival_function, times, grid_points, @@ -181,6 +225,7 @@ surv_ale_2d <- function(x, surv_ale_2d_num_num <- function(model, data, + label, predict_survival_function, times, grid_points, From 7cbb351dd074f33cb10fe74651d767cb40ce9226 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 15:06:13 +0200 Subject: [PATCH 066/207] fix plots --- R/plot_model_profile_2d.R | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index 7b93f61f..7ec1000a 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -18,14 +18,15 @@ plot.model_profile_2d_survival <- function(x, title <- "2D accumulated local effects survival profiles" } - all_variables <- x$variables if (!is.null(variables)) { - all_variables <- intersect(all_variables, variables) - if (length(all_variables) == 0) + variables <- intersect(x$variables, variables) + if (length(variables) == 0) stop(paste0( "variables do not overlap with ", - paste(all_variables, collapse = ", ") + paste(x$variables, collapse = ", ") )) + } else { + variables <- x$variables } if (is.null(colors)) @@ -56,7 +57,7 @@ plot.model_profile_2d_survival <- function(x, title = this_title, subtitle = subtitle, colors = colors) - labels[[i]] <- c(this_title, rep("", length(return_list[[i]]$patches)-2)) + labels[[i]] <- c(this_title, rep("", length(variables)-1)) } labels <- unlist(labels) @@ -95,14 +96,18 @@ prepare_model_profile_2d_plots <- function(x, sf_range <- range(df_time$`_yhat_`) - pl <- lapply(seq_along(x$variables), function(i){ - variable_pair <- x$variables[[i]] + pl <- lapply(seq_along(variables), function(i){ + variable_pair <- variables[[i]] df <- df_time[df_time$`_v1name_` == variable_pair[1] & df_time$`_v2name_` == variable_pair[2],] if (any(df$`_v1type_` == "numerical")) df$`_v1value_` <- as.numeric(as.character(df$`_v1value_`)) + else if (any(df$`_v1type_` == "categorical")) + df$`_v1value_` <- as.character(df$`_v1value_`) if (any(df$`_v2type_` == "numerical")) df$`_v2value_` <- as.numeric(as.character(df$`_v2value_`)) + else if (any(df$`_v2type_` == "categorical")) + df$`_v2value_` <- as.character(df$`_v2value_`) xlabel <- unique(df$`_v1name_`) ylabel <- unique(df$`_v2name_`) From 889248cbaebd7e1716acf6596498dfc96b762ef6 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 15:06:19 +0200 Subject: [PATCH 067/207] add docs --- man/model_profile_2d.surv_explainer.Rd | 83 ++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 man/model_profile_2d.surv_explainer.Rd diff --git a/man/model_profile_2d.surv_explainer.Rd b/man/model_profile_2d.surv_explainer.Rd new file mode 100644 index 00000000..82123d81 --- /dev/null +++ b/man/model_profile_2d.surv_explainer.Rd @@ -0,0 +1,83 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model_profile_2d.R +\name{model_profile_2d} +\alias{model_profile_2d} +\alias{model_profile_2d.surv_explainer} +\title{Dataset Level 2-Dimensional Variable Profile for Survival Models} +\usage{ +model_profile_2d( + explainer, + variables = NULL, + N = 100, + ..., + categorical_variables = NULL, + grid_points = 25, + center = TRUE, + variable_splits_type = "uniform", + type = "partial", + output_type = "survival" +) + +\method{model_profile_2d}{surv_explainer}( + explainer, + variables = NULL, + N = 100, + ..., + categorical_variables = NULL, + grid_points = 25, + center = TRUE, + variable_splits_type = "uniform", + type = "partial", + output_type = "survival" +) +} +\arguments{ +\item{explainer}{an explainer object - model preprocessed by the \code{explain()} function} + +\item{variables}{list of character vectors of length 2, names of pairs of variables to be explained} + +\item{N}{number of observations used for the calculation of aggregated profiles. By default \code{100}. If \code{NULL} all observations are used.} + +\item{categorical_variables}{character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the \code{variables} argument, they will be added at the end.} + +\item{grid_points}{maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default \code{25}.} + +\item{center}{logical, should profiles be centered at 0} + +\item{variable_splits_type}{character, decides how variable grids should be calculated. Use \code{"quantiles"} for quantiles or \code{"uniform"} (default) to get uniform grid of points. Used only if \code{type = "partial"}.} + +\item{type}{the type of variable profile, \code{"partial"} for Partial Dependence or \code{"accumulated"} for Accumulated Local Effects} + +\item{output_type}{either \code{"survival"} or \code{"risk"} the type of survival model output that should be considered for explanations. Currently only \code{"survival"} is available.} +} +\value{ +An object of class \code{model_profile_2d_survival}. It is a list with the element \code{result} containing the results of the calculation. +} +\description{ +This function calculates explanations on a dataset level that help explore model response as a function of selected pairs of variables. +The explanations are calculated as an extension of Partial Dependence Profiles or Accumulated Local Effects with the inclusion of the time dimension. +} +\examples{ +\donttest{ +library(survival) +library(survex) + +cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) +rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) + +cph_exp <- explain(cph) +rsf_src_exp <- explain(rsf_src) + +cph_model_profile_2d <- model_profile_2d(cph_exp, + variables = list(c("age", "celltype"))) +head(cph_model_profile_2d$result) +plot(cph_model_profile_2d) + +rsf_model_profile_2d <- model_profile_2d(rsf_src_exp, + variables = list(c("age", "karno")), + type = "accumulated") +head(rsf_model_profile_2d$result) +plot(rsf_model_profile_2d) +} + +} From 1e45399e184bd375846a33dd007b94ae17a23923 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 15:06:25 +0200 Subject: [PATCH 068/207] add tests --- tests/testthat/test-model_profile_2d.R | 80 ++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/testthat/test-model_profile_2d.R diff --git a/tests/testthat/test-model_profile_2d.R b/tests/testthat/test-model_profile_2d.R new file mode 100644 index 00000000..b25c6bce --- /dev/null +++ b/tests/testthat/test-model_profile_2d.R @@ -0,0 +1,80 @@ +test_that("model_profile_2d with type = 'partial' works", { + veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] + + cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) + rsf <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + + cph_exp <- explain(cph, verbose = FALSE) + rsf_exp <- explain(rsf, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + + mp_cph_pdp <- model_profile_2d(cph_exp, + variable_splits_type = "uniform", + variables = list(c("trt", "age"), + c("karno", "trt"), + c("karno", "age")), + categorical_variables = "trt", + grid_points = 6) + plot(mp_cph_pdp) + + mp_rsf_pdp <- model_profile_2d(rsf_exp, + variables = list(c("karno", "age")), + grid_points = 6, + output_type = "survival", + N = 25) + plot(mp_cph_pdp, mp_rsf_pdp, variables = list(c("karno", "age"))) + + expect_output(print(mp_cph_pdp)) + expect_s3_class(mp_cph_pdp, "model_profile_2d_survival") + expect_true(all(mp_cph_pdp$eval_times == cph_exp$times)) + expect_equal(ncol(mp_cph_pdp$result), 9) + expect_true(all(unique(c(mp_cph_pdp$result$`_v1name_`, mp_cph_pdp$result$`_v2name_`)) + %in% colnames(cph_exp$data))) + + expect_error(model_profile_2d(rsf_exp)) + expect_error(model_profile_2d(rsf_exp, type = "conditional", + variables = list(c("karno", "age")))) + expect_error(model_profile_2d(rsf_exp, output_type = "risk", + variables = list(c("karno", "age")))) + } +) + +test_that("model_profile_2d with type = 'accumulated' works", { + veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] + + rsf <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_exp <- explain(rsf, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + + mp_rsf_ale <- model_profile_2d(rsf_exp, + variable_splits_type = "quantiles", + variables = list(c("karno", "age")), + grid_points = 6, + output_type = "survival", + type = "accumulated") + + mp_rsf_ale_noncentered <- model_profile_2d(rsf_exp, + variable_splits_type = "quantiles", + variables = list(c("karno", "age")), + grid_points = 6, + output_type = "survival", + type = "accumulated", + center = FALSE) + plot(mp_rsf_ale) + plot(mp_rsf_ale, times=rsf_exp$times[1]) + + expect_output(print(mp_rsf_ale)) + expect_s3_class(mp_rsf_ale, "model_profile_2d_survival") + expect_true(all(unique(mp_rsf_ale$eval_times) == rsf_exp$times)) + expect_equal(ncol(mp_rsf_ale$result), 14) + expect_true(all(unique(c(mp_rsf_ale$result$`_v1name_`, mp_rsf_ale$result$`_v2name_`)) + %in% colnames(rsf_exp$data))) + + expect_error(plot(mp_rsf_ale, variables = "nonexistent")) + expect_error(plot(mp_rsf_ale, + variables = list(c("karno", "trt")), + categorical_variables="trt")) + expect_error(model_profile_2d(mp_rsf_ale, + variables = list(c("karno", "trt")), + categorical_variables="trt", + type = "accumulated")) + expect_error(plot(mp_rsf_ale, times = -1)) +}) From 193ea0ac8da9e993fc485f131c4a1e43fbf0b391 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 15:36:49 +0200 Subject: [PATCH 069/207] remove unnecessary docs --- R/surv_model_profiles.R | 22 ---------------------- man/surv_aggregate_profiles.Rd | 33 --------------------------------- 2 files changed, 55 deletions(-) delete mode 100644 man/surv_aggregate_profiles.Rd diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index c8298990..fb9bd4b3 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -1,13 +1,3 @@ -#' Helper functions for `model_profile.R` -#' -#' @param x an object containing calculated ceteris_paribus profiles -#' @param ... other parameters, ignored -#' @param variable_type character, either `"numerical"` or `"categorical"`, the type of variable to be calculated, if left `NULL` (default), both are calculated -#' @param variables a character vector containing names of variables to be explained -#' @param center logical, if the profiles should be centered before aggregations -#' -#' @return A data.frame with calculated results. -#' #' @keywords internal surv_aggregate_profiles <- function(x, ..., @@ -115,19 +105,7 @@ surv_aggregate_profiles_partial <- function(all_profiles) { aggregated_profiles } - -#' @param x an explainer object - model preprocessed by the `explain()` function -#' @param ... other parameters, ignored -#' @param data data used to create explanations -#' @param variables a character vector containing names of variables to be explained -#' @param categorical_variables character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the `variables` argument, they will be added at the end. -#' @param grid_points maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default `51`. -#' @param center logical, if the profiles should be centered before aggregations -#' #' @importFrom stats ave -#' -#' @return A data.frame with calculated results. -#' #' @keywords internal surv_ale <- function(x, ..., diff --git a/man/surv_aggregate_profiles.Rd b/man/surv_aggregate_profiles.Rd deleted file mode 100644 index 89570cf3..00000000 --- a/man/surv_aggregate_profiles.Rd +++ /dev/null @@ -1,33 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/surv_model_profiles.R -\name{surv_aggregate_profiles} -\alias{surv_aggregate_profiles} -\title{Helper functions for \code{model_profile.R}} -\usage{ -surv_aggregate_profiles( - x, - ..., - variable_type = NULL, - groups = NULL, - variables = NULL, - center = FALSE -) -} -\arguments{ -\item{x}{an object containing calculated ceteris_paribus profiles} - -\item{...}{other parameters, ignored} - -\item{variable_type}{character, either \code{"numerical"} or \code{"categorical"}, the type of variable to be calculated, if left \code{NULL} (default), both are calculated} - -\item{variables}{a character vector containing names of variables to be explained} - -\item{center}{logical, if the profiles should be centered before aggregations} -} -\value{ -A data.frame with calculated results. -} -\description{ -Helper functions for \code{model_profile.R} -} -\keyword{internal} From 87f820f2644f90269767a8d54cd837e4244682f3 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 15:36:58 +0200 Subject: [PATCH 070/207] update documentation --- man/plot2.model_profile_survival.Rd | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/man/plot2.model_profile_survival.Rd b/man/plot2.model_profile_survival.Rd index 9d345e45..172048fa 100644 --- a/man/plot2.model_profile_survival.Rd +++ b/man/plot2.model_profile_survival.Rd @@ -12,9 +12,9 @@ plot2(x, ...) variable, times = NULL, marginalize_over_time = FALSE, - plot_type = "pdp+ice", + plot_type = NULL, ..., - title = "Partial dependence profile", + title = "default", subtitle = "default", colors = NULL ) @@ -30,9 +30,9 @@ plot2(x, ...) \item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately} -\item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"} selects the type of plot to be drawn} +\item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"}, or \code{"ale"} selects the type of plot to be drawn} -\item{title}{character, title of the plot} +\item{title}{character, title of the plot. \code{'default'} automatically generates either "Partial dependence survival profiles" or "Accumulated local effects survival profiles" depending on the explanation type.} \item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} From 2b4d4031042fdd3eb37df417c2b7e4893e95adaf Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 15:42:56 +0200 Subject: [PATCH 071/207] add variable_splits_type to docs --- R/model_profile.R | 3 +++ man/model_profile.surv_explainer.Rd | 3 +++ 2 files changed, 6 insertions(+) diff --git a/R/model_profile.R b/R/model_profile.R index e97e5401..350f4817 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -10,6 +10,7 @@ #' @param ... other parameters passed to `DALEX::model_profile` if `output_type == "risk"`, otherwise ignored #' @param categorical_variables character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the `variables` argument, they will be added at the end. #' @param grid_points maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default `51`. +#' @param variable_splits_type character, decides how variable grids should be calculated. Use `"quantiles"` for percentiles or `"uniform"` (default) to get uniform grid of points. #' @param groups if `output_type == "risk"` a variable name that will be used for grouping. By default `NULL`, so no groups are calculated. If `output_type == "survival"` then ignored #' @param k passed to `DALEX::model_profile` if `output_type == "risk"`, otherwise ignored #' @param center logical, should profiles be centered before clustering @@ -66,6 +67,7 @@ model_profile.surv_explainer <- function(explainer, ..., categorical_variables = NULL, grid_points = 51, + variable_splits_type = "uniform", groups = NULL, k = NULL, center = TRUE, @@ -98,6 +100,7 @@ model_profile.surv_explainer <- function(explainer, variables = variables, categorical_variables = categorical_variables, grid_points = grid_points, + variable_splits_type = variable_splits_type, ...) result <- surv_aggregate_profiles(cp_profiles, ..., diff --git a/man/model_profile.surv_explainer.Rd b/man/model_profile.surv_explainer.Rd index 41f1a487..2c39b243 100644 --- a/man/model_profile.surv_explainer.Rd +++ b/man/model_profile.surv_explainer.Rd @@ -24,6 +24,7 @@ model_profile( ..., categorical_variables = NULL, grid_points = 51, + variable_splits_type = "uniform", groups = NULL, k = NULL, center = TRUE, @@ -53,6 +54,8 @@ model_profile( \item{categorical_variables}{character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the \code{variables} argument, they will be added at the end.} \item{grid_points}{maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default \code{51}.} + +\item{variable_splits_type}{character, decides how variable grids should be calculated. Use \code{"quantiles"} for percentiles or \code{"uniform"} (default) to get uniform grid of points.} } \value{ An object of class \code{model_profile_survival}. It is a list with the element \code{result} containing the results of the calculation. From 3c61103910a36aa30e3f79181fcae1362c2ada71 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 15:48:20 +0200 Subject: [PATCH 072/207] fix variable handling --- R/plot_surv_ceteris_paribus.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index b28f625a..4b180202 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -132,12 +132,13 @@ prepare_ceteris_paribus_plots <- function(x, all_variables <- na.omit(as.character(unique(all_profiles$`_vname_`))) if (!is.null(variables)) { - all_variables <- intersect(all_variables, variables) - if (length(all_variables) == 0) + variables <- intersect(all_variables, variables) + if (length(variables) == 0) stop(paste0( "variables do not overlap with ", paste(all_variables, collapse = ", ") )) + all_variables <- variables } From 9021c8cd25a343e8f26777c905de90e354c4137a Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 16:26:44 +0200 Subject: [PATCH 073/207] add missing import --- NAMESPACE | 1 + 1 file changed, 1 insertion(+) diff --git a/NAMESPACE b/NAMESPACE index b0ece826..b0ee1154 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -94,6 +94,7 @@ importFrom(stats,optim) importFrom(stats,predict) importFrom(stats,quantile) importFrom(stats,reorder) +importFrom(stats,reshape) importFrom(stats,rnorm) importFrom(stats,stepfun) importFrom(stats,xtabs) From 1f53f37d6d985a51f205e75255dc823a16b2a467 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 16:26:59 +0200 Subject: [PATCH 074/207] add missing import & remove ... parameters --- R/model_profile_2d.R | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index 5a9aa662..ab7d159c 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -45,7 +45,6 @@ model_profile_2d <- function(explainer, variables = NULL, N = 100, - ..., categorical_variables = NULL, grid_points = 25, center = TRUE, @@ -60,7 +59,6 @@ model_profile_2d <- function(explainer, model_profile_2d.surv_explainer <- function(explainer, variables = NULL, N = 100, - ..., categorical_variables = NULL, grid_points = 25, center = TRUE, @@ -100,8 +98,7 @@ model_profile_2d.surv_explainer <- function(explainer, variables = variables, categorical_variables = categorical_variables, grid_points = grid_points, - variable_splits_type = variable_splits_type, - ... + variable_splits_type = variable_splits_type ) } else if (type == "accumulated") { result <- surv_ale_2d( @@ -110,8 +107,7 @@ model_profile_2d.surv_explainer <- function(explainer, variables = variables, categorical_variables = categorical_variables, grid_points = grid_points, - center = center, - ... + center = center ) } else { stop("Currently only `partial` and `accumulated` types are implemented") @@ -132,8 +128,8 @@ surv_pdp_2d <- function(x, variables, categorical_variables, grid_points, - variable_splits_type, - ...) { + variable_splits_type + ) { model <- x$model label <- x$label predict_survival_function <- x$predict_survival_function @@ -182,8 +178,8 @@ surv_ale_2d <- function(x, variables, categorical_variables, grid_points, - center, - ...){ + center + ){ model <- x$model label <- x$label predict_survival_function <- x$predict_survival_function @@ -222,7 +218,7 @@ surv_ale_2d <- function(x, profiles } - +#' @importFrom stats reshape surv_ale_2d_num_num <- function(model, data, label, From 894d97abd4e685b1aef0cc8fc9e294a604f271c8 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 16:27:12 +0200 Subject: [PATCH 075/207] remove ... parameters --- man/model_profile_2d.surv_explainer.Rd | 2 -- 1 file changed, 2 deletions(-) diff --git a/man/model_profile_2d.surv_explainer.Rd b/man/model_profile_2d.surv_explainer.Rd index 82123d81..8a6b3250 100644 --- a/man/model_profile_2d.surv_explainer.Rd +++ b/man/model_profile_2d.surv_explainer.Rd @@ -9,7 +9,6 @@ model_profile_2d( explainer, variables = NULL, N = 100, - ..., categorical_variables = NULL, grid_points = 25, center = TRUE, @@ -22,7 +21,6 @@ model_profile_2d( explainer, variables = NULL, N = 100, - ..., categorical_variables = NULL, grid_points = 25, center = TRUE, From 59b6749f06c5be68fb5e6359665be9edf21909a4 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 16:27:25 +0200 Subject: [PATCH 076/207] edit tests --- tests/testthat/test-model_profile.R | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 6fb7abbd..c24a22e0 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -73,16 +73,14 @@ test_that("model_profile with type = 'partial' works", { expect_output(print(mp_cph_num)) expect_error(plot(mp_rsf_num, variables = "nonexistent", grid_points = 6)) - - model_profile(rsf_ranger_exp, varaibles = "trt", grid_points = 6) expect_error(model_profile(rsf_ranger_exp, type = "conditional")) - + expect_error(plot2(mp_rsf_num, variable = "nonexistent")) + expect_error(plot2(mp_rsf_num, variable = "age", times = -1)) }) test_that("model_profile with type = 'accumulated' works", { veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] - type <- 'accumulated' cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) @@ -97,7 +95,7 @@ test_that("model_profile with type = 'accumulated' works", { output_type = "survival", variable_type = "categorical", grid_points = 6, - type = type) + type = 'accumulated') plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") ### Add tests for plot2 for categorical ALE @@ -116,7 +114,7 @@ test_that("model_profile with type = 'accumulated' works", { output_type = "survival", variable_type = "numerical", grid_points = 6, - type = type) + type = 'accumulated') plot(mp_cph_num, variable_type = "numerical") plot(mp_cph_num, numerical_plot_type = "contours") From 8b607981c8f63976e6aa6158770c5b831d0e12e8 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 16:37:54 +0200 Subject: [PATCH 077/207] add test for rms::cph model --- tests/testthat/test-explain.R | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-explain.R index e17cc1be..6cc6149e 100644 --- a/tests/testthat/test-explain.R +++ b/tests/testthat/test-explain.R @@ -306,6 +306,8 @@ test_that("default methods for creating explainers work correctly", { expect_error(explain(cph_wihtout_params_2, verbose = FALSE)) expect_output(explain(cph)) + + ### ranger::ranger ### rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) @@ -314,6 +316,7 @@ test_that("default methods for creating explainers work correctly", { expect_s3_class(rsf_ranger_exp, c("surv_explainer", "explainer")) expect_equal(rsf_ranger_exp$label, "ranger") + ### randomForestSRC::rfsrc ### rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) @@ -322,6 +325,16 @@ test_that("default methods for creating explainers work correctly", { expect_s3_class(rsf_src_exp, c("surv_explainer", "explainer")) expect_equal(rsf_src_exp$label, "rfsrc", ignore_attr = TRUE) + + ### rms::cph ### + surv <- survival::Surv(veteran$time, veteran$status) + cph <- rms::cph(surv ~ trt + celltype + karno + diagtime + age + prior, + data = veteran, surv=TRUE, model=TRUE, x=TRUE, y=TRUE) + cph_rms_exp <- explain(cph) + expect_s3_class(cph_rms_exp, c("surv_explainer", "explainer")) + expect_equal(cph_rms_exp$label, "coxph", ignore_attr = TRUE) + + ### parsnip::boost_tree ### library(censored, quietly = TRUE) From 30ca56498fa7fc2bfdefa11c633550c157360ae6 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 3 Aug 2023 16:52:25 +0200 Subject: [PATCH 078/207] add test for exact_kernel calculation method --- tests/testthat/test-predict_parts.R | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index c47c8a77..45d00f75 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -10,6 +10,7 @@ test_that("survshap explanations work", { rsf_src_exp <- explain(rsf_src, verbose = FALSE) parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares") + parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), calculation_method = "exact_kernel") plot(parts_cph) plot(parts_cph, rug = "events") plot(parts_cph, rug = "censors") From cb02f16a7124f619cd51aff19f1d82541f6ac4c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Thu, 3 Aug 2023 17:23:47 +0200 Subject: [PATCH 079/207] Fix ale plots and uncomment tests --- R/plot_model_profile_survival.R | 32 ++++++++++++++------------- man/plot2.model_profile_survival.Rd | 8 +++---- tests/testthat/test-model_profile.R | 34 ++++++++++++++--------------- 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 6647f23c..2993dbad 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -48,11 +48,13 @@ plot.model_profile_survival <- function(x, rug_colors = c("#dd0000", "#222222")) { explanations_list <- c(list(x), list(...)) num_models <- length(explanations_list) - if (title == "default"){ - if (x$type == "partial") + if (title == "default") { + if (x$type == "partial") { title <- "Partial dependence survival profiles" - if (x$type == "accumulated") + } + if (x$type == "accumulated") { title <- "Accumulated local effects survival profiles" + } } if (num_models == 1) { @@ -144,11 +146,10 @@ plot2.model_profile_survival <- function(x, title = "default", subtitle = "default", colors = NULL) { - - if (is.null(plot_type)) { - if (x$type == "accumulated") plot_type = "ale" - else if (x$type == "partial") plot_type = "pdp+ice" + if (x$type == "accumulated") { + plot_type <- "ale" + } else if (x$type == "partial") plot_type <- "pdp+ice" } if (x$type == "accumulated" && plot_type != "ale") { @@ -184,10 +185,12 @@ plot2.model_profile_survival <- function(x, } if (title == "default") { - if (x$type == "partial") + if (x$type == "partial") { title <- "Partial dependence survival profiles" - if (x$type == "accumulated") + } + if (x$type == "accumulated") { title <- "Accumulated local effects survival profiles" + } } if (!is.null(subtitle) && subtitle == "default") { @@ -196,7 +199,7 @@ plot2.model_profile_survival <- function(x, single_timepoint <- ((length(times) == 1) || marginalize_over_time) is_categorical <- (unique(x$result[x$result$`_vname_` == variable, "_vtype_"]) == "categorical") - ice_needed <- plot_type %in% c("pdp+ice", "ice") + ice_needed <- plot_type %in% c("pdp+ice", "ice") if (single_timepoint) { pdp_df <- x$result[(x$result$`_vname_` == variable) & (x$result$`_times_` %in% times), c("_x_", "_yhat_")] @@ -207,7 +210,7 @@ plot2.model_profile_survival <- function(x, pdp_df$time <- as.factor(pdp_df$time) } - if (ice_needed){ + if (ice_needed) { ice_df <- x$cp_profiles$result[(x$cp_profiles$result$`_vname_` == variable) & (x$cp_profiles$result$`_times_` %in% times), ] @@ -216,7 +219,6 @@ plot2.model_profile_survival <- function(x, } else { colnames(ice_df)[colnames(ice_df) == "_times_"] <- "time" ice_df$time <- as.factor(ice_df$time) - } if (is_categorical) { @@ -249,7 +251,7 @@ plot2.model_profile_survival <- function(x, y_ceiling_pd <- ceiling(max(pdp_df[, "pd"]) * 10) / 10 if (marginalize_over_time) { - pdp_df <- aggregate(pd ~., data = pdp_df, mean) + pdp_df <- aggregate(pd ~ ., data = pdp_df, mean) ice_df <- aggregate(predictions ~ ., data = ice_df, mean) color_scale <- generate_discrete_color_scale(1, colors) @@ -384,7 +386,7 @@ plot_pdp_cat <- function(pdp_dt, geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = "gold") + scale_color_manual(name = "time", values = colors) - } else if (plot_type == "pdp") { + } else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd), ) + geom_bar(stat = "identity", width = 0.5) + scale_fill_manual(name = "time", values = colors) @@ -399,7 +401,7 @@ plot_pdp_cat <- function(pdp_dt, geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = time), linewidth = 0.6) + scale_color_manual(name = "time", values = colors) - } else if (plot_type == "pdp") { + } else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, fill = time)) + geom_bar(stat = "identity", width = 0.5, position = "dodge") + scale_fill_manual(name = "time", values = colors) diff --git a/man/plot2.model_profile_survival.Rd b/man/plot2.model_profile_survival.Rd index 9d345e45..172048fa 100644 --- a/man/plot2.model_profile_survival.Rd +++ b/man/plot2.model_profile_survival.Rd @@ -12,9 +12,9 @@ plot2(x, ...) variable, times = NULL, marginalize_over_time = FALSE, - plot_type = "pdp+ice", + plot_type = NULL, ..., - title = "Partial dependence profile", + title = "default", subtitle = "default", colors = NULL ) @@ -30,9 +30,9 @@ plot2(x, ...) \item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately} -\item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"} selects the type of plot to be drawn} +\item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"}, or \code{"ale"} selects the type of plot to be drawn} -\item{title}{character, title of the plot} +\item{title}{character, title of the plot. \code{'default'} automatically generates either "Partial dependence survival profiles" or "Accumulated local effects survival profiles" depending on the explanation type.} \item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 6fb7abbd..902a9f37 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -1,5 +1,4 @@ test_that("model_profile with type = 'partial' works", { - veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) @@ -76,13 +75,11 @@ test_that("model_profile with type = 'partial' works", { model_profile(rsf_ranger_exp, varaibles = "trt", grid_points = 6) expect_error(model_profile(rsf_ranger_exp, type = "conditional")) - - }) +}) test_that("model_profile with type = 'accumulated' works", { - veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] - type <- 'accumulated' + type <- "accumulated" cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) @@ -94,17 +91,18 @@ test_that("model_profile with type = 'accumulated' works", { mp_cph_cat <- model_profile(cph_exp, - output_type = "survival", - variable_type = "categorical", - grid_points = 6, - type = type) + output_type = "survival", + variable_type = "categorical", + grid_points = 6, + type = type + ) plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") ### Add tests for plot2 for categorical ALE # single timepoint - # plot2(mp_cph_cat, variable = "celltype", plot_type = "ale") + plot2(mp_cph_cat, variable = "celltype", plot_type = "ale") # multiple timepoints - # plot2(mp_cph_cat, times = c(4, 5.84), variable = "celltype", plot_type = "ale") + plot2(mp_cph_cat, times = c(4, 5.84), variable = "celltype", plot_type = "ale") expect_s3_class(mp_cph_cat, "model_profile_survival") expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) @@ -113,18 +111,19 @@ test_that("model_profile with type = 'accumulated' works", { mp_cph_num <- model_profile(cph_exp, - output_type = "survival", - variable_type = "numerical", - grid_points = 6, - type = type) + output_type = "survival", + variable_type = "numerical", + grid_points = 6, + type = type + ) plot(mp_cph_num, variable_type = "numerical") plot(mp_cph_num, numerical_plot_type = "contours") ### Add tests for plot2 for numerical ALE # single timepoint - # plot2(mp_cph_num, variable = "karno", plot_type = "pdp") + plot2(mp_cph_num, variable = "karno", plot_type = "ale") # multiple timepoints - # plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "pdp") + plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "ale") expect_s3_class(mp_cph_num, "model_profile_survival") expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) @@ -136,7 +135,6 @@ test_that("model_profile with type = 'accumulated' works", { }) test_that("default DALEX::model_profile is ok", { - veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) From deaad20bd7cf57780c83e3b4ad10432255aadeef Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 09:59:01 +0200 Subject: [PATCH 080/207] change times in test --- tests/testthat/test-model_profile.R | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 526790d9..3eeef633 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -29,9 +29,9 @@ test_that("model_profile with type = 'partial' works", { plot2(mp_cph_num, variable = "karno", plot_type = "pdp") plot2(mp_cph_num, variable = "karno", plot_type = "ice") # multiple timepoints - plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "pdp+ice") - plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "pdp") - plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "ice") + plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "pdp+ice") + plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "pdp") + plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "ice") expect_s3_class(mp_cph_num, "model_profile_survival") expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) @@ -48,10 +48,10 @@ test_that("model_profile with type = 'partial' works", { plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp") plot2(mp_rsf_cat, variable = "celltype", plot_type = "ice") # multiple timepoints - plot2(mp_rsf_cat, times = c(4, 5.84), variable = "celltype", plot_type = "pdp+ice") - plot2(mp_rsf_cat, times = c(4, 5.84), marginalilze_over_time = T, variable = "celltype", plot_type = "pdp+ice") - plot2(mp_rsf_cat, times = c(4, 5.84), variable = "celltype", plot_type = "pdp") - plot2(mp_rsf_cat, times = c(4, 5.84), variable = "celltype", plot_type = "ice") + plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "pdp+ice") + plot2(mp_rsf_cat, times = c(4, 80.7), marginalilze_over_time = T, variable = "celltype", plot_type = "pdp+ice") + plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "pdp") + plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "ice") expect_s3_class(mp_rsf_cat, "model_profile_survival") @@ -91,16 +91,16 @@ test_that("model_profile with type = 'accumulated' works", { mp_cph_cat <- model_profile(cph_exp, output_type = "survival", - variable_type = "categorical", grid_points = 6, - type = 'accumulated') + type = 'accumulated', + categorical_variables = "trt") plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") ### Add tests for plot2 for categorical ALE # single timepoint plot2(mp_cph_cat, variable = "celltype", plot_type = "ale") # multiple timepoints - plot2(mp_cph_cat, times = c(4, 5.84), variable = "celltype", plot_type = "ale") + plot2(mp_cph_cat, times = c(4, 80.7), variable = "celltype", plot_type = "ale") expect_s3_class(mp_cph_cat, "model_profile_survival") expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) @@ -120,7 +120,7 @@ test_that("model_profile with type = 'accumulated' works", { # single timepoint plot2(mp_cph_num, variable = "karno", plot_type = "ale") # multiple timepoints - plot2(mp_cph_num, times = c(4, 5.84), variable = "karno", plot_type = "ale") + plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "ale") expect_s3_class(mp_cph_num, "model_profile_survival") expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) From 693f74503b36100ecd60e834d2ec52ce61d3e067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 4 Aug 2023 10:10:47 +0200 Subject: [PATCH 081/207] Fix colors in plots --- R/plot_model_profile_survival.R | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 2993dbad..c22fba04 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -315,7 +315,7 @@ plot_pdp_num <- function(pdp_dt, if (single_timepoint == TRUE) { ## single timepoint if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + - geom_line(alpha = 0.2, mapping = aes(group = id)) + + geom_line(alpha = 0.2, mapping = aes(group = id), color = colors_discrete_drwhy(1)) + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + ylim(y_floor_ice, y_ceiling_ice) } @@ -323,14 +323,14 @@ plot_pdp_num <- function(pdp_dt, else if (plot_type == "pdp+ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(mapping = aes(group = id), alpha = 0.2) + - geom_line(data = pdp_dt, aes(x = !!feature_name_sym, y = pd), linewidth = 2, color = "gold") + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + geom_line(data = pdp_dt, aes(x = !!feature_name_sym, y = pd), linewidth = 2, color = colors_discrete_drwhy(1)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter", color = colors_discrete_drwhy(1)) + ylim(y_floor_ice, y_ceiling_ice) } # PDP else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + - geom_line() + + geom_line(color = colors_discrete_drwhy(1)) + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + ylim(y_floor_pd, y_ceiling_pd) } @@ -379,16 +379,16 @@ plot_pdp_cat <- function(pdp_dt, if (single_timepoint == TRUE) { ## single timepoint if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + - geom_boxplot(alpha = 0.2) + + geom_boxplot(alpha = 0.2, color = colors_discrete_drwhy(1)) + scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp+ice") { ggplot() + - geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + - geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = "gold") + + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2, color = colors_discrete_drwhy(1)) + + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = colors_discrete_drwhy(1)) + scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd), ) + - geom_bar(stat = "identity", width = 0.5) + + geom_bar(stat = "identity", width = 0.5, fill = colors_discrete_drwhy(1)) + scale_fill_manual(name = "time", values = colors) } } else { From 9015d4555c8062aa29cbff6b0916e9af542c3be3 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 10:11:02 +0200 Subject: [PATCH 082/207] remove variable_type from tests (unused) --- tests/testthat/test-model_profile.R | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 3eeef633..73534a14 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -10,7 +10,7 @@ test_that("model_profile with type = 'partial' works", { rsf_src_exp <- explain(rsf_src, verbose = FALSE) - mp_cph_cat <- model_profile(cph_exp, output_type = "survival", variable_splits_type = "quantiles", variable_type = "categorical", grid_points = 6) + mp_cph_cat <- model_profile(cph_exp, output_type = "survival", variable_splits_type = "quantiles", grid_points = 6) plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") expect_s3_class(mp_cph_cat, "model_profile_survival") @@ -19,7 +19,7 @@ test_that("model_profile with type = 'partial' works", { expect_true(all(unique(mp_cph_cat$result$`_vname_`) %in% colnames(cph_exp$data))) - mp_cph_num <- model_profile(cph_exp, output_type = "survival", variable_splits_type = "quantiles", variable_type = "numerical", grid_points = 6) + mp_cph_num <- model_profile(cph_exp, output_type = "survival", variable_splits_type = "quantiles", grid_points = 6) plot(mp_cph_num, variable_type = "numerical") plot(mp_cph_num, numerical_plot_type = "contours") @@ -39,7 +39,7 @@ test_that("model_profile with type = 'partial' works", { expect_true(all(unique(mp_cph_num$result$`_vname_`) %in% colnames(cph_exp$data))) - mp_rsf_cat <- model_profile(rsf_ranger_exp, output_type = "survival", variable_splits_type = "uniform", variable_type = "categorical", grid_points = 6) + mp_rsf_cat <- model_profile(rsf_ranger_exp, output_type = "survival", variable_splits_type = "uniform", grid_points = 6) plot(mp_rsf_cat, variable_type = "categorical") ### Add tests for plot2 for categorical PDP @@ -60,7 +60,7 @@ test_that("model_profile with type = 'partial' works", { expect_true(all(unique(mp_rsf_cat$result$`_vname_`) %in% colnames(rsf_ranger_exp$data))) - mp_rsf_num <- model_profile(rsf_ranger_exp, output_type = "survival", variable_splits_type = "uniform", variable_type = "numerical", grid_points = 6) + mp_rsf_num <- model_profile(rsf_ranger_exp, output_type = "survival", variable_splits_type = "uniform", grid_points = 6) plot(mp_rsf_num, variable_type = "numerical") plot(mp_rsf_num, variable_type = "numerical", numerical_plot_type = "contours") @@ -110,9 +110,9 @@ test_that("model_profile with type = 'accumulated' works", { mp_cph_num <- model_profile(cph_exp, output_type = "survival", - variable_type = "numerical", grid_points = 6, - type = 'accumulated') + type = 'accumulated', + categorical_variables = "trt") plot(mp_cph_num, variable_type = "numerical") plot(mp_cph_num, numerical_plot_type = "contours") From ca9b6df886ee4684986b83c47f36e274a10396f4 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 10:11:32 +0200 Subject: [PATCH 083/207] fix factorizing not factor variables --- R/surv_model_profiles.R | 26 +++++++++++++------------- R/utils.R | 11 +++++------ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index fb9bd4b3..d2a18875 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -1,7 +1,6 @@ #' @keywords internal surv_aggregate_profiles <- function(x, ..., - variable_type = NULL, groups = NULL, variables = NULL, center = FALSE) { @@ -24,16 +23,6 @@ surv_aggregate_profiles <- function(x, all_variables <- all_variables_intersect } - if (!is.null(variable_type) && variable_type == "numerical") { - all_profiles <- - all_profiles[all_profiles$`_vtype_` == "numerical",] - } - - if (!is.null(variable_type) && variable_type == "categorical") { - all_profiles <- - all_profiles[all_profiles$`_vtype_` == "categorical",] - } - all_variables <- intersect(all_variables, unique(all_profiles$`_vname_`)) @@ -149,15 +138,18 @@ surv_ale <- function(x, if (variable %in% categorical_variables) { if (!is.factor(variable_values)){ - data[, variable] <- as.factor(data[, variable]) + is_numeric <- is.numeric(variable_values) + is_factorized <- TRUE variable_values <- as.factor(variable_values) + } else { + is_factorized <- FALSE } levels_original <- levels(droplevels(variable_values)) levels_n <- nlevels(droplevels(variable_values)) if (inherits(variable_values, "ordered")) { level_order <- 1:levels_n } else { - level_order <- order_levels(data, variable) + level_order <- order_levels(data, variable_values, variable) } # The new order of the levels @@ -170,6 +162,14 @@ surv_ale <- function(x, # Filter rows which are not already at maximum or minimum level values row_ind_increase <- (1:nrow(data))[x_ordered < levels_n] row_ind_decrease <- (1:nrow(data))[x_ordered > 1] + + if (is_factorized){ + levels_ordered <- as.character(levels_ordered) + if (is_numeric){ + levels_ordered <- as.numeric(levels_ordered) + } + } + X_lower[row_ind_decrease, variable] <- levels_ordered[x_ordered[row_ind_decrease] - 1] X_upper[row_ind_increase, variable] <- diff --git a/R/utils.R b/R/utils.R index a4e1f8bc..10afc5af 100644 --- a/R/utils.R +++ b/R/utils.R @@ -189,14 +189,13 @@ add_rug_to_plot <- function(base_plot, rug_df, rug, rug_colors){ # based on iml::order_levels #' @importFrom stats ecdf xtabs cmdscale #' @keywords internal -order_levels <- function(data, variable) { - data[, variable] <- droplevels(data[, variable]) - feature <- data[, variable] - x.count <- as.numeric(table(data[, variable])) +order_levels <- function(data, variable_values, variable_name) { + feature <- droplevels(variable_values) + x.count <- as.numeric(table(feature)) x.prob <- x.count / sum(x.count) - K <- nlevels(data[, variable]) + K <- nlevels(feature) - dists <- lapply(setdiff(colnames(data), variable), function(x) { + dists <- lapply(setdiff(colnames(data), variable_name), function(x) { feature.x <- data[, x] dists <- expand.grid(levels(feature), levels(feature)) colnames(dists) <- c("from.level", "to.level") From d0ed5a1b7bf845cb5d24a25aad74366351c312af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 4 Aug 2023 10:24:19 +0200 Subject: [PATCH 084/207] Fix dodge on categorical pdp plots --- R/plot_model_profile_survival.R | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index c22fba04..734e908b 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -383,12 +383,13 @@ plot_pdp_cat <- function(pdp_dt, scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp+ice") { ggplot() + - geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2, color = colors_discrete_drwhy(1)) + - geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = colors_discrete_drwhy(1)) + + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2, color = colors_discrete_drwhy(1), width = 0.7) + + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = 1), linewidth = 2, color = colors_discrete_drwhy(1), position = position_dodge(0.7)) + scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd), ) + geom_bar(stat = "identity", width = 0.5, fill = colors_discrete_drwhy(1)) + + scale_y_continuous(expand = c(0, NA)) + scale_fill_manual(name = "time", values = colors) } } else { @@ -398,12 +399,13 @@ plot_pdp_cat <- function(pdp_dt, scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp+ice") { ggplot(mapping = aes(color = time)) + - geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2) + - geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = time), linewidth = 0.6) + + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2, width = 0.7) + + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = time), linewidth = 0.6, position = position_dodge(0.7)) + scale_color_manual(name = "time", values = colors) } else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, fill = time)) + geom_bar(stat = "identity", width = 0.5, position = "dodge") + + scale_y_continuous(expand = c(0, NA)) + scale_fill_manual(name = "time", values = colors) } } From 57c30699223f84a46bfc0a9bcc146a7e5e1751c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 4 Aug 2023 10:27:24 +0200 Subject: [PATCH 085/207] Fix verbose explainer output in test --- tests/testthat/test-explain.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-explain.R index 6cc6149e..68e1962d 100644 --- a/tests/testthat/test-explain.R +++ b/tests/testthat/test-explain.R @@ -330,7 +330,7 @@ test_that("default methods for creating explainers work correctly", { surv <- survival::Surv(veteran$time, veteran$status) cph <- rms::cph(surv ~ trt + celltype + karno + diagtime + age + prior, data = veteran, surv=TRUE, model=TRUE, x=TRUE, y=TRUE) - cph_rms_exp <- explain(cph) + cph_rms_exp <- explain(cph, verbose = FALSE) expect_s3_class(cph_rms_exp, c("surv_explainer", "explainer")) expect_equal(cph_rms_exp$label, "coxph", ignore_attr = TRUE) From c7d34afd56299b702d0f1d7395c8f125c29cb681 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 11:19:14 +0200 Subject: [PATCH 086/207] add cp_profiles for ale --- R/model_profile.R | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/R/model_profile.R b/R/model_profile.R index 350f4817..cd04019b 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -86,12 +86,11 @@ model_profile.surv_explainer <- function(explainer, type = type), "survival" = { test_explainer(explainer, "model_profile", has_data = TRUE, has_survival = TRUE) - data <- explainer$data if (!is.null(N) && N < nrow(data)) { - ndata <- data[sample(1:nrow(data), N), ] + ndata <- data[sample(1:nrow(data), N), , drop = FALSE] } else { - ndata <- data + ndata <- data[1:nrow(data), , drop = FALSE] } if (type == "partial"){ @@ -107,8 +106,7 @@ model_profile.surv_explainer <- function(explainer, variables = variables, center = center) } else if (type == "accumulated"){ - cp_profiles <- NULL - + cp_profiles <- list(variable_values = data.frame(ndata)) result <- surv_ale(explainer, data = ndata, variables = variables, From 41cc6846304bd5088a6b4dbe06a4ad5f702c6269 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 11:19:28 +0200 Subject: [PATCH 087/207] fix rug jitter --- R/plot_model_profile_survival.R | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 734e908b..fefed694 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -253,14 +253,11 @@ plot2.model_profile_survival <- function(x, if (marginalize_over_time) { pdp_df <- aggregate(pd ~ ., data = pdp_df, mean) ice_df <- aggregate(predictions ~ ., data = ice_df, mean) - color_scale <- generate_discrete_color_scale(1, colors) } else { color_scale <- generate_discrete_color_scale(length(times), colors) } - - if (is_categorical) { pl <- plot_pdp_cat( pdp_dt = pdp_df, @@ -277,6 +274,7 @@ plot2.model_profile_survival <- function(x, ) } else { pdp_df[, 1] <- as.numeric(as.character(pdp_df[, 1])) + x_width <- diff(range(pdp_df[,variable])) pl <- plot_pdp_num( pdp_dt = pdp_df, ice_dt = ice_df, @@ -286,6 +284,7 @@ plot2.model_profile_survival <- function(x, y_ceiling_ice = y_ceiling_ice, y_floor_pd = y_floor_pd, y_ceiling_pd = y_ceiling_pd, + x_width = x_width, plot_type = plot_type, single_timepoint = single_timepoint, colors = color_scale @@ -308,6 +307,7 @@ plot_pdp_num <- function(pdp_dt, y_ceiling_ice, y_floor_pd, y_ceiling_pd, + x_width, plot_type, single_timepoint, colors) { @@ -316,7 +316,7 @@ plot_pdp_num <- function(pdp_dt, if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(alpha = 0.2, mapping = aes(group = id), color = colors_discrete_drwhy(1)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + ylim(y_floor_ice, y_ceiling_ice) } # PDP + ICE @@ -324,21 +324,21 @@ plot_pdp_num <- function(pdp_dt, ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(mapping = aes(group = id), alpha = 0.2) + geom_line(data = pdp_dt, aes(x = !!feature_name_sym, y = pd), linewidth = 2, color = colors_discrete_drwhy(1)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter", color = colors_discrete_drwhy(1)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + ylim(y_floor_ice, y_ceiling_ice) } # PDP else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + geom_line(color = colors_discrete_drwhy(1)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + ylim(y_floor_pd, y_ceiling_pd) } } else { ## multiple timepoints if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(alpha = 0.2, mapping = aes(group = interaction(id, time), color = time)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + scale_color_manual(name = "time", values = colors) + ylim(y_floor_ice, y_ceiling_ice) } @@ -348,7 +348,7 @@ plot_pdp_num <- function(pdp_dt, geom_line(data = ice_dt, aes(x = !!feature_name_sym, y = predictions, group = interaction(id, time), color = time), alpha = 0.1) + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, color = time), linewidth = 1.5, lineend = "round", linejoin = "round") + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, group = time), color = "black", linewidth = 0.5, linetype = "dashed", lineend = "round", linejoin = "round") + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = "jitter") + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + scale_color_manual(name = "time", values = colors) + ylim(y_floor_ice, y_ceiling_ice) } @@ -356,7 +356,7 @@ plot_pdp_num <- function(pdp_dt, else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + geom_line(aes(color = time)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = "jitter") + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + scale_color_manual(name = "time", values = colors) + ylim(y_floor_pd, y_ceiling_pd) } From bfe7de3eb02ea1e7a4a80d8811b3a7b62bc42a4f Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 11:20:03 +0200 Subject: [PATCH 088/207] remove repeated test of explainer --- R/surv_model_profiles.R | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index d2a18875..a8659369 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -103,13 +103,6 @@ surv_ale <- function(x, categorical_variables, grid_points, center = FALSE) { - test_explainer( - x, - has_data = TRUE, - has_survival = TRUE, - function_name = "surv_ale" - ) - if (is.null(variables)) variables <- colnames(data) @@ -127,15 +120,12 @@ surv_ale <- function(x, times <- x$times # Make predictions for original levels - predictions_original <- predict_survival_function(model = model, - newdata = data, - times = times) + predictions_original <- predict_survival_function(model, data, times) mean_pred <- colMeans(predictions_original) profiles <- lapply(variables, function(variable) { X_lower <- X_upper <- data variable_values <- data[, variable] - if (variable %in% categorical_variables) { if (!is.factor(variable_values)){ is_numeric <- is.numeric(variable_values) @@ -146,6 +136,7 @@ surv_ale <- function(x, } levels_original <- levels(droplevels(variable_values)) levels_n <- nlevels(droplevels(variable_values)) + if (inherits(variable_values, "ordered")) { level_order <- 1:levels_n } else { @@ -243,6 +234,7 @@ surv_ale <- function(x, if (!center){ ale_values$ale <- ale_values$ale + mean_pred } + return( data.frame( `_vname_` = variable, From ed9efadb4df1f5c8b8be7a2982e254d89da53bba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 4 Aug 2023 11:30:32 +0200 Subject: [PATCH 089/207] increase test coverage --- tests/testthat/test-model_profile.R | 10 ++++++++-- tests/testthat/test-model_profile_2d.R | 10 ++++++++++ tests/testthat/test-predict_profile.R | 6 ++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 73534a14..ec3be11e 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -42,6 +42,8 @@ test_that("model_profile with type = 'partial' works", { mp_rsf_cat <- model_profile(rsf_ranger_exp, output_type = "survival", variable_splits_type = "uniform", grid_points = 6) plot(mp_rsf_cat, variable_type = "categorical") + + plot(mp_cph_cat, mp_rsf_cat) ### Add tests for plot2 for categorical PDP # single timepoint plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp+ice") @@ -49,7 +51,7 @@ test_that("model_profile with type = 'partial' works", { plot2(mp_rsf_cat, variable = "celltype", plot_type = "ice") # multiple timepoints plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "pdp+ice") - plot2(mp_rsf_cat, times = c(4, 80.7), marginalilze_over_time = T, variable = "celltype", plot_type = "pdp+ice") + plot2(mp_rsf_cat, times = c(4, 80.7), marginalize_over_time = T, variable = "celltype", plot_type = "pdp+ice") plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "pdp") plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "ice") @@ -98,7 +100,7 @@ test_that("model_profile with type = 'accumulated' works", { ### Add tests for plot2 for categorical ALE # single timepoint - plot2(mp_cph_cat, variable = "celltype", plot_type = "ale") + plot2(mp_cph_cat, variable = "celltype") # multiple timepoints plot2(mp_cph_cat, times = c(4, 80.7), variable = "celltype", plot_type = "ale") @@ -106,6 +108,10 @@ test_that("model_profile with type = 'accumulated' works", { expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) expect_equal(ncol(mp_cph_cat$result), 7) expect_true(all(unique(mp_cph_cat$result$`_vname_`) %in% colnames(cph_exp$data))) + expect_error(plot2(mp_cph_cat, variable = "celltype", plot_type = "pdp")) + expect_error(plot2(mp_cph_cat, variable = "celltype", plot_type = "nonexistent")) + expect_error(plot2(mp_cph_cat, variable = 1, plot_type = "nonexistent")) + expect_error(plot2(mp_cph_cat, variable = c("celltype", "trt"), plot_type = "nonexistent")) mp_cph_num <- model_profile(cph_exp, diff --git a/tests/testthat/test-model_profile_2d.R b/tests/testthat/test-model_profile_2d.R index b25c6bce..71463aaa 100644 --- a/tests/testthat/test-model_profile_2d.R +++ b/tests/testthat/test-model_profile_2d.R @@ -14,6 +14,11 @@ test_that("model_profile_2d with type = 'partial' works", { c("karno", "age")), categorical_variables = "trt", grid_points = 6) + mp_small <- model_profile_2d(cph_exp, + variables = list(c("trt", "age")), + categorical_variables = 1, + grid_points = 2, + N = 2) plot(mp_cph_pdp) mp_rsf_pdp <- model_profile_2d(rsf_exp, @@ -58,6 +63,11 @@ test_that("model_profile_2d with type = 'accumulated' works", { output_type = "survival", type = "accumulated", center = FALSE) + + expect_error(model_profile_2d(rsf_exp, + type = "accumulated", + variables = list(c("karno", "celltype")))) + plot(mp_rsf_ale) plot(mp_rsf_ale, times=rsf_exp$times[1]) diff --git a/tests/testthat/test-predict_profile.R b/tests/testthat/test-predict_profile.R index 2e028ea2..a6884261 100644 --- a/tests/testthat/test-predict_profile.R +++ b/tests/testthat/test-predict_profile.R @@ -16,10 +16,12 @@ test_that("ceteris_paribus works", { plot(cph_pp, colors = c("#ff0000", "#00ff00", "#0000ff")) plot(cph_pp) plot(cph_pp, numerical_plot_type = "contours") - plot(cph_pp, ranger_pp, rug = "events") - plot(cph_pp, rug = "censors") + plot(cph_pp, ranger_pp, rug = "events", variables = c("karno", "age")) + plot(cph_pp, rug = "censors", variable_type = "numerical") plot(cph_pp, rug = "none") + expect_error(plot(cph_pp, variable_type = "nonexistent")) + expect_error(plot(cph_pp, numerical_plot_type = "nonexistent")) cph_pp_cat <- predict_profile(cph_exp, veteran[2, -c(3, 4)], variables = c("celltype")) plot(cph_pp_cat, variable_type = "categorical", colors = c("#ff0000", "#00ff00", "#0000ff")) From 33f7e940e53ed89cdd01c1cfb17a179165179945 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 12:07:36 +0200 Subject: [PATCH 090/207] add rms to suggests --- DESCRIPTION | 1 + tests/testthat/test-explain.R | 2 ++ 2 files changed, 3 insertions(+) diff --git a/DESCRIPTION b/DESCRIPTION index b394b836..69b28d39 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -42,6 +42,7 @@ Suggests: randomForestSRC, ranger, rmarkdown, + rms, testthat (>= 3.0.0), withr, xgboost diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-explain.R index 68e1962d..f14fade3 100644 --- a/tests/testthat/test-explain.R +++ b/tests/testthat/test-explain.R @@ -327,6 +327,8 @@ test_that("default methods for creating explainers work correctly", { ### rms::cph ### + + library(rms, quietly = TRUE) surv <- survival::Surv(veteran$time, veteran$status) cph <- rms::cph(surv ~ trt + celltype + karno + diagtime + age + prior, data = veteran, surv=TRUE, model=TRUE, x=TRUE, y=TRUE) From fbd49e2b4c43651b1a7f945c84cf6dbc9c8cb8f2 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 12:07:56 +0200 Subject: [PATCH 091/207] remove plot.feature_importance_explainer --- NAMESPACE | 2 +- R/plot_feature_importance.R | 136 ----------------------- man/plot.feature_importance_explainer.Rd | 67 ----------- 3 files changed, 1 insertion(+), 204 deletions(-) delete mode 100644 R/plot_feature_importance.R delete mode 100644 man/plot.feature_importance_explainer.Rd diff --git a/NAMESPACE b/NAMESPACE index b0ee1154..283de862 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -13,7 +13,6 @@ S3method(model_performance,surv_explainer) S3method(model_profile,default) S3method(model_profile,surv_explainer) S3method(model_profile_2d,surv_explainer) -S3method(plot,feature_importance_explainer) S3method(plot,model_parts_survival) S3method(plot,model_performance_survival) S3method(plot,model_profile_2d_survival) @@ -32,6 +31,7 @@ S3method(predict_parts,default) S3method(predict_parts,surv_explainer) S3method(predict_profile,default) S3method(predict_profile,surv_explainer) +S3method(print,model_profile_2d_survival) S3method(print,model_profile_survival) S3method(print,surv_ceteris_paribus) S3method(print,surv_feature_importance) diff --git a/R/plot_feature_importance.R b/R/plot_feature_importance.R deleted file mode 100644 index 0cd4c4a1..00000000 --- a/R/plot_feature_importance.R +++ /dev/null @@ -1,136 +0,0 @@ -#' Plots Feature Importance -#' -#' This function plots variable importance calculated as changes in the loss function after variable drops. -#' It uses output from \code{feature_importance} function that corresponds to -#' permutation based measure of variable importance. -#' Variables are sorted in the same order in all panels. -#' The order depends on the average drop out loss. -#' In different panels variable contributions may not look like sorted if variable -#' importance is different in different in different models. -#' -#' Find more details in the \href{https://ema.drwhy.ai/featureImportance.html}{Feature Importance Chapter}. -#' -#' @param x a feature importance explainer produced with the \code{feature_importance()} function -#' @param ... other explainers that shall be plotted together -#' @param max_vars maximum number of variables that shall be presented for for each model -#' By default \code{NULL} what means all variables -#' @param show_boxplots logical if \code{TRUE} (default) boxplot will be plotted to show permutation data. -#' @param bar_width width of bars. By default \code{10} -#' @param desc_sorting logical. Should the bars be sorted descending? By default TRUE -#' @param title the plot's title, by default \code{'Feature Importance'} -#' @param subtitle the plot's subtitle. By default - \code{NULL}, which means -#' the subtitle will be 'created for the XXX model', where XXX is the label of explainer(s) -#' -#' @importFrom stats model.frame reorder -#' @importFrom utils head tail -#' -#' @return a \code{ggplot2} object -#' -#' @references Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. \url{https://ema.drwhy.ai/} -#' -#' @examples -#' \donttest{ -#' library(survex) -#' library(randomForestSRC) -#' library(survival) -#' -#' model <- rfsrc(Surv(time, status) ~., data = veteran) -#' explainer <- explain(model) -#' -#' mp <- model_parts(explainer, loss = loss_one_minus_c_index, output_type = "risk") -#' plot(mp) -#' } -#' @export -plot.feature_importance_explainer <- function(x, ..., max_vars = NULL, show_boxplots = TRUE, bar_width = 10, - desc_sorting = TRUE, title = "Feature Importance", subtitle = "default") { - - if (!is.logical(desc_sorting)) { - stop("desc_sorting is not logical") - } - - dfl <- c(list(x), list(...)) - - # add boxplot data - if (show_boxplots) { - dfl <- lapply(dfl, function(x) { - result <- data.frame( - min = tapply(x$dropout_loss, x$variable, min, na.rm = TRUE), - q1 = tapply(x$dropout_loss, x$variable, quantile, 0.25, na.rm = TRUE), - median = tapply(x$dropout_loss, x$variable, median, na.rm = TRUE), - q3 = tapply(x$dropout_loss, x$variable, quantile, 0.75, na.rm = TRUE), - max = tapply(x$dropout_loss, x$variable, max, na.rm = TRUE) - ) - - result$min <- as.numeric(result$min) - result$q1 <- as.numeric(result$q1) - result$median <- as.numeric(result$median) - result$q3 <- as.numeric(result$q3) - result$max <- as.numeric(result$max) - - merge(x[x$permutation == 0,], cbind(rownames(result),result), by.x = "variable", by.y = "rownames(result)") - }) - } else { - dfl <- lapply(dfl, function(x) { - x[x$permutation == 0,] - }) - } - - # combine all explainers in a single frame - expl_df <- do.call(rbind, dfl) - - # add an additional column that serve as a baseline - bestFits <- expl_df[expl_df$variable == "_full_model_", ] - ext_expl_df <- merge(expl_df, bestFits[,c("label", "dropout_loss")], by = "label") - - # set the order of variables depending on their contribution - ext_expl_df$variable <- reorder(ext_expl_df$variable, - (ext_expl_df$dropout_loss.x - ext_expl_df$dropout_loss.y) * ifelse(desc_sorting, 1, -1), - mean) - - # remove rows that starts with _ - ext_expl_df <- ext_expl_df[!(substr(ext_expl_df$variable,1,1) == "_"),] - - # for each model leave only max_vars - if (!is.null(max_vars)) { - trimmed_parts <- lapply(unique(ext_expl_df$label), function(label) { - tmp <- ext_expl_df[ext_expl_df$label == label, ] - tmp[tail(order(tmp$dropout_loss.x), max_vars), ] - }) - ext_expl_df <- do.call(rbind, trimmed_parts) - } - - variable <- q1 <- q3 <- dropout_loss.x <- dropout_loss.y <- label <- dropout_loss <- NULL - nlabels <- length(unique(bestFits$label)) - - # extract labels for plot's subtitle - if (!is.null(subtitle) && subtitle == "default"){ - glm_labels <- paste0(unique(ext_expl_df$label), collapse = ", ") - subtitle <- paste0("created for the ", glm_labels, " model") - } - - # plot it - pl <- ggplot(ext_expl_df, aes(variable, ymin = dropout_loss.y, ymax = dropout_loss.x, color = label)) + - geom_hline(data = bestFits, aes(yintercept = dropout_loss, color = label), lty= 3) + - geom_linerange(linewidth = bar_width) - - if (show_boxplots) { - pl <- pl + - geom_boxplot(aes(ymin = min, lower = q1, middle = median, upper = q3, ymax = max), - stat = "identity", fill = "#371ea3", color = "#371ea3", width = 0.25) - } - - if (!is.null(attr(x, "loss_name"))) { - y_lab <- paste(attr(x, "loss_name")[1], "loss after permutations") - } else { - y_lab <- "Loss function after variable's permutations" - } - # facets have fixed space, can be resolved with ggforce https://github.com/tidyverse/ggplot2/issues/2933 - pl + coord_flip() + - scale_color_manual(values = DALEX::colors_discrete_drwhy(nlabels)) + - facet_wrap(~label, ncol = 1, scales = "free_y") + - theme_vertical_default_survex() + - ylab(y_lab) + xlab("") + - labs(title = title, subtitle = subtitle) + - theme(legend.position = "none") - -} diff --git a/man/plot.feature_importance_explainer.Rd b/man/plot.feature_importance_explainer.Rd deleted file mode 100644 index e01289af..00000000 --- a/man/plot.feature_importance_explainer.Rd +++ /dev/null @@ -1,67 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot_feature_importance.R -\name{plot.feature_importance_explainer} -\alias{plot.feature_importance_explainer} -\title{Plots Feature Importance} -\usage{ -\method{plot}{feature_importance_explainer}( - x, - ..., - max_vars = NULL, - show_boxplots = TRUE, - bar_width = 10, - desc_sorting = TRUE, - title = "Feature Importance", - subtitle = "default" -) -} -\arguments{ -\item{x}{a feature importance explainer produced with the \code{feature_importance()} function} - -\item{...}{other explainers that shall be plotted together} - -\item{max_vars}{maximum number of variables that shall be presented for for each model -By default \code{NULL} what means all variables} - -\item{show_boxplots}{logical if \code{TRUE} (default) boxplot will be plotted to show permutation data.} - -\item{bar_width}{width of bars. By default \code{10}} - -\item{desc_sorting}{logical. Should the bars be sorted descending? By default TRUE} - -\item{title}{the plot's title, by default \code{'Feature Importance'}} - -\item{subtitle}{the plot's subtitle. By default - \code{NULL}, which means -the subtitle will be 'created for the XXX model', where XXX is the label of explainer(s)} -} -\value{ -a \code{ggplot2} object -} -\description{ -This function plots variable importance calculated as changes in the loss function after variable drops. -It uses output from \code{feature_importance} function that corresponds to -permutation based measure of variable importance. -Variables are sorted in the same order in all panels. -The order depends on the average drop out loss. -In different panels variable contributions may not look like sorted if variable -importance is different in different in different models. -} -\details{ -Find more details in the \href{https://ema.drwhy.ai/featureImportance.html}{Feature Importance Chapter}. -} -\examples{ -\donttest{ -library(survex) -library(randomForestSRC) -library(survival) - -model <- rfsrc(Surv(time, status) ~., data = veteran) -explainer <- explain(model) - -mp <- model_parts(explainer, loss = loss_one_minus_c_index, output_type = "risk") -plot(mp) -} -} -\references{ -Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. \url{https://ema.drwhy.ai/} -} From b29a58250aa7ddb3764430b70578bcb706e95938 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 12:08:54 +0200 Subject: [PATCH 092/207] add print for model_profile_2d --- R/print.R | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/R/print.R b/R/print.R index 635c6c1e..c915b83c 100644 --- a/R/print.R +++ b/R/print.R @@ -9,6 +9,17 @@ print.model_profile_survival <- function(x, ...) { } +#' @export +print.model_profile_2d_survival <- function(x, ...) { + res <- x$result + method_name <- ifelse(x$type == "partial", "Partial dependence", "Accumulated local effects") + res <- res[order(res$`_v1name_`), ] + text <- paste0(method_name, " survival profiles for the ", unique(res$`_label_`), " model:\n") + cat(text) + print.data.frame(res[, !colnames(res) %in% c("_label_", "_right_", "_left_", "_top_", "_bottom_")], ...) +} + + #' @export print.surv_ceteris_paribus <- function(x, ...) { res <- x$result From 9dae4f2c55077abc0be6a36acdf916ba050173d4 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 12:09:00 +0200 Subject: [PATCH 093/207] fix tests --- tests/testthat/test-model_parts.R | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-model_parts.R b/tests/testthat/test-model_parts.R index 49eb32d7..55a84ca7 100644 --- a/tests/testthat/test-model_parts.R +++ b/tests/testthat/test-model_parts.R @@ -41,7 +41,7 @@ test_that("C-index fpi works", { expect_s3_class(mp_cph_cind, "model_parts") - cph_model_parts_dalex <- model_parts(cph_exp, loss_function = loss_one_minus_c_index, + cph_model_parts_dalex <- model_parts(cph_exp, loss = loss_one_minus_c_index, output_type = "risk", type = "raw") expect_true(all(cph_model_parts_dalex$dropout_loss <= 1)) @@ -52,10 +52,8 @@ test_that("C-index fpi works", { expect_error(model_parts(coxph_explainer, output_type = "nonexistent")) plot(cph_model_parts_dalex) - expect_error(plot(cph_model_parts_dalex, desc_sorting = "non-logical")) plot(cph_model_parts_dalex, show_boxplots = FALSE, max_vars = 2) - - + expect_error(plot(cph_model_parts_dalex, desc_sorting = "non-logical")) }) From 636d450973ec8d66f70451cbde2863e1e1637df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 4 Aug 2023 12:12:00 +0200 Subject: [PATCH 094/207] Improve test coverage --- R/model_profile.R | 1 - R/surv_lime.R | 1 - tests/testthat/test-model_profile.R | 4 ++++ tests/testthat/test-predict_parts.R | 7 +++++-- tests/testthat/test-utils.R | 5 +++-- 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/R/model_profile.R b/R/model_profile.R index cd04019b..ec60bff5 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -101,7 +101,6 @@ model_profile.surv_explainer <- function(explainer, grid_points = grid_points, variable_splits_type = variable_splits_type, ...) - result <- surv_aggregate_profiles(cp_profiles, ..., variables = variables, center = center) diff --git a/R/surv_lime.R b/R/surv_lime.R index e5c620b2..0d2fab36 100644 --- a/R/surv_lime.R +++ b/R/surv_lime.R @@ -117,7 +117,6 @@ generate_neighbourhood <- function(data_org, additional_categorical_variables <- categorical_variables factor_variables <- colnames(data_org)[sapply(data_org, is.factor)] categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) - if (is.null(ncol(data_row))) stop("The observation to be explained has to be data.frame") data_row <- data_row[colnames(data_org)] feature_frequencies <- list(length(categorical_variables)) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index ec3be11e..bd0557a3 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -18,6 +18,10 @@ test_that("model_profile with type = 'partial' works", { expect_equal(ncol(mp_cph_cat$result), 7) expect_true(all(unique(mp_cph_cat$result$`_vname_`) %in% colnames(cph_exp$data))) + mp_chosen_var <- model_profile(cph_exp, output_type = "survival", variable_splits_type = "quantiles", grid_points = 6, variables = "karno") + expect_s3_class(mp_chosen_var, "model_profile_survival") + expect_true(all(mp_chosen_var$eval_times == cph_exp$times)) + expect_equal(ncol(mp_chosen_var$result), 7) mp_cph_num <- model_profile(cph_exp, output_type = "survival", variable_splits_type = "quantiles", grid_points = 6) plot(mp_cph_num, variable_type = "numerical") diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index ecd7fe59..3f19fef7 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -63,8 +63,11 @@ test_that("survlime explanations work", { rsf_src_exp <- explain(rsf_src, verbose = FALSE) cph_survlime <- predict_parts(cph_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime") - ranger_survlime <- predict_parts(rsf_ranger_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime") - rsf_survlime <- predict_parts(rsf_src_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime") + ranger_survlime <- predict_parts(rsf_ranger_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime", sample_around_instance = FALSE) + rsf_survlime <- predict_parts(rsf_src_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime", categorical_variables = 1) + + # error on to few columns + expect_error(predict_parts(rsf_src_exp, new_observation = veteran[1, -c(1, 2 ,3, 4)], type = "survlime")) plot(cph_survlime, type = "coefficients") plot(cph_survlime, type = "local_importance") diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 8772be2d..17151b8a 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -98,10 +98,11 @@ test_that("setting theme works",{ old <- set_theme_survex("drwhy") plot(parts_cph) old <- set_theme_survex(ggplot2::theme_bw(), ggplot2::theme_bw()) - plot(parts_cph) + plot(parts_cph) + theme_vertical_default_survex() - expect_error(set_theme_survex(1,1)) + expect_error(set_theme_survex(1, 1)) expect_error(set_theme_survex("nonexistant")) + expect_error(set_theme_survex_vertical("ema", 5)) }) From e1c62ff32009153db0411f5808eb90ab02520983 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 12:15:12 +0200 Subject: [PATCH 095/207] fix order in namespace --- NAMESPACE | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index c2871f00..ecadd196 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,9 +12,9 @@ S3method(model_performance,default) S3method(model_performance,surv_explainer) S3method(model_profile,default) S3method(model_profile,surv_explainer) +S3method(model_profile_2d,surv_explainer) S3method(model_survshap,surv_explainer) S3method(plot,aggregated_surv_shap) -S3method(model_profile_2d,surv_explainer) S3method(plot,model_parts_survival) S3method(plot,model_performance_survival) S3method(plot,model_profile_2d_survival) @@ -64,8 +64,8 @@ export(loss_one_minus_integrated_cd_auc) export(model_parts) export(model_performance) export(model_profile) -export(model_survshap) export(model_profile_2d) +export(model_survshap) export(plot2) export(predict_parts) export(predict_profile) From 95a302ecb6ed0c7cdb93df1b12c8351660da0add Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 13:40:28 +0200 Subject: [PATCH 096/207] add description for plot aggregated survshap --- R/plot_surv_shap.R | 73 ++++++++++++++++++++++++++------ man/plot.aggregated_surv_shap.Rd | 65 ++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 12 deletions(-) create mode 100644 man/plot.aggregated_surv_shap.Rd diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index e0917d61..cfc901e8 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -3,7 +3,7 @@ #' This functions plots objects of class `surv_shap` - time-dependent explanations of #' survival models created using the `predict_parts(..., type="survshap")` function. #' -#' @param x an object of class `"surv_shap"` to be plotted +#' @param x an object of class `surv_shap` to be plotted #' @param ... additional objects of class `surv_shap` to be plotted together #' @param title character, title of the plot #' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels @@ -91,12 +91,63 @@ plot.surv_shap <- function(x, } - +#' Plot Aggregated SurvSHAP(t) Explanations for Survival Models +#' +#' This functions plots objects of class `aggregated_surv_shap` - aggregated time-dependent +#' explanations of survival models created using the `model_survshap()` function. +#' +#' @param x an object of class `aggregated_surv_shap` to be plotted +#' @param kind character, one of `"importance"`, `"swarm"`, or `"profile"`. Type of chart to be plotted: `"importance"` shows the importance of variables over time and aggregated, `"swarm"` shows the distribution of SurvSHAP(t) values for variables and observations, `"profile"` shows the dependence of SurvSHAP(t) values on variable values. +#' @param title character, title of the plot +#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for the XXX model (n = YYY)", where XXX is the explainer label and YYY is the number of observations used for calculations +#' @param max_vars maximum number of variables to be plotted (least important variables are ignored), by default 7 +#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") +#' +#' @return An object of the class `ggplot`. +#' +#' @section Plot options: +#' +#' ## `plot.aggregated_surv_shap(type = "importance")` +#' +#' * `rug` - character, one of `"all"`, `"events"`, `"censors"`, `"none"` or `NULL`. Which times to mark on the x axis in `geom_rug()`. +#' * `rug_colors` - character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. +#' * `xlab_left, ylab_right` - axis labels for left and right plots (due to different aggregation possibilities) +#' +#' +#' ## `plot.aggregated_surv_shap(type = "swarm")` +#' +#' * no additional options +#' +#' +#' ## `plot.aggregated_surv_shap(type = "swarm")` +#' +#' * `variable` - variable for which the profile is to be plotted, by default first from result data +#' * `color_variable` - variable used to denote the color, by default equal to `variable` +#' +#' +#' @examples +#' \donttest{ +#' library(survival) +#' library(survex) +#' +#' model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) +#' exp <- explain(model) +#' +#' p_parts_shap <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survshap") +#' plot(p_parts_shap) +#' } +#' #'@export plot.aggregated_surv_shap <- function(x, kind = "importance", ..., colors = NULL){ + if (is.null(colors)){ + colors <- c(low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3") + } + switch( kind, "importance" = plot_shap_global_importance(x = x, @@ -130,7 +181,7 @@ plot_shap_global_importance <- function(x, title = NULL, subtitle = NULL, max_vars = max_vars, - colors = colors, + colors = NULL, rug = rug, rug_colors = rug_colors) + labs(y = ylab_right) @@ -148,7 +199,7 @@ plot_shap_global_importance <- function(x, left_plot <- with(long_df, { ggplot(long_df, aes(x = values, y = reorder(ind, values))) + - geom_col(fill = "#46bac2") + + geom_col(fill = colors[2]) + theme_default_survex() + labs(x = xlab_left) + theme(axis.title.y = element_blank()) @@ -191,16 +242,15 @@ plot_shap_global_swarm <- function(x, "(n=", x$n_observations, ")" ) } - with(df, { ggplot(data = df, aes(x = shap_value, y = variable, color = var_value)) + geom_vline(xintercept = 0, color = "#ceced9", linetype="solid") + geom_jitter(width=0) + scale_color_gradient2( name = "Variable value", - low = "#9fe5bd", - mid = "#46bac2", - high = "#371ea3", + low = colors[1], + mid = colors[2], + high = colors[3], midpoint = 0.5, limits=c(0,1), breaks = c(0, 1), @@ -265,9 +315,9 @@ plot_shap_global_profile <- function(x, if (!is.factor(df$color_variable_val)) { p + scale_color_gradient2( name = paste(color_variable, "value"), - low = "#9fe5bd", - mid = "#46bac2", - high = "#371ea3", + low = colors[1], + mid = colors[2], + high = colors[3], midpoint = median(df$color_variable_val)) } else { p + scale_color_manual(name = paste(color_variable, "value"), @@ -279,7 +329,6 @@ preprocess_values_to_common_scale <- function(data) { # Scale numerical columns to range [0, 1] num_cols <- sapply(data, is.numeric) data[num_cols] <- lapply(data[num_cols], function(x) (x - min(x)) / (max(x) - min(x))) - # Map categorical columns to integers with even differences cat_cols <- sapply(data, function(x) !is.numeric(x) & is.factor(x)) data[cat_cols] <- lapply(data[cat_cols], function(x) { diff --git a/man/plot.aggregated_surv_shap.Rd b/man/plot.aggregated_surv_shap.Rd new file mode 100644 index 00000000..fabbc8d9 --- /dev/null +++ b/man/plot.aggregated_surv_shap.Rd @@ -0,0 +1,65 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot_surv_shap.R +\name{plot.aggregated_surv_shap} +\alias{plot.aggregated_surv_shap} +\title{Plot Aggregated SurvSHAP(t) Explanations for Survival Models} +\usage{ +\method{plot}{aggregated_surv_shap}(x, kind = "importance", ..., colors = NULL) +} +\arguments{ +\item{x}{an object of class \code{aggregated_surv_shap} to be plotted} + +\item{kind}{character, one of \code{"importance"}, \code{"swarm"}, or \code{"profile"}. Type of chart to be plotted: \code{"importance"} shows the importance of variables over time and aggregated, \code{"swarm"} shows the distribution of SurvSHAP(t) values for variables and observations, \code{"profile"} shows the dependence of SurvSHAP(t) values on variable values.} + +\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} + +\item{title}{character, title of the plot} + +\item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for the XXX model (n = YYY)", where XXX is the explainer label and YYY is the number of observations used for calculations} + +\item{max_vars}{maximum number of variables to be plotted (least important variables are ignored), by default 7} +} +\value{ +An object of the class \code{ggplot}. +} +\description{ +This functions plots objects of class \code{aggregated_surv_shap} - aggregated time-dependent +explanations of survival models created using the \code{model_survshap()} function. +} +\section{Plot options}{ + +\subsection{\code{plot.aggregated_surv_shap(type = "importance")}}{ +\itemize{ +\item \code{rug} - character, one of \code{"all"}, \code{"events"}, \code{"censors"}, \code{"none"} or \code{NULL}. Which times to mark on the x axis in \code{geom_rug()}. +\item \code{rug_colors} - character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. +\item \verb{xlab_left, ylab_right} - axis labels for left and right plots (due to different aggregation possibilities) +} +} + +\subsection{\code{plot.aggregated_surv_shap(type = "swarm")}}{ +\itemize{ +\item no additional options +} +} + +\subsection{\code{plot.aggregated_surv_shap(type = "swarm")}}{ +\itemize{ +\item \code{variable} - variable for which the profile is to be plotted, by default first from result data +\item \code{color_variable} - variable used to denote the color, by default equal to \code{variable} +} +} +} + +\examples{ +\donttest{ +library(survival) +library(survex) + +model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) +exp <- explain(model) + +p_parts_shap <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survshap") +plot(p_parts_shap) +} + +} From 33ff6d56f48b96938481b63b46e23d768ae38b5c Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 13:41:15 +0200 Subject: [PATCH 097/207] fix description, remove quotation marks --- man/plot.surv_shap.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/plot.surv_shap.Rd b/man/plot.surv_shap.Rd index bc401943..78b55cc8 100644 --- a/man/plot.surv_shap.Rd +++ b/man/plot.surv_shap.Rd @@ -16,7 +16,7 @@ ) } \arguments{ -\item{x}{an object of class \code{"surv_shap"} to be plotted} +\item{x}{an object of class \code{surv_shap} to be plotted} \item{...}{additional objects of class \code{surv_shap} to be plotted together} From 9cc0581d39d28f014256d7d37e5bdc1757b5c722 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 13:41:32 +0200 Subject: [PATCH 098/207] change dataset for testing --- tests/testthat/test-model_survshap.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index c86f2163..7d74e164 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -16,8 +16,8 @@ test_that("global survshap explanations with kernelshap work for ranger, using n ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, - new_observation = veteran[1:10, !colnames(veteran) %in% c("time", "status")], - y_true = Surv(veteran$time[1:10], veteran$status[1:10]), + new_observation = veteran[c(1:3, 16:18, 111:113, 126:128), !colnames(veteran) %in% c("time", "status")], + y_true = Surv(veteran$time[c(1:3, 16:18, 111:113, 126:128)], veteran$status[c(1:3, 16:18, 111:113, 126:128)]), aggregation_method = "mean_absolute", calculation_method = "kernelshap" ) From da559c3bed2b953eb8e3a128658145fb79874c30 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 13:42:05 +0200 Subject: [PATCH 099/207] fix when too few colors in provided vector --- R/utils.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/utils.R b/R/utils.R index 9a1bc46c..acb9f5b5 100644 --- a/R/utils.R +++ b/R/utils.R @@ -71,7 +71,7 @@ test_explainer <- function(explainer, #' @importFrom DALEX colors_discrete_drwhy generate_discrete_color_scale <- function(n, colors = NULL) { - if (is.null(colors)) return(colors_discrete_drwhy(n)) + if (is.null(colors) || length(colors) < n) return(colors_discrete_drwhy(n)) else return(colors[(0:(n - 1) %% length(colors)) + 1]) } From 1c7d84b8a8e563ce0faed6c0661cb6980812c902 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 13:42:30 +0200 Subject: [PATCH 100/207] new reference structure --- _pkgdown.yml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/_pkgdown.yml b/_pkgdown.yml index 89ea8bd6..1ae967af 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -7,7 +7,7 @@ reference: desc: Create a model agnostic explainer for further processing - contents: starts_with("explain") - title: Dataset level explanations -- contents: starts_with("model_p") +- contents: starts_with("model_") - title: Instance level explanations - contents: starts_with("predict") - title: Metrics and loss functions @@ -28,9 +28,13 @@ reference: - contents: - plot.model_parts_survival - plot.surv_feature_importance - - plot.feature_importance_explainer - subtitle: Model Profile -- contents: plot.model_profile_survival +- contents: + - plot.model_profile_survival + - plot2.model_profile_survival + - plot.model_profile_2d_survival +- subtitle: Model SurvSHAP(t) +- contents: plot.aggregated_surv_shap - subtitle: Predict Parts - contents: - plot.predict_parts_survival From ac9406416c153dfc87e35de980454c45473bb80c Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 13:42:37 +0200 Subject: [PATCH 101/207] bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index e524bc37..57ba6d7c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.0.0.9002 +Version: 1.0.0.9100 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), From aa581d475f0c0ab1f5a19237cf610af95b1c3570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 4 Aug 2023 13:49:08 +0200 Subject: [PATCH 102/207] Increase test coverage --- R/explain.R | 24 ++--- R/surv_feature_importance.R | 7 +- R/surv_integrated_feature_importance.R | 5 +- tests/testthat/test-explain.R | 137 +++++++++++++++++++++++- tests/testthat/test-model_parts.R | 17 ++- tests/testthat/test-model_performance.R | 2 + tests/testthat/test-model_profile.R | 2 +- tests/testthat/test-model_survshap.R | 2 +- tests/testthat/test-utils.R | 2 +- 9 files changed, 167 insertions(+), 31 deletions(-) diff --git a/R/explain.R b/R/explain.R index 3fdc7270..785ddc87 100644 --- a/R/explain.R +++ b/R/explain.R @@ -126,9 +126,7 @@ explain_survival <- if (is.null(predict_survival_function) && !is.null(predict_cumulative_hazard_function)) { - predict_survival_function <- function(model, newdata, times) { - cumulative_hazard_to_survival(predict_cumulative_hazard_function(model, newdata, times)) - } + predict_survival_function <- function(model, newdata, times) cumulative_hazard_to_survival(predict_cumulative_hazard_function(model, newdata, times)) attr(predict_survival_function, "verbose_info") <- "exp(-predict_cumulative_hazard_function) will be used" attr(predict_survival_function, "is.default") <- TRUE } @@ -136,9 +134,7 @@ explain_survival <- if (is.null(predict_cumulative_hazard_function) && !is.null(predict_survival_function)) { predict_cumulative_hazard_function <- - function(model, newdata, times) { - survival_to_cumulative_hazard(predict_survival_function(model, newdata, times)) - } + function(model, newdata, times) survival_to_cumulative_hazard(predict_survival_function(model, newdata, times)) attr(predict_cumulative_hazard_function, "verbose_info") <- "-log(predict_survival_function) will be used" attr(predict_cumulative_hazard_function, "is.default") <- TRUE } @@ -250,9 +246,7 @@ explain_survival <- # verbose predict function if (is.null(predict_function)) { if (!is.null(predict_cumulative_hazard_function)) { - predict_function <- function(model, newdata) { - risk_from_chf(predict_cumulative_hazard_function(model, newdata, times = times)) - } + predict_function <- function(model, newdata) risk_from_chf(predict_cumulative_hazard_function(model, newdata, times = times)) verbose_cat(" -> predict function : ", "sum over the predict_cumulative_hazard_function will be used", is.default = TRUE, verbose = verbose) } else { verbose_cat(" -> predict function : not specified! (", color_codes$red_start, "WARNING", color_codes$red_end, ")", verbose = verbose) @@ -726,9 +720,7 @@ explain.model_fit <- if (is.null(predict_cumulative_hazard_function)) { predict_cumulative_hazard_function <- - function(object, newdata, times) { - survival_to_cumulative_hazard(predict_survival_function(object, newdata, times)) - } + function(object, newdata, times) survival_to_cumulative_hazard(predict_survival_function(object, newdata, times)) attr(predict_cumulative_hazard_function, "verbose_info") <- "-log(predict_survival_function) will be used" attr(predict_cumulative_hazard_function, "is.default") <- TRUE } else { @@ -737,14 +729,10 @@ explain.model_fit <- if (is.null(predict_function)) { if (model$spec$engine %in% c("mboost", "survival", "glmnet", "flexsurv")){ - predict_function <- function(model, newdata, times) { - predict(model, new_data = newdata, type = "linear_pred")$.pred_linear_pred - } + predict_function <- function(model, newdata, times) predict(model, new_data = newdata, type = "linear_pred")$.pred_linear_pred attr(predict_function, "verbose_info") <- "predict.model_fit with type = 'linear_pred' will be used" } else { - predict_function <- function(model, newdata, times) { - rowSums(predict_cumulative_hazard_function(model, newdata, times = times)) - } + predict_function <- function(model, newdata, times) rowSums(predict_cumulative_hazard_function(model, newdata, times = times)) attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" } attr(predict_function, "use.times") <- TRUE diff --git a/R/surv_feature_importance.R b/R/surv_feature_importance.R index 2c00be11..9e156e71 100644 --- a/R/surv_feature_importance.R +++ b/R/surv_feature_importance.R @@ -35,9 +35,8 @@ surv_feature_importance.surv_explainer <- function(x, variable_groups = NULL, N = NULL, label = NULL) { - if (is.null(x$data)) stop("The feature_importance() function requires explainers created with specified 'data' parameter.") - if (is.null(x$y)) stop("The feature_importance() function requires explainers created with specified 'y' parameter.") - if (is.null(x$predict_survival_function)) stop("The feature_importance() function requires explainers created with specified 'predict_survival_function' parameter.") + + test_explainer(x, "feature_importance", has_data = TRUE, has_y = TRUE, has_survival = TRUE) model <- x$model data <- x$data @@ -94,8 +93,8 @@ surv_feature_importance.default <- function(x, all(variable_set %in% colnames(data)) })) - if (wrong_names) stop("You have passed wrong variables names in variable_groups argument") if (!all(sapply(variable_groups, class) == "character")) stop("Elements of variable_groups argument should be of class character") + if (wrong_names) stop("You have passed wrong variables names in variable_groups argument") if (is.null(names(variable_groups))) warning("You have passed an unnamed list. The names of variable groupings will be created from variables names.") } diff --git a/R/surv_integrated_feature_importance.R b/R/surv_integrated_feature_importance.R index 7e956d67..8630a371 100644 --- a/R/surv_integrated_feature_importance.R +++ b/R/surv_integrated_feature_importance.R @@ -31,8 +31,7 @@ surv_integrated_feature_importance <- function(x, N = NULL, label = NULL) { - if (is.null(x$data)) stop("The feature_importance() function requires explainers created with specified 'data' parameter.") - if (is.null(x$y)) stop("The feature_importance() function requires explainers created with specified 'y' parameter.") + test_explainer(x, "feature_importance", has_data = TRUE, has_y = TRUE) # extracts model, data and predict function from the explainer explainer <- x @@ -54,8 +53,8 @@ surv_integrated_feature_importance <- function(x, all(variable_set %in% colnames(data)) })) - if (wrong_names) stop("You have passed wrong variables names in variable_groups argument") if (!all(sapply(variable_groups, class) == "character")) stop("Elements of variable_groups argument should be of class character") + if (wrong_names) stop("You have passed wrong variables names in variable_groups argument") if (is.null(names(variable_groups))) warning("You have passed an unnamed list. The names of variable groupings will be created from variables names.") } type <- match.arg(type) diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-explain.R index f14fade3..44438744 100644 --- a/tests/testthat/test-explain.R +++ b/tests/testthat/test-explain.R @@ -9,7 +9,23 @@ test_that("coxph prediction functions work correctly", { x = TRUE, y = TRUE ) + + cox_wrong <- survival::coxph( + survival::Surv(rtime, recur) ~ ., + data = rotterdam[, !colnames(rotterdam) %in% c("year", "dtime", "death")] + ) + + expect_error(explain(cox_wrong, verbose = FALSE)) + + cox_wrong_2 <- survival::coxph( + survival::Surv(rtime, recur) ~ ., + data = rotterdam[, !colnames(rotterdam) %in% c("year", "dtime", "death")], + model = TRUE, + x = TRUE, + y = FALSE + ) + expect_error(explain(cox_wrong_2, verbose = FALSE)) coxph_explainer <- explain(cox_rotterdam_rec, @@ -45,6 +61,20 @@ test_that("coxph prediction functions work correctly", { coxph_explainer$predict_function(cox_rotterdam_rec, rotterdam[4, ]), predict(cox_rotterdam_rec, rotterdam[4, ], type = "risk") ) + + # test manually setting predict survival function / chf / predict function + explain(cox_rotterdam_rec, + predict_survival_function = pec::predictSurvProb, + verbose = FALSE) + + explain(cox_rotterdam_rec, + predict_cumulative_hazard_function = pec::predictSurvProb, + verbose = FALSE) + + explain(cox_rotterdam_rec, + predict_function = predict, + verbose = FALSE) + }) test_that("ranger prediction functions work correctly", { @@ -146,6 +176,27 @@ test_that("ranger prediction functions work correctly", { rowSums(return_matrix) }) + # test manually setting predict survival function / chf / predict function + # the functions DO NOT WORK this is just a test if everything is set properly + explain(rsf_rotterdam_rec, + data = rotterdam[, !colnames(rotterdam) %in% c("year", "dtime", "death", "rtime", "recur")], + y = survival::Surv(rotterdam$rtime, rotterdam$recur), + predict_survival_function = pec::predictSurvProb, + verbose = FALSE) + + + explain(rsf_rotterdam_rec, + data = rotterdam[, !colnames(rotterdam) %in% c("year", "dtime", "death", "rtime", "recur")], + y = survival::Surv(rotterdam$rtime, rotterdam$recur), + predict_cumulative_hazard_function = pec::predictSurvProb, + verbose = FALSE) + + explain(rsf_rotterdam_rec, + data = rotterdam[, !colnames(rotterdam) %in% c("year", "dtime", "death", "rtime", "recur")], + y = survival::Surv(rotterdam$rtime, rotterdam$recur), + predict_function = predict, + verbose = FALSE) + }) @@ -239,6 +290,25 @@ test_that("rsfrc prediction functions work correctly", { rowSums(return_matrix) }) + + # test manually setting predict survival function / chf / predict function + # the functions DO NOT WORK this is just a test if everything is set properly + explain(rsf_colon_rec, + y = survival::Surv(colon$time, colon$status), + predict_survival_function = pec::predictSurvProb, + verbose = FALSE) + + explain(rsf_colon_rec, + y = survival::Surv(colon$time, colon$status), + predict_cumulative_hazard_function = pec::predictSurvProb, + verbose = FALSE) + + explain(rsf_colon_rec, + y = survival::Surv(colon$time, colon$status), + predict_function = predict, + verbose = FALSE) + + }) @@ -328,7 +398,7 @@ test_that("default methods for creating explainers work correctly", { ### rms::cph ### - library(rms, quietly = TRUE) + # library(rms, quietly = TRUE) surv <- survival::Surv(veteran$time, veteran$status) cph <- rms::cph(surv ~ trt + celltype + karno + diagtime + age + prior, data = veteran, surv=TRUE, model=TRUE, x=TRUE, y=TRUE) @@ -348,6 +418,29 @@ test_that("default methods for creating explainers work correctly", { bt_exp <- explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) expect_s3_class(bt_exp, c("surv_explainer", "explainer")) expect_equal(bt_exp$label, "model_fit_blackboost", ignore_attr = TRUE) + predict(bt_exp) + predict(bt_exp, output_type = "chf") + predict(bt_exp, output_type = "risk") + + + rf_spec <- parsnip::rand_forest(trees = 200) %>% + parsnip::set_engine("partykit") %>% + parsnip::set_mode("censored regression") %>% + generics::fit(survival::Surv(time, status) ~ ., data = veteran) + + exp <- explain(rf_spec, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + + predict(exp) + predict(exp, output_type = "chf") + predict(exp, output_type = "risk") + + # test manually setting predict survival function / chf / predict function + # the functions DO NOT WORK this is just a test if everything is set properly + explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), predict_survival_function = pec::predictSurvProb, verbose = FALSE) + explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), predict_cumulative_hazard_function = pec::predictSurvProb, verbose = FALSE) + explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), predict_function = predict, verbose = FALSE) + + detach("package:censored", unload = TRUE) ### explain.default ### @@ -370,6 +463,10 @@ test_that("warnings in explain_survival work correctly", { cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph, verbose = FALSE, colorize = FALSE) + tmp_data <- veteran[, -c(3, 4)] + class(tmp_data) <- c("tbl", class(tmp_data)) + + expect_warning(explain_survival(cph, data = veteran, survival::Surv(veteran$time, veteran$status), @@ -413,6 +510,44 @@ test_that("warnings in explain_survival work correctly", { predict_survival_function = pec::predictSurvProb, predict_cumulative_hazard_function = "", verbose = FALSE)) + + expect_warning(explain_survival(cph, + data = tmp_data, + survival::Surv(veteran$time, veteran$status), + times = c(1,2,3), + predict_survival_function = pec::predictSurvProb, + predict_cumulative_hazard_function = "", + verbose = FALSE)) + + expect_warning(explain_survival(cph, + data = veteran, + survival::Surv(veteran$time, veteran$status), + times_generation = "uniform", + predict_survival_function = pec::predictSurvProb, + predict_cumulative_hazard_function = "", + verbose = FALSE)) + + custom_info <- list(ver = "1.0", + package = "custom", + type = "typee") + + expect_warning(explain_survival(cph, + data = veteran, + survival::Surv(veteran$time, veteran$status), + times_generation = "uniform", + predict_survival_function = pec::predictSurvProb, + predict_cumulative_hazard_function = "", + verbose = FALSE, + type = "weird type", + model_info = custom_info)) + + expect_error(explain_survival(cph, + data = veteran, + survival::Surv(veteran$time, veteran$status), + times_generation = "nonexistent", + predict_survival_function = pec::predictSurvProb, + predict_cumulative_hazard_function = "", + verbose = FALSE)) }) test_that("default method for creating explainers for mlr3proba works correctly", { diff --git a/tests/testthat/test-model_parts.R b/tests/testthat/test-model_parts.R index 55a84ca7..101c9ca0 100644 --- a/tests/testthat/test-model_parts.R +++ b/tests/testthat/test-model_parts.R @@ -93,7 +93,7 @@ test_that("Brier score fpi works", { ### groups - mp_groups_1 <- model_parts(cph_exp, type = "ratio", variable_groups = list(group1 = c("celltype", "trt"), group2 = c("age", "prior"))) + mp_groups_1 <- model_parts(cph_exp, type = "ratio", variable_groups = list(group1 = c("celltype", "trt"), group2 = c("age", "prior")), N= 10) mp_groups_2 <- model_parts(cph_exp, type = "difference", variable_groups = list(group1 = c("celltype", "trt"), group2 = c("age", "prior")), N = 70) plot(mp_groups_2) @@ -101,6 +101,9 @@ test_that("Brier score fpi works", { expect_error(model_parts(cph_exp, variable_groups = list(group1 = c("sss", "ss"), group2 = c("f", "f")))) expect_warning(model_parts(cph_exp, variable_groups = list(c("celltype", "trt"), c("age", "prior")))) + expect_error(model_parts(cph_exp, B = 10, type = "raw", variable_groups = "not a list")) + expect_error(model_parts(cph_exp, B = 10, type = "raw", variable_groups = list(group1 = c(1, 2)))) + }) @@ -146,7 +149,7 @@ test_that("integrated metrics fpi works", { rsf_src_exp <- explain(rsf_src, verbose = FALSE) # auc - cph_model_parts_int_auc <- model_parts(cph_exp, loss = loss_one_minus_integrated_cd_auc, B = 1, type = "raw") + cph_model_parts_int_auc <- model_parts(cph_exp, loss = loss_one_minus_integrated_cd_auc, B = 1, type = "raw", N = 10) rsf_ranger_model_parts_int_auc <- model_parts(rsf_ranger_exp, loss = loss_one_minus_integrated_cd_auc, B = 1, type = "raw") expect_equal(nrow(cph_model_parts_int_auc[cph_model_parts_int_auc$permutation == 0, ]), ncol(cph_exp$data) + 2) @@ -191,4 +194,14 @@ test_that("integrated metrics fpi works", { mp_groups_2 <- model_parts(cph_exp, loss = loss_integrated_brier_score, type = "difference", variable_groups = list(group1 = c("celltype", "trt"), group2 = c("age", "prior")), N = 70) plot(mp_groups_2) + expect_error(model_parts(rsf_ranger_exp, loss = loss_integrated_brier_score, B = 10, type = "raw", variable_groups = "not a list")) + expect_error(model_parts(rsf_ranger_exp, loss = loss_integrated_brier_score, B = 10, type = "raw", variable_groups = list(group1 = c(1, 2), group2 = c("f", "f")))) + + rsf_ranger_exp$y <- NULL + expect_error(model_parts(rsf_ranger_exp, loss = loss_integrated_brier_score, B = 10, type = "raw")) + + rsf_ranger_exp$data <- NULL + expect_error(model_parts(rsf_ranger_exp, loss = loss_integrated_brier_score, B = 10, type = "raw")) + + }) diff --git a/tests/testthat/test-model_performance.R b/tests/testthat/test-model_performance.R index f09c9b2b..eeae53e6 100644 --- a/tests/testthat/test-model_performance.R +++ b/tests/testthat/test-model_performance.R @@ -51,4 +51,6 @@ test_that("model_performance works", { plot(cph_rot_perf_roc) plot(rsf_rot_perf_roc) + expect_error(model_performance(rsf_exp_rot, type = "roc", times = NULL)) + }) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index bd0557a3..3988b7b0 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -10,7 +10,7 @@ test_that("model_profile with type = 'partial' works", { rsf_src_exp <- explain(rsf_src, verbose = FALSE) - mp_cph_cat <- model_profile(cph_exp, output_type = "survival", variable_splits_type = "quantiles", grid_points = 6) + mp_cph_cat <- model_profile(cph_exp, output_type = "survival", variable_splits_type = "quantiles", grid_points = 6, N = 4) plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") expect_s3_class(mp_cph_cat, "model_profile_survival") diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index c86f2163..630e531c 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -26,7 +26,7 @@ test_that("global survshap explanations with kernelshap work for ranger, using n plot(ranger_global_survshap, kind = "profile") plot(ranger_global_survshap, kind = "profile", variable = "karno", color_variable = "celltype") plot(ranger_global_survshap, kind = "profile", variable = "karno", color_variable = "age") - + expect_error(plot(ranger_global_survshap, kind = "nonexistent")) expect_s3_class(ranger_global_survshap, c("aggregated_surv_shap", "surv_shap")) expect_equal(length(ranger_global_survshap$eval_times), length(rsf_ranger_exp$times)) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 17151b8a..0af02af5 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -102,7 +102,7 @@ test_that("setting theme works",{ expect_error(set_theme_survex(1, 1)) expect_error(set_theme_survex("nonexistant")) - expect_error(set_theme_survex_vertical("ema", 5)) + expect_error(set_theme_survex("ema", 5)) }) From 42aa8abb395ad3fb525586413073fa6660280d00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 4 Aug 2023 13:57:44 +0200 Subject: [PATCH 103/207] Add better subtitle for single time --- R/plot_model_profile_survival.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index fefed694..926fb254 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -195,6 +195,9 @@ plot2.model_profile_survival <- function(x, if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", unique(variable), " variable") + if (single_timepoint){ + subtitle <- paste0(subtitle, " and time=", times) + } } single_timepoint <- ((length(times) == 1) || marginalize_over_time) From 68ee0e62824d5adce2a5a24992e16f98d5c2712c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 4 Aug 2023 14:06:09 +0200 Subject: [PATCH 104/207] Fix test --- tests/testthat/test-explain.R | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-explain.R index 44438744..a702c2c0 100644 --- a/tests/testthat/test-explain.R +++ b/tests/testthat/test-explain.R @@ -422,18 +422,6 @@ test_that("default methods for creating explainers work correctly", { predict(bt_exp, output_type = "chf") predict(bt_exp, output_type = "risk") - - rf_spec <- parsnip::rand_forest(trees = 200) %>% - parsnip::set_engine("partykit") %>% - parsnip::set_mode("censored regression") %>% - generics::fit(survival::Surv(time, status) ~ ., data = veteran) - - exp <- explain(rf_spec, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) - - predict(exp) - predict(exp, output_type = "chf") - predict(exp, output_type = "risk") - # test manually setting predict survival function / chf / predict function # the functions DO NOT WORK this is just a test if everything is set properly explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), predict_survival_function = pec::predictSurvProb, verbose = FALSE) From 2cb9a596154c1fb1fa3793698c965938e5217d5b Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 14:48:32 +0200 Subject: [PATCH 105/207] change example to show ale --- R/model_profile.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/model_profile.R b/R/model_profile.R index ec60bff5..5f7be226 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -40,7 +40,8 @@ #' plot(cph_model_profile) #' #' rsf_model_profile <- model_profile(rsf_src_exp, output_type = "survival", -#' variables = c("age", "celltype")) +#' variables = c("age", "celltype"), +#' type = "accumulated") #' #' head(rsf_model_profile$result) #' From 60fb27ca1bbc1529ee5ef10dea2524598c764663 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 14:50:18 +0200 Subject: [PATCH 106/207] fix example and binding results for different types of variables --- R/model_profile_2d.R | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index ab7d159c..3b786a8d 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -23,21 +23,18 @@ #' library(survex) #' #' cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) -#' rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) -#' #' cph_exp <- explain(cph) -#' rsf_src_exp <- explain(rsf_src) #' #' cph_model_profile_2d <- model_profile_2d(cph_exp, #' variables = list(c("age", "celltype"))) #' head(cph_model_profile_2d$result) #' plot(cph_model_profile_2d) #' -#' rsf_model_profile_2d <- model_profile_2d(rsf_src_exp, +#' cph_model_profile_2d_ale <- model_profile_2d(cph_exp, #' variables = list(c("age", "karno")), #' type = "accumulated") -#' head(rsf_model_profile_2d$result) -#' plot(rsf_model_profile_2d) +#' head(cph_model_profile_2d_ale$result) +#' plot(cph_model_profile_2d_ale) #' } #' #' @rdname model_profile_2d.surv_explainer @@ -159,8 +156,8 @@ surv_pdp_2d <- function(x, "_v2name_" = var2, "_v1type_" = ifelse(var1 %in% categorical_variables, "categorical", "numerical"), "_v2type_" = ifelse(var2 %in% categorical_variables, "categorical", "numerical"), - "_v1value_" = rep(expanded_data[,var1], each=length(times)), - "_v2value_" = rep(expanded_data[,var2], each=length(times)), + "_v1value_" = as.character(rep(expanded_data[,var1], each=length(times))), + "_v2value_" = as.character(rep(expanded_data[,var2], each=length(times))), "_times_" = rep(times, nrow(expanded_data)), "_yhat_" = c(t(predictions)), "_label_" = label, From 4f6174e626c0cda49ef71b5309a66bf9b1e1dcd5 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 14:51:04 +0200 Subject: [PATCH 107/207] add description and warnings for marginalization and times --- R/plot_model_profile_2d.R | 47 +++++++++++++++++++- man/plot.model_profile_2d_survival.Rd | 64 +++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 man/plot.model_profile_2d_survival.Rd diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index 7ec1000a..6ea739fb 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -1,3 +1,41 @@ +#' Plot 2-Dimensional Model Profile for Survival Models +#' +#' This function plots objects of class `"model_profile_2d_survival"` created +#' using the `model_profile_2d()` function. +#' +#' @param x an object of class `model_profile_2d_survival` to be plotted +#' @param ... additional objects of class `model_profile_2d_survival` to be plotted together +#' @param variables list of character vectors of length 2, names of pairs of variables to be plotted +#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If `NULL` (default) then the median time from the explainer object is used. +#' @param facet_ncol number of columns for arranging subplots +#' @param title character, title of the plot. `'default'` automatically generates either "2D partial dependence survival profiles" or "2D accumulated local effects survival profiles" depending on the explanation type. +#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for the XXX model", where XXX is the explainer labels, if `marginalize_over_time = FALSE`, time is also added to the subtitle +#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") +#' +#' @return A collection of `ggplot` objects arranged with the `patchwork` package. +#' +#' +#' @examples +#' \donttest{ +#' library(survival) +#' library(survex) +#' +#' cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) +#' cph_exp <- explain(cph) +#' +#' cph_model_profile_2d <- model_profile_2d(cph_exp, +#' variables = list(c("age", "celltype"), +#' c("age", "karno"))) +#' head(cph_model_profile_2d$result) +#' plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 88.5) +#' +#' cph_model_profile_2d_ale <- model_profile_2d(cph_exp, +#' variables = list(c("age", "karno")), +#' type = "accumulated") +#' head(cph_model_profile_2d_ale$result) +#' plot(cph_model_profile_2d_ale, times = c(4, 88.5), marginalize_over_time = TRUE) +#' } +#' #' @export plot.model_profile_2d_survival <- function(x, ..., @@ -77,10 +115,12 @@ prepare_model_profile_2d_plots <- function(x, ){ if (is.null(times)) { times <- quantile(x$eval_times, p = 0.5, type = 1) + warning("Plot will be prepared for the median time point from the `times` vector. For another time point, set the value of `times`.") } - if (!marginalize_over_time) { + if (!marginalize_over_time && length(times) > 1) { times <- times[1] + warning("Plot will be prepared for the first time point in the `times` vector. For aggregation over time, set the option `marginalize_over_time = TRUE`.") } if (!all(times %in% x$eval_times)) { @@ -93,7 +133,10 @@ prepare_model_profile_2d_plots <- function(x, all_profiles <- x$result df_time <- all_profiles[all_profiles$`_times_` %in% times, ] - + df_time$`_times_` <- NULL + if (marginalize_over_time){ + df_time <- aggregate(`_yhat_`~., data=df_time, FUN=mean) + } sf_range <- range(df_time$`_yhat_`) pl <- lapply(seq_along(variables), function(i){ diff --git a/man/plot.model_profile_2d_survival.Rd b/man/plot.model_profile_2d_survival.Rd new file mode 100644 index 00000000..8c87c322 --- /dev/null +++ b/man/plot.model_profile_2d_survival.Rd @@ -0,0 +1,64 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot_model_profile_2d.R +\name{plot.model_profile_2d_survival} +\alias{plot.model_profile_2d_survival} +\title{Plot 2-Dimensional Model Profile for Survival Models} +\usage{ +\method{plot}{model_profile_2d_survival}( + x, + ..., + variables = NULL, + times = NULL, + marginalize_over_time = FALSE, + facet_ncol = NULL, + title = "default", + subtitle = "default", + colors = NULL +) +} +\arguments{ +\item{x}{an object of class \code{model_profile_2d_survival} to be plotted} + +\item{...}{additional objects of class \code{model_profile_2d_survival} to be plotted together} + +\item{variables}{list of character vectors of length 2, names of pairs of variables to be plotted} + +\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used.} + +\item{facet_ncol}{number of columns for arranging subplots} + +\item{title}{character, title of the plot. \code{'default'} automatically generates either "2D partial dependence survival profiles" or "2D accumulated local effects survival profiles" depending on the explanation type.} + +\item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for the XXX model", where XXX is the explainer labels, if \code{marginalize_over_time = FALSE}, time is also added to the subtitle} + +\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} +} +\value{ +A collection of \code{ggplot} objects arranged with the \code{patchwork} package. +} +\description{ +This function plots objects of class \code{"model_profile_2d_survival"} created +using the \code{model_profile_2d()} function. +} +\examples{ +\donttest{ +library(survival) +library(survex) + +cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) +cph_exp <- explain(cph) + +cph_model_profile_2d <- model_profile_2d(cph_exp, + variables = list(c("age", "celltype"), + c("age", "karno"))) +head(cph_model_profile_2d$result) +plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 88.5) + +cph_model_profile_2d_ale <- model_profile_2d(cph_exp, + variables = list(c("age", "karno")), + type = "accumulated") +head(cph_model_profile_2d_ale$result) +plot(cph_model_profile_2d_ale, times = c(4, 88.5), marginalize_over_time = TRUE) +} + +} From 04421c9fd06d0591dbbd0c9a76f61f5e9ff04197 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 14:51:45 +0200 Subject: [PATCH 108/207] add warning --- R/plot_model_profile_survival.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index fefed694..9ccf6bd6 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -4,7 +4,7 @@ #' using the `model_profile()` function. #' #' @param x an object of class `model_profile_survival` to be plotted -#' @param ... additional objects of class `"model_profile_survival"` to be plotted together +#' @param ... additional objects of class `model_profile_survival` to be plotted together #' @param variables character, names of the variables to be plotted #' @param variable_type character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all #' @param facet_ncol number of columns for arranging subplots @@ -174,6 +174,7 @@ plot2.model_profile_survival <- function(x, if (is.null(times)) { times <- quantile(x$eval_times, p = 0.5, type = 1) + warning("Plot will be prepared for the median time point from the `times` vector. For another time point, set the value of `times`.") } if (!all(times %in% x$eval_times)) { From 9623fcb15d5440695776147c1e295cec0b381fa6 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 14:52:01 +0200 Subject: [PATCH 109/207] update man --- man/model_profile.surv_explainer.Rd | 3 ++- man/model_profile_2d.surv_explainer.Rd | 9 +++------ man/plot.model_profile_survival.Rd | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/man/model_profile.surv_explainer.Rd b/man/model_profile.surv_explainer.Rd index 2c39b243..c8d54d4e 100644 --- a/man/model_profile.surv_explainer.Rd +++ b/man/model_profile.surv_explainer.Rd @@ -83,7 +83,8 @@ head(cph_model_profile$result) plot(cph_model_profile) rsf_model_profile <- model_profile(rsf_src_exp, output_type = "survival", - variables = c("age", "celltype")) + variables = c("age", "celltype"), + type = "accumulated") head(rsf_model_profile$result) diff --git a/man/model_profile_2d.surv_explainer.Rd b/man/model_profile_2d.surv_explainer.Rd index 8a6b3250..6eaf1e27 100644 --- a/man/model_profile_2d.surv_explainer.Rd +++ b/man/model_profile_2d.surv_explainer.Rd @@ -61,21 +61,18 @@ library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) -rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) - cph_exp <- explain(cph) -rsf_src_exp <- explain(rsf_src) cph_model_profile_2d <- model_profile_2d(cph_exp, variables = list(c("age", "celltype"))) head(cph_model_profile_2d$result) plot(cph_model_profile_2d) -rsf_model_profile_2d <- model_profile_2d(rsf_src_exp, +cph_model_profile_2d_ale <- model_profile_2d(cph_exp, variables = list(c("age", "karno")), type = "accumulated") -head(rsf_model_profile_2d$result) -plot(rsf_model_profile_2d) +head(cph_model_profile_2d_ale$result) +plot(cph_model_profile_2d_ale) } } diff --git a/man/plot.model_profile_survival.Rd b/man/plot.model_profile_survival.Rd index 338f3cff..9608c4f7 100644 --- a/man/plot.model_profile_survival.Rd +++ b/man/plot.model_profile_survival.Rd @@ -21,7 +21,7 @@ \arguments{ \item{x}{an object of class \code{model_profile_survival} to be plotted} -\item{...}{additional objects of class \code{"model_profile_survival"} to be plotted together} +\item{...}{additional objects of class \code{model_profile_survival} to be plotted together} \item{variables}{character, names of the variables to be plotted} From 8a1ae80cc60e4f000a20461484a1d10558155070 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 15:00:05 +0200 Subject: [PATCH 110/207] fix timepoints in example --- R/plot_model_profile_2d.R | 6 +++--- man/plot.model_profile_2d_survival.Rd | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index 6ea739fb..04f17dbd 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -27,13 +27,13 @@ #' variables = list(c("age", "celltype"), #' c("age", "karno"))) #' head(cph_model_profile_2d$result) -#' plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 88.5) +#' plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 103) #' #' cph_model_profile_2d_ale <- model_profile_2d(cph_exp, #' variables = list(c("age", "karno")), #' type = "accumulated") #' head(cph_model_profile_2d_ale$result) -#' plot(cph_model_profile_2d_ale, times = c(4, 88.5), marginalize_over_time = TRUE) +#' plot(cph_model_profile_2d_ale, times = c(8, 103), marginalize_over_time = TRUE) #' } #' #' @export @@ -180,7 +180,7 @@ prepare_model_profile_2d_plots <- function(x, }) } - if (i != length(x$variables)) + if (i != length(variables)) p <- p + guides(fill = "none") return(p) }) diff --git a/man/plot.model_profile_2d_survival.Rd b/man/plot.model_profile_2d_survival.Rd index 8c87c322..7d3e0968 100644 --- a/man/plot.model_profile_2d_survival.Rd +++ b/man/plot.model_profile_2d_survival.Rd @@ -52,13 +52,13 @@ cph_model_profile_2d <- model_profile_2d(cph_exp, variables = list(c("age", "celltype"), c("age", "karno"))) head(cph_model_profile_2d$result) -plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 88.5) +plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 103) cph_model_profile_2d_ale <- model_profile_2d(cph_exp, variables = list(c("age", "karno")), type = "accumulated") head(cph_model_profile_2d_ale$result) -plot(cph_model_profile_2d_ale, times = c(4, 88.5), marginalize_over_time = TRUE) +plot(cph_model_profile_2d_ale, times = c(8, 103), marginalize_over_time = TRUE) } } From a79ea0f3d5c16f94d0e5374afc31c12dfaccbcb8 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 15:17:57 +0200 Subject: [PATCH 111/207] move subtitle creation --- R/plot_model_profile_survival.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index d6f11adc..9e696b43 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -194,14 +194,14 @@ plot2.model_profile_survival <- function(x, } } + single_timepoint <- ((length(times) == 1) || marginalize_over_time) if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", unique(variable), " variable") - if (single_timepoint){ + if (single_timepoint & !marginalize_over_time){ subtitle <- paste0(subtitle, " and time=", times) } } - single_timepoint <- ((length(times) == 1) || marginalize_over_time) is_categorical <- (unique(x$result[x$result$`_vname_` == variable, "_vtype_"]) == "categorical") ice_needed <- plot_type %in% c("pdp+ice", "ice") From fd3287b8f344c386e08901eab75e37e8afde2aa5 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 15:18:14 +0200 Subject: [PATCH 112/207] add missing params --- R/plot_model_profile_2d.R | 1 + R/plot_surv_shap.R | 20 +++++++++++++++----- man/plot.aggregated_surv_shap.Rd | 18 ++++++++++++++---- man/plot.model_profile_2d_survival.Rd | 2 ++ 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index 04f17dbd..a207d77c 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -7,6 +7,7 @@ #' @param ... additional objects of class `model_profile_2d_survival` to be plotted together #' @param variables list of character vectors of length 2, names of pairs of variables to be plotted #' @param times numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If `NULL` (default) then the median time from the explainer object is used. +#' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately #' @param facet_ncol number of columns for arranging subplots #' @param title character, title of the plot. `'default'` automatically generates either "2D partial dependence survival profiles" or "2D accumulated local effects survival profiles" depending on the explanation type. #' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for the XXX model", where XXX is the explainer labels, if `marginalize_over_time = FALSE`, time is also added to the subtitle diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index cfc901e8..19c5e961 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -97,7 +97,8 @@ plot.surv_shap <- function(x, #' explanations of survival models created using the `model_survshap()` function. #' #' @param x an object of class `aggregated_surv_shap` to be plotted -#' @param kind character, one of `"importance"`, `"swarm"`, or `"profile"`. Type of chart to be plotted: `"importance"` shows the importance of variables over time and aggregated, `"swarm"` shows the distribution of SurvSHAP(t) values for variables and observations, `"profile"` shows the dependence of SurvSHAP(t) values on variable values. +#' @param kind character, one of `"importance"`, `"swarm"`, or `"profile"`. Type of chart to be plotted; `"importance"` shows the importance of variables over time and aggregated, `"swarm"` shows the distribution of SurvSHAP(t) values for variables and observations, `"profile"` shows the dependence of SurvSHAP(t) values on variable values. +#' @param ... additional parameters passed to internal functions #' @param title character, title of the plot #' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for the XXX model (n = YYY)", where XXX is the explainer label and YYY is the number of observations used for calculations #' @param max_vars maximum number of variables to be plotted (least important variables are ignored), by default 7 @@ -116,7 +117,7 @@ plot.surv_shap <- function(x, #' #' ## `plot.aggregated_surv_shap(type = "swarm")` #' -#' * no additional options +#' * no additional parameters #' #' #' ## `plot.aggregated_surv_shap(type = "swarm")` @@ -141,6 +142,9 @@ plot.surv_shap <- function(x, plot.aggregated_surv_shap <- function(x, kind = "importance", ..., + title="default", + subtitle="default", + max_vars=7, colors = NULL){ if (is.null(colors)){ colors <- c(low = "#9fe5bd", @@ -165,7 +169,7 @@ plot.aggregated_surv_shap <- function(x, plot_shap_global_importance <- function(x, ..., - title = "Feature importance according to aggregated |SurvSHAP(t)|", + title = "default", subtitle = "default", max_vars = 7, colors = NULL, @@ -190,6 +194,8 @@ plot_shap_global_importance <- function(x, long_df <- stack(x$aggregate) long_df <- long_df[order(long_df$values, decreasing = TRUE),][1:min(max_vars, length(x$aggregate)), ] + if (!is.null(subtitle) && subtitle == "default") + title <- "Feature importance according to aggregated |SurvSHAP(t)|" if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0( "created for the ", label, " model ", @@ -220,7 +226,7 @@ plot_shap_global_importance <- function(x, plot_shap_global_swarm <- function(x, ..., - title = "Aggregated SurvSHAP(t) values summary", + title = "default", subtitle = "default", max_vars = 7, colors = NULL){ @@ -236,6 +242,8 @@ plot_shap_global_swarm <- function(x, df <- cbind(df, var_value) label <- attr(x, "label") + if (!is.null(subtitle) && subtitle == "default") + title <- "Aggregated SurvSHAP(t) values summary" if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0( "created for the ", label, " model ", @@ -269,7 +277,7 @@ plot_shap_global_profile <- function(x, ..., variable = NULL, color_variable = NULL, - title = "Aggregated SurvSHAP(t) profile", + title = "default", subtitle = "default", max_vars = 7, colors = NULL){ @@ -292,6 +300,8 @@ plot_shap_global_profile <- function(x, colnames(df) <- c("shap_val", "variable_val", "color_variable_val") label <- attr(x, "label") + if (!is.null(subtitle) && subtitle == "default") + title <- "Aggregated SurvSHAP(t) profile" if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0( "created for the ", label, " model ", diff --git a/man/plot.aggregated_surv_shap.Rd b/man/plot.aggregated_surv_shap.Rd index fabbc8d9..173cd1aa 100644 --- a/man/plot.aggregated_surv_shap.Rd +++ b/man/plot.aggregated_surv_shap.Rd @@ -4,20 +4,30 @@ \alias{plot.aggregated_surv_shap} \title{Plot Aggregated SurvSHAP(t) Explanations for Survival Models} \usage{ -\method{plot}{aggregated_surv_shap}(x, kind = "importance", ..., colors = NULL) +\method{plot}{aggregated_surv_shap}( + x, + kind = "importance", + ..., + title = "default", + subtitle = "default", + max_vars = 7, + colors = NULL +) } \arguments{ \item{x}{an object of class \code{aggregated_surv_shap} to be plotted} -\item{kind}{character, one of \code{"importance"}, \code{"swarm"}, or \code{"profile"}. Type of chart to be plotted: \code{"importance"} shows the importance of variables over time and aggregated, \code{"swarm"} shows the distribution of SurvSHAP(t) values for variables and observations, \code{"profile"} shows the dependence of SurvSHAP(t) values on variable values.} +\item{kind}{character, one of \code{"importance"}, \code{"swarm"}, or \code{"profile"}. Type of chart to be plotted; \code{"importance"} shows the importance of variables over time and aggregated, \code{"swarm"} shows the distribution of SurvSHAP(t) values for variables and observations, \code{"profile"} shows the dependence of SurvSHAP(t) values on variable values.} -\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} +\item{...}{additional parameters passed to internal functions} \item{title}{character, title of the plot} \item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for the XXX model (n = YYY)", where XXX is the explainer label and YYY is the number of observations used for calculations} \item{max_vars}{maximum number of variables to be plotted (least important variables are ignored), by default 7} + +\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} } \value{ An object of the class \code{ggplot}. @@ -38,7 +48,7 @@ explanations of survival models created using the \code{model_survshap()} functi \subsection{\code{plot.aggregated_surv_shap(type = "swarm")}}{ \itemize{ -\item no additional options +\item no additional parameters } } diff --git a/man/plot.model_profile_2d_survival.Rd b/man/plot.model_profile_2d_survival.Rd index 7d3e0968..36bad02f 100644 --- a/man/plot.model_profile_2d_survival.Rd +++ b/man/plot.model_profile_2d_survival.Rd @@ -25,6 +25,8 @@ \item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used.} +\item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately} + \item{facet_ncol}{number of columns for arranging subplots} \item{title}{character, title of the plot. \code{'default'} automatically generates either "2D partial dependence survival profiles" or "2D accumulated local effects survival profiles" depending on the explanation type.} From 662763a8f7827a3cbcac55a79f7bb36ea07273e6 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 4 Aug 2023 15:27:00 +0200 Subject: [PATCH 113/207] update tests for handling warnings --- tests/testthat/test-model_profile.R | 18 +++++++++--------- tests/testthat/test-model_profile_2d.R | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 3988b7b0..c0b19b8d 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -29,9 +29,9 @@ test_that("model_profile with type = 'partial' works", { ### Add tests for plot2 for numerical PDP # single timepoint - plot2(mp_cph_num, variable = "karno", plot_type = "pdp+ice") - plot2(mp_cph_num, variable = "karno", plot_type = "pdp") - plot2(mp_cph_num, variable = "karno", plot_type = "ice") + plot2(mp_cph_num, variable = "karno", plot_type = "pdp+ice", times = cph_exp$times[1]) + plot2(mp_cph_num, variable = "karno", plot_type = "pdp", times = cph_exp$times[1]) + plot2(mp_cph_num, variable = "karno", plot_type = "ice", times = cph_exp$times[1]) # multiple timepoints plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "pdp+ice") plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "pdp") @@ -50,9 +50,9 @@ test_that("model_profile with type = 'partial' works", { plot(mp_cph_cat, mp_rsf_cat) ### Add tests for plot2 for categorical PDP # single timepoint - plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp+ice") - plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp") - plot2(mp_rsf_cat, variable = "celltype", plot_type = "ice") + plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp+ice", times = rsf_ranger_exp$times[1]) + plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp", times = rsf_ranger_exp$times[1]) + plot2(mp_rsf_cat, variable = "celltype", plot_type = "ice", times = rsf_ranger_exp$times[1]) # multiple timepoints plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "pdp+ice") plot2(mp_rsf_cat, times = c(4, 80.7), marginalize_over_time = T, variable = "celltype", plot_type = "pdp+ice") @@ -76,8 +76,8 @@ test_that("model_profile with type = 'partial' works", { expect_true(all(unique(mp_rsf_num$result$`_vname_`) %in% colnames(rsf_ranger_exp$data))) expect_output(print(mp_cph_num)) + expect_warning(plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp+ice")) expect_error(plot(mp_rsf_num, variables = "nonexistent", grid_points = 6)) - expect_error(model_profile(rsf_ranger_exp, type = "conditional")) expect_error(plot2(mp_rsf_num, variable = "nonexistent")) expect_error(plot2(mp_rsf_num, variable = "age", times = -1)) @@ -104,7 +104,7 @@ test_that("model_profile with type = 'accumulated' works", { ### Add tests for plot2 for categorical ALE # single timepoint - plot2(mp_cph_cat, variable = "celltype") + plot2(mp_cph_cat, variable = "celltype", times=cph_exp$times[1]) # multiple timepoints plot2(mp_cph_cat, times = c(4, 80.7), variable = "celltype", plot_type = "ale") @@ -128,7 +128,7 @@ test_that("model_profile with type = 'accumulated' works", { ### Add tests for plot2 for numerical ALE # single timepoint - plot2(mp_cph_num, variable = "karno", plot_type = "ale") + plot2(mp_cph_num, variable = "karno", plot_type = "ale", times=cph_exp$times[1]) # multiple timepoints plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "ale") diff --git a/tests/testthat/test-model_profile_2d.R b/tests/testthat/test-model_profile_2d.R index 71463aaa..56a98838 100644 --- a/tests/testthat/test-model_profile_2d.R +++ b/tests/testthat/test-model_profile_2d.R @@ -19,14 +19,14 @@ test_that("model_profile_2d with type = 'partial' works", { categorical_variables = 1, grid_points = 2, N = 2) - plot(mp_cph_pdp) + plot(mp_cph_pdp, times=cph_exp$times[1]) mp_rsf_pdp <- model_profile_2d(rsf_exp, variables = list(c("karno", "age")), grid_points = 6, output_type = "survival", N = 25) - plot(mp_cph_pdp, mp_rsf_pdp, variables = list(c("karno", "age"))) + plot(mp_cph_pdp, mp_rsf_pdp, variables = list(c("karno", "age")), times=cph_exp$times[1]) expect_output(print(mp_cph_pdp)) expect_s3_class(mp_cph_pdp, "model_profile_2d_survival") @@ -35,6 +35,7 @@ test_that("model_profile_2d with type = 'partial' works", { expect_true(all(unique(c(mp_cph_pdp$result$`_v1name_`, mp_cph_pdp$result$`_v2name_`)) %in% colnames(cph_exp$data))) + expect_warning(plot(mp_cph_pdp)) expect_error(model_profile_2d(rsf_exp)) expect_error(model_profile_2d(rsf_exp, type = "conditional", variables = list(c("karno", "age")))) @@ -67,8 +68,7 @@ test_that("model_profile_2d with type = 'accumulated' works", { expect_error(model_profile_2d(rsf_exp, type = "accumulated", variables = list(c("karno", "celltype")))) - - plot(mp_rsf_ale) + plot(mp_rsf_ale, times=rsf_exp$times[1]) expect_output(print(mp_rsf_ale)) From a01f2e1df3a4b5c08dc881ddb0a95683d0315686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 7 Aug 2023 10:14:49 +0200 Subject: [PATCH 114/207] Add info to vignette --- vignettes/survex-usage.Rmd | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vignettes/survex-usage.Rmd b/vignettes/survex-usage.Rmd index cb3362ec..dd911927 100644 --- a/vignettes/survex-usage.Rmd +++ b/vignettes/survex-usage.Rmd @@ -37,6 +37,14 @@ rsf <- randomForestSRC::rfsrc(Surv(time, status)~., data = vet) rsf_exp <- explain(rsf) ``` +However, for some models, not all data can be extracted automatically. If we want to create an explainer for a Random Survival Forest from the `ranger` package, we need to supply `data`, and `y` on our own. It is important to remember, that we should supply the data parameter **without** the columns containing survival information. + +``` {r} +library(ranger) + +ranger_rsf <- ranger(Surv(time, status)~., data = vet) +ranger_rsf_exp <- explain(ranger_rsf, data = vet[, -c(3,4)], y = Surv(vet$time, vet$status)) +``` # Making predictions @@ -53,7 +61,6 @@ predict(cph_exp, veteran[1:2,], output_type="chf", times=seq(1, 600, 100)) predict(rsf_exp, veteran[1:2,], output_type="chf", times=seq(1, 600, 100)) ``` - # Measuring performance Another helpful thing is the functionality for calculating different metrics of the models. For this we use the `model_performance()` function. It calculates a set of performance measures we can plot next to each other and easily compare. @@ -111,8 +118,12 @@ plot(model_parts_cph_auc,model_parts_rsf_auc) We observe that the results are consistent with `loss_brier_score()`. That's good news - a simple sanity check that these variables are the most important for these models. +**Important note (new functionality):** We can also measure the importance of variables using global SurvSHAP(t) explanations. Details on how to do so, are presented in this vignette. + ## Partial dependence +**Important note (new functionality):** A more detailed description of the partial dependence explanation, with added functionality is presented in this vignette. + The next type of global explanation this package provides is partial dependence plots. This is calculated using the `model_profile()` function. These plots show how setting one variable to a different value would, on average, influence the model's prediction. Again, this is an extension of [partial dependence](https://ema.drwhy.ai/partialDependenceProfiles.html) known from regression and classification tasks, applied to survival models by extending it to take the time dimension into account. Note that we need to set the `categorical_variables` parameter in order to avoid nonsensical values, such as the treatment value of 0.5. All factors are automatically detected as categories, but if you want to treat a numerical variable as a categorical one, you need to set it here. From 96365f646fc29035c10d7b0ec3ad1b9e2f8f2bde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 7 Aug 2023 13:43:24 +0200 Subject: [PATCH 115/207] Update NEWS.md --- NEWS.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/NEWS.md b/NEWS.md index 9b4406cb..4f178cf6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,13 @@ # survex (development) * Fix not being able to plot or print SurvLIME results for the cph model sometimes. ([#72](https://github.com/ModelOriented/survex/issues/72)) +* Add global explanations via the SurvSHAP(t) method (see `model_survshap()` function) +* Add plots for global SurvSHAP(t) explanations (see `plot.aggregated_surv_shap()`) +* Add Accumulated Local Effects (ALE) explanations (see `model_profile(..., type = "accumulated")`) +* Add 2-dimensional PDP and ALE plots (see `model_profile_2d()` function) +* Add `plot2()` function for plotting PDP and ALE explanations without the time dimension +* Improvement on the vignettes for the package (see `vignette("pdp")` and `vignette("global-survshap")`) +* Increase the test coverage of the pacakge. # survex 1.0.0 From ec653fd5819178c332cfb1f953c6f1055225fc48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 7 Aug 2023 13:44:01 +0200 Subject: [PATCH 116/207] Update and add new vignettes --- vignettes/global-survshap.Rmd | 56 ++++++++++++++++++ vignettes/pdp.Rmd | 103 ++++++++++++++++++++++++++++++++++ vignettes/survex-usage.Rmd | 4 +- 3 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 vignettes/global-survshap.Rmd create mode 100644 vignettes/pdp.Rmd diff --git a/vignettes/global-survshap.Rmd b/vignettes/global-survshap.Rmd new file mode 100644 index 00000000..182ce301 --- /dev/null +++ b/vignettes/global-survshap.Rmd @@ -0,0 +1,56 @@ +--- +title: "Global explanations with SurvSHAP(t)" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Global explanations with SurvSHAP(t)} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +This vignette demonstrates how to use the `model_survshap()` function can be used to calculate global explanations for survival models. These explanations can be used to assess variable importance, and can be plotted in different ways to focus on different aspects. + +To create this explanation we follow the standard way of working with survex i.e. we create a model, and an explainer. + + +```{r setup} +library(survex) +library(survival) +library(ranger) + +vet <- survival::veteran + +cph <- coxph(Surv(time, status) ~ ., data = vet, x = TRUE, model = TRUE) +exp <- explain(cph, data = vet[, -c(3,4)], y = Surv(vet$time, vet$status)) + +``` + +We use the explainer and the `model_survshap()` function to calculate SurvSHAP(t) explanations. We can specify the observations for which we want to calculate the explanations. In this example we calculate the explanations for the first 20 observations in the veteran dataset. **Note:** The background for generating SHAP values is the `data` field of the explainer! If you want to calculate explanations with a background that is not the training data, you need to manually specify the `data` argument, when creating the explainer. + +```{r} +shap <- model_survshap(exp, veteran[1:20, -c(3,4)]) +``` + +We plot these explanations using the `plot.aggregated_surv_shap()` function. By default the plot features 2 panels, the one on the left depicts overall importance of variables. The panel on the right demonstates the time-dependent importance of each variable calculated as the mean absolute SHAP value at each time point across all observations. + +```{r} +plot(shap) +``` + +The `plot.aggregated_surv_shap()` function can also be used to plot the explanations for a single variable. The `variable` argument specifies the variable for which the explanations are plotted. The `kind` argument specifies the type of plot. For `kind = "profile"` a plot is generated that shows the mean SHAP value (averaged across the time domain) depending on the value of the variable. + +```{r} +plot(shap, variable = "karno", kind = "profile") +``` + +For `kind = "swarm"` a swarm plot is generated that shows the SHAP values for each observation. The swarm plot is a good way to assess the distribution of SHAP values for each variable. + +```{r} +plot(shap, kind = "swarm") +``` \ No newline at end of file diff --git a/vignettes/pdp.Rmd b/vignettes/pdp.Rmd new file mode 100644 index 00000000..f417f822 --- /dev/null +++ b/vignettes/pdp.Rmd @@ -0,0 +1,103 @@ +--- +title: "Partial Dependence Explanations" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Partial Dependence Explanations} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +This vignette demonstrates how to use the Partial Dependence explanations in `survex`, as well as Accumulated Local Effects explanations. It especially demonstrates the usage of new kinds of plots that are available in the 1.1 version of the package. To create these explanations we follow the standard way of working with survex i.e. we create a model, and an explainer. + +```{r setup} +library(survex) +library(survival) +library(ranger) + +set.seed(123) + +vet <- survival::veteran + +rsf <- ranger(Surv(time, status) ~ ., data = vet) +exp <- explain(rsf, data = vet[, -c(3,4)], y = Surv(vet$time, vet$status)) + +``` + +We use the explainer and the `model_profile()` function to calculate Partial Dependence explanations. We can specify the variables for which we want to calculate the explanations. In this example we calculate the explanations for the variables `karno` and `celltype`. **Note:** The background for generating PD values is the `data` field of the explainer! If you want to calculate explanations with a background that is not the training data, you need to manually specify the `data` argument, when creating the explainer. + +We can calculate Accumulated Local Effects in the same way, by setting the `type` argument to `"accumulated"`. + +```{r} +pdp <- model_profile(exp, variables = c("karno", "celltype"), n = 20) +ale <- model_profile(exp, variables = c("karno", "celltype"), n = 20, type = "accumulated") +``` + +To plot these explanations you can use the plot function. By default the explanations for all calculated variables are plotted. This example demonstrates this for the `pdp` object which contains the explanations for 2 variables. + +```{r} +plot(pdp) +``` + +We can plot ALE explanations in the same way as PD explanations, as is demonstrated by the example below. For the rest of the vignette we only focus on PD explanations. + +```{r} +plot(ale) +``` + +The `plot()` function can also be used to plot the explanations for a subset of variables. The `variables` argument specifies the variables for which the explanations are plotted. The `numerical_plot_type` argument specifies the type of plot for numerical variables. For `numerical_plot_type = "lines"` (default), the y-axis represents the mean prediction (survival function), x-axis represents the time dimension, and different colors represent values of the studied variable. For `numerical_plot_type = "contours"`, the y-axis represents the values of the studied variable, x-axis represents time, and different colors represent the mean prediction (survival function). + +```{r} +plot(pdp, variables = c("karno"), numerical_plot_type = "contours") +``` + +The plots above make use of the time dependent output of survival models, by placing the time dimension on the x-axis. However, for people familiar with Partial Dependence explanations in classification and regression, it might be more intuitive to place the variable values on the x-axis. For this reason, we provide the `plot2()` function, which can display the explanations without the aspect of time. + +To use this function a specific time of interest has to be chosen. This time needs to be one of the values in the `times` field of the explainer. If the automatically generated times do not contain the time of interest, one needs to manually specify the `times` argument when creating the explainer. + +The example below shows the PD explanations for the `karno` variable at time 80. The y-axis represents the mean prediction (survival function), x-axis represents the values of the studied variable. Thin background lines are individual ceteris paribus profiles (otherwise known as ICE profiles). + +```{r} +plot2(pdp, variable = "karno", times = 80) +``` + +The same plot can be generated for the categorical `celltype` variable. In this case the x-axis represents the different values of the studied variable, boxplots present the distribution of individual ceteris paribus profiles, and the line represents the mean prediction (survival function), which is the PD explanation. + +```{r} +plot2(pdp, variable = "celltype", times = 80) +``` + +Of course, the plots can be prepared for multiple timepoints, at the same time and presented on one plot. + +```{r} +plot2(pdp, variable = "karno", times = c(1, 80, 151.72)) +``` + +```{r} +plot2(pdp, variable = "celltype", times = c(1, 80, 151.72)) +``` + + +`survex` also implements 2 dimensional PD and ALE explanations. These can be used to study the interaction between variables. The `model_profile_2d()` function can be used to calculate these explanations. The `variables` argument specifies the variables for which the explanations are calculated. The `type` argument specifies the type of explanation, and can be set to `"partial"` (default) or `"accumulated"`. + +```{r} +pdp_2d <- model_profile_2d(exp, variables = list(c("karno", "age"))) +pdp_2d_num_cat <- model_profile_2d(exp, variables = list(c("karno", "celltype"))) +``` + +These explanations can be plotted using the plot function. + +```{r} +plot(pdp_2d, times = 80) +``` + +```{r} +plot(pdp_2d_num_cat, times = 80) +``` + diff --git a/vignettes/survex-usage.Rmd b/vignettes/survex-usage.Rmd index dd911927..aaad1a7a 100644 --- a/vignettes/survex-usage.Rmd +++ b/vignettes/survex-usage.Rmd @@ -118,11 +118,11 @@ plot(model_parts_cph_auc,model_parts_rsf_auc) We observe that the results are consistent with `loss_brier_score()`. That's good news - a simple sanity check that these variables are the most important for these models. -**Important note (new functionality):** We can also measure the importance of variables using global SurvSHAP(t) explanations. Details on how to do so, are presented in this vignette. +**Important note (new functionality):** We can also measure the importance of variables using global SurvSHAP(t) explanations. Details on how to do so, are presented in this vignette: `vignette("global-survshap")`. ## Partial dependence -**Important note (new functionality):** A more detailed description of the partial dependence explanation, with added functionality is presented in this vignette. +**Important note (new functionality):** A more detailed description of the partial dependence explanation, with added functionality is presented in this vignette: `vignette("pdp")`. The next type of global explanation this package provides is partial dependence plots. This is calculated using the `model_profile()` function. These plots show how setting one variable to a different value would, on average, influence the model's prediction. Again, this is an extension of [partial dependence](https://ema.drwhy.ai/partialDependenceProfiles.html) known from regression and classification tasks, applied to survival models by extending it to take the time dimension into account. From 5d3b3655862eece4e3b02e97a197380a36e7a1a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 7 Aug 2023 13:45:18 +0200 Subject: [PATCH 117/207] Add knitr version dependency to avoid vignette problems while running R CMD check --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 57ba6d7c..9cff15fb 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -35,7 +35,7 @@ Suggests: generics, glmnet, ingredients, - knitr, + knitr (>= 4.2), mboost, parsnip, progressr, From dda77fb98e07903c886375e8f563c12ca32f42b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 7 Aug 2023 13:51:01 +0200 Subject: [PATCH 118/207] Fix broken version of knitr --- DESCRIPTION | 2 +- survex.Rproj | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 9cff15fb..40b4e529 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -35,7 +35,7 @@ Suggests: generics, glmnet, ingredients, - knitr (>= 4.2), + knitr (>= 1.42), mboost, parsnip, progressr, diff --git a/survex.Rproj b/survex.Rproj index fa0d77bf..5d346a0e 100644 --- a/survex.Rproj +++ b/survex.Rproj @@ -19,4 +19,6 @@ LineEndingConversion: Posix BuildType: Package PackageUseDevtools: Yes PackageInstallArgs: --no-multiarch --with-keep.source +PackageBuildArgs: --no-build-vignettes +PackageCheckArgs: --ignore-vignettes PackageRoxygenize: rd,collate,namespace From f545615e5a6c8ebab2519a5547cb47ff6fbecfa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 7 Aug 2023 15:42:01 +0200 Subject: [PATCH 119/207] Add list of functionlaities and roadmap to readme --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index a943713c..bc60f59c 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,23 @@ plot(model_parts(explainer)) plot(predict_parts(explainer, veteran[1, -c(3, 4)])) ``` +## Functionalities and roadmap + +Existing functionalities: +- [x] calculation of performance metrics (Brier Score, Time-dependent C/D AUC, metrics from `mlr3proba`) - `model_performance()` +- [x] calculation of feature importance (Permutation Feature Importance - PFI) - `model_parts()` +- [x] calculation of partial dependence (Partial Dependence Profiles - PDP, Accumulated Local Effects - ALE) - `model_profile()` +- [x] calculation of 2-dimensional partial dependence (2D PDP, 2D ALE) - `model_profile_2d()` +- [x] calculation of local feature attributions (SurvSHAP(t), SurvLIME) - `predict_parts()` +- [x] calculation of local ceteris paribus explanations (Ceteris Paribus profiles - CP/ Individual Conditional Expectations - ICE) - `predict_profile()` +- [x] calculation of global feature attributions using SurvSHAP(t) - `model_survshap()` + +Currently in develompment: +- [ ] ... + +Future plans: +- [ ] ... (raise an Issue on GitHub if you have any suggestions) + ## Usage [![`survex` usage cheatsheet](man/figures/cheatsheet.png)](https://github.com/ModelOriented/survex/blob/main/misc/cheatsheet.pdf) From 236aacf370ee650d0e1d608e7101cd2c560cdb5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Mon, 7 Aug 2023 15:44:14 +0200 Subject: [PATCH 120/207] Add predict to readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bc60f59c..ebfc227b 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ plot(predict_parts(explainer, veteran[1, -c(3, 4)])) ## Functionalities and roadmap Existing functionalities: +- [x] unified prediction interface using the explainer object - `predict()` - [x] calculation of performance metrics (Brier Score, Time-dependent C/D AUC, metrics from `mlr3proba`) - `model_performance()` - [x] calculation of feature importance (Permutation Feature Importance - PFI) - `model_parts()` - [x] calculation of partial dependence (Partial Dependence Profiles - PDP, Accumulated Local Effects - ALE) - `model_profile()` From 705cae29aa449245d7a7ce70c8e41de35d810c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Thu, 10 Aug 2023 14:28:46 +0200 Subject: [PATCH 121/207] Change `plot2` to plot(..., `geom = "variable"`) --- DESCRIPTION | 2 +- NAMESPACE | 2 - R/plot_model_profile_survival.R | 121 ++++++++++++++-------------- man/plot.model_profile_survival.Rd | 34 ++++++-- man/plot2.model_profile_survival.Rd | 65 --------------- tests/testthat/test-model_profile.R | 58 ++++++------- vignettes/pdp.Rmd | 10 +-- 7 files changed, 124 insertions(+), 168 deletions(-) delete mode 100644 man/plot2.model_profile_survival.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 40b4e529..6a6555f1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.0.0.9100 +Version: 1.0.0.9101 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), diff --git a/NAMESPACE b/NAMESPACE index ecadd196..8666414f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -27,7 +27,6 @@ S3method(plot,surv_lime) S3method(plot,surv_model_performance) S3method(plot,surv_model_performance_rocs) S3method(plot,surv_shap) -S3method(plot2,model_profile_survival) S3method(predict,surv_explainer) S3method(predict_parts,default) S3method(predict_parts,surv_explainer) @@ -66,7 +65,6 @@ export(model_performance) export(model_profile) export(model_profile_2d) export(model_survshap) -export(plot2) export(predict_parts) export(predict_profile) export(risk_from_chf) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 9e696b43..f72bdb40 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -4,15 +4,19 @@ #' using the `model_profile()` function. #' #' @param x an object of class `model_profile_survival` to be plotted -#' @param ... additional objects of class `model_profile_survival` to be plotted together -#' @param variables character, names of the variables to be plotted -#' @param variable_type character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all -#' @param facet_ncol number of columns for arranging subplots -#' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots -#' @param title character, title of the plot -#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels -#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") -#' @param rug character, one of `"all"`, `"events"`, `"censors"`, `"none"` or `NULL`. Which times to mark on the x axis in `geom_rug()`. +#' @param ... additional objects of class `model_profile_survival` to be plotted together. Only available for `geom = "time"` +#' @param geom character, either "time" or "variable". Selects the type of plot to be prepared. If `"time"` then the x-axis represents survival times, and variable is denoted by colors, if `"variable"` then the x-axis represents the variable values, and mean predictions at selected timepoints. +#' @param variables character, names of the variables to be plotted. When `geom = "variable"` it needs to be a name of a single variable, when `geom = "time"` it can be a vector of variable names. If `NULL` (default) then all variables are plotted. +#' @param variable_type character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all. Only used when `geom = "time"` +#' @param facet_ncol number of columns for arranging subplots. Only used when `geom = "time"` +#' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots. Only used when `geom = "time"` +#' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately. Only used when `geom = "variable"` +#' @param plot_type character, one of `"pdp"`, `"ice"`, `"pdp+ice"`, or `NULL` (default). If `NULL` then the type of plot is chosen automatically based on the number of variables to be plotted. Only used when `geom = "variable"` +#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If `NULL` (default) then the median time from the explainer object is used. Only used when `geom = "variable"` and `marginalize_over_time = FALSE +#' @param title character, title of the plot +#' @param subtitle character, subtitle of the plot, `"default"` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels +#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue"). +#' @param rug character, one of `"all"`, `"events"`, `"censors"`, `"none"` or `NULL`. Which times to mark on the x axis in `geom_rug()`. Only used when `geom = "time"`. #' @param rug_colors character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. #' #' @return A collection of `ggplot` objects arranged with the `patchwork` package. @@ -32,21 +36,58 @@ #' plot(m_prof, numerical_plot_type = "contours") #' #' plot(m_prof, variables = c("trt", "age"), facet_ncol = 1) +#' +#' plot(m_prof, geom = "variable", variables = "karno", plot_type = "pdp+ice") +#' +#' plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "karno", plot_type = "pdp+ice") +#' +#' plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "trt", plot_type = "pdp+ice") #' } #' #' @export plot.model_profile_survival <- function(x, ..., + geom = "time", variables = NULL, variable_type = NULL, facet_ncol = NULL, numerical_plot_type = "lines", + times = NULL, + marginalize_over_time = FALSE, + plot_type = NULL, title = "default", subtitle = "default", colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222")) { + + if (!geom %in% c("time", "variable")) { + stop("`geom` must be one of 'time' or 'survival'.") + } + + if (geom == "variable") { + + pl <- plot2( + x = x, + variable = variables, + times = times, + marginalize_over_time = marginalize_over_time, + plot_type = plot_type, + ... = ..., + title = title, + subtitle = subtitle, + colors = colors + ) + return(pl) + } + + lapply(list(x, ...), function(x) { + if (!inherits(x, "model_profile_survival")) { + stop("All ... must be objects of class `model_profile_survival`.") + } + }) explanations_list <- c(list(x), list(...)) + num_models <- length(explanations_list) if (title == "default") { if (x$type == "partial") { @@ -97,55 +138,17 @@ plot.model_profile_survival <- function(x, return(return_plot) } -#' @rdname plot2.model_profile_survival -#' @export -plot2 <- function(x, ...) UseMethod("plot2") -#' Plot Model Profile for Survival Models (without continuous time aspect) -#' -#' This function plots objects of class `"model_profile_survival"` created -#' using the `model_profile()` function. -#' -#' @param x an object of class `model_profile_survival` to be plotted -#' @param variable character, name of a single variable to be plotted -#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If `NULL` (default) then the median time from the explainer object is used. -#' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately -#' @param plot_type character, one of `"pdp"`, `"ice"`, `"pdp+ice"`, or `"ale"` selects the type of plot to be drawn -#' @param ... other parameters. Currently ignored. -#' @param title character, title of the plot. `'default'` automatically generates either "Partial dependence survival profiles" or "Accumulated local effects survival profiles" depending on the explanation type. -#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels -#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") -#' -#' @return A `ggplot` object. -#' -#' @rdname plot2.model_profile_survival -#' @examples -#' \donttest{ -#' library(survival) -#' library(survex) -#' -#' model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) -#' exp <- explain(model) -#' -#' m_prof <- model_profile(exp, categorical_variables = "trt") -#' -#' plot2(m_prof, variable = "karno", plot_type = "pdp+ice") -#' -#' plot2(m_prof, times = c(1, 2.72), variable = "karno", plot_type = "pdp+ice") -#' -#' plot2(m_prof, times = c(1, 2.72), variable = "celltype", plot_type = "pdp+ice") -#' } -#' -#' @export -plot2.model_profile_survival <- function(x, - variable, - times = NULL, - marginalize_over_time = FALSE, - plot_type = NULL, - ..., - title = "default", - subtitle = "default", - colors = NULL) { +#' @keywords internal +plot2 <- function(x, + variable, + times = NULL, + marginalize_over_time = FALSE, + plot_type = NULL, + ..., + title = "default", + subtitle = "default", + colors = NULL) { if (is.null(plot_type)) { if (x$type == "accumulated") { plot_type <- "ale" @@ -161,11 +164,11 @@ plot2.model_profile_survival <- function(x, } if (is.null(variable) || !is.character(variable)) { - stop("A variable must be specified by name") + stop("The variable must be specified by name") } if (length(variable) > 1) { - stop("Only one variable can be specified") + stop("Only one variable can be specified for `geom`='variable'") } if (!variable %in% x$result$`_vname_`) { diff --git a/man/plot.model_profile_survival.Rd b/man/plot.model_profile_survival.Rd index 9608c4f7..9c548ab9 100644 --- a/man/plot.model_profile_survival.Rd +++ b/man/plot.model_profile_survival.Rd @@ -7,10 +7,14 @@ \method{plot}{model_profile_survival}( x, ..., + geom = "time", variables = NULL, variable_type = NULL, facet_ncol = NULL, numerical_plot_type = "lines", + times = NULL, + marginalize_over_time = FALSE, + plot_type = NULL, title = "default", subtitle = "default", colors = NULL, @@ -21,23 +25,31 @@ \arguments{ \item{x}{an object of class \code{model_profile_survival} to be plotted} -\item{...}{additional objects of class \code{model_profile_survival} to be plotted together} +\item{...}{additional objects of class \code{model_profile_survival} to be plotted together. Only available for \code{geom = "time"}} -\item{variables}{character, names of the variables to be plotted} +\item{geom}{character, either "time" or "variable". Selects the type of plot to be prepared. If \code{"time"} then the x-axis represents survival times, and variable is denoted by colors, if \code{"variable"} then the x-axis represents the variable values, and mean predictions at selected timepoints.} -\item{variable_type}{character, either \code{"numerical"}, \code{"categorical"} or \code{NULL} (default), select only one type of variable for plotting, or leave \code{NULL} for all} +\item{variables}{character, names of the variables to be plotted. When \code{geom = "variable"} it needs to be a name of a single variable, when \code{geom = "time"} it can be a vector of variable names. If \code{NULL} (default) then all variables are plotted.} -\item{facet_ncol}{number of columns for arranging subplots} +\item{variable_type}{character, either \code{"numerical"}, \code{"categorical"} or \code{NULL} (default), select only one type of variable for plotting, or leave \code{NULL} for all. Only used when \code{geom = "time"}} -\item{numerical_plot_type}{character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots} +\item{facet_ncol}{number of columns for arranging subplots. Only used when \code{geom = "time"}} + +\item{numerical_plot_type}{character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots. Only used when \code{geom = "time"}} + +\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used. Only used when \code{geom = "variable"} and `marginalize_over_time = FALSE} + +\item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately. Only used when \code{geom = "variable"}} + +\item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"}, or \code{NULL} (default). If \code{NULL} then the type of plot is chosen automatically based on the number of variables to be plotted. Only used when \code{geom = "variable"}} \item{title}{character, title of the plot} -\item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} +\item{subtitle}{character, subtitle of the plot, \code{"default"} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} -\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} +\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue").} -\item{rug}{character, one of \code{"all"}, \code{"events"}, \code{"censors"}, \code{"none"} or \code{NULL}. Which times to mark on the x axis in \code{geom_rug()}.} +\item{rug}{character, one of \code{"all"}, \code{"events"}, \code{"censors"}, \code{"none"} or \code{NULL}. Which times to mark on the x axis in \code{geom_rug()}. Only used when \code{geom = "time"}.} \item{rug_colors}{character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times.} } @@ -63,6 +75,12 @@ plot(m_prof) plot(m_prof, numerical_plot_type = "contours") plot(m_prof, variables = c("trt", "age"), facet_ncol = 1) + +plot(m_prof, geom = "variable", variables = "karno", plot_type = "pdp+ice") + +plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "karno", plot_type = "pdp+ice") + +plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "trt", plot_type = "pdp+ice") } } diff --git a/man/plot2.model_profile_survival.Rd b/man/plot2.model_profile_survival.Rd deleted file mode 100644 index 172048fa..00000000 --- a/man/plot2.model_profile_survival.Rd +++ /dev/null @@ -1,65 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot_model_profile_survival.R -\name{plot2} -\alias{plot2} -\alias{plot2.model_profile_survival} -\title{Plot Model Profile for Survival Models (without continuous time aspect)} -\usage{ -plot2(x, ...) - -\method{plot2}{model_profile_survival}( - x, - variable, - times = NULL, - marginalize_over_time = FALSE, - plot_type = NULL, - ..., - title = "default", - subtitle = "default", - colors = NULL -) -} -\arguments{ -\item{x}{an object of class \code{model_profile_survival} to be plotted} - -\item{...}{other parameters. Currently ignored.} - -\item{variable}{character, name of a single variable to be plotted} - -\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used.} - -\item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately} - -\item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"}, or \code{"ale"} selects the type of plot to be drawn} - -\item{title}{character, title of the plot. \code{'default'} automatically generates either "Partial dependence survival profiles" or "Accumulated local effects survival profiles" depending on the explanation type.} - -\item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} - -\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} -} -\value{ -A \code{ggplot} object. -} -\description{ -This function plots objects of class \code{"model_profile_survival"} created -using the \code{model_profile()} function. -} -\examples{ -\donttest{ -library(survival) -library(survex) - -model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) -exp <- explain(model) - -m_prof <- model_profile(exp, categorical_variables = "trt") - -plot2(m_prof, variable = "karno", plot_type = "pdp+ice") - -plot2(m_prof, times = c(1, 2.72), variable = "karno", plot_type = "pdp+ice") - -plot2(m_prof, times = c(1, 2.72), variable = "celltype", plot_type = "pdp+ice") -} - -} diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index c0b19b8d..c781a672 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -27,15 +27,15 @@ test_that("model_profile with type = 'partial' works", { plot(mp_cph_num, variable_type = "numerical") plot(mp_cph_num, numerical_plot_type = "contours") - ### Add tests for plot2 for numerical PDP + ### Add tests for plot for numerical PDP # single timepoint - plot2(mp_cph_num, variable = "karno", plot_type = "pdp+ice", times = cph_exp$times[1]) - plot2(mp_cph_num, variable = "karno", plot_type = "pdp", times = cph_exp$times[1]) - plot2(mp_cph_num, variable = "karno", plot_type = "ice", times = cph_exp$times[1]) + plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "pdp+ice", times = cph_exp$times[1]) + plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "pdp", times = cph_exp$times[1]) + plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "ice", times = cph_exp$times[1]) # multiple timepoints - plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "pdp+ice") - plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "pdp") - plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "ice") + plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "pdp+ice") + plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "pdp") + plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "ice") expect_s3_class(mp_cph_num, "model_profile_survival") expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) @@ -48,16 +48,16 @@ test_that("model_profile with type = 'partial' works", { plot(mp_cph_cat, mp_rsf_cat) - ### Add tests for plot2 for categorical PDP + ### Add tests for plot for categorical PDP # single timepoint - plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp+ice", times = rsf_ranger_exp$times[1]) - plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp", times = rsf_ranger_exp$times[1]) - plot2(mp_rsf_cat, variable = "celltype", plot_type = "ice", times = rsf_ranger_exp$times[1]) + plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "pdp+ice", times = rsf_ranger_exp$times[1]) + plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "pdp", times = rsf_ranger_exp$times[1]) + plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "ice", times = rsf_ranger_exp$times[1]) # multiple timepoints - plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "pdp+ice") - plot2(mp_rsf_cat, times = c(4, 80.7), marginalize_over_time = T, variable = "celltype", plot_type = "pdp+ice") - plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "pdp") - plot2(mp_rsf_cat, times = c(4, 80.7), variable = "celltype", plot_type = "ice") + plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "pdp+ice") + plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), marginalize_over_time = T, variables = "celltype", plot_type = "pdp+ice") + plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "pdp") + plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "ice") expect_s3_class(mp_rsf_cat, "model_profile_survival") @@ -76,11 +76,13 @@ test_that("model_profile with type = 'partial' works", { expect_true(all(unique(mp_rsf_num$result$`_vname_`) %in% colnames(rsf_ranger_exp$data))) expect_output(print(mp_cph_num)) - expect_warning(plot2(mp_rsf_cat, variable = "celltype", plot_type = "pdp+ice")) + expect_warning(plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "pdp+ice")) expect_error(plot(mp_rsf_num, variables = "nonexistent", grid_points = 6)) expect_error(model_profile(rsf_ranger_exp, type = "conditional")) - expect_error(plot2(mp_rsf_num, variable = "nonexistent")) - expect_error(plot2(mp_rsf_num, variable = "age", times = -1)) + expect_error(plot(mp_rsf_num, geom = "variable", variables = "nonexistent")) + expect_error(plot(mp_rsf_num, geom = "variable", variables = "age", times = -1)) + expect_error(plot(mp_rsf_num, geom = "nonexistent")) + expect_error(plot(mp_rsf_num, nonsense_argument = "character")) }) test_that("model_profile with type = 'accumulated' works", { @@ -102,20 +104,20 @@ test_that("model_profile with type = 'accumulated' works", { categorical_variables = "trt") plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") - ### Add tests for plot2 for categorical ALE + ### Add tests for plot for categorical ALE # single timepoint - plot2(mp_cph_cat, variable = "celltype", times=cph_exp$times[1]) + plot(mp_cph_cat, geom = "variable", variables = "celltype", times=cph_exp$times[1]) # multiple timepoints - plot2(mp_cph_cat, times = c(4, 80.7), variable = "celltype", plot_type = "ale") + plot(mp_cph_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "ale") expect_s3_class(mp_cph_cat, "model_profile_survival") expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) expect_equal(ncol(mp_cph_cat$result), 7) expect_true(all(unique(mp_cph_cat$result$`_vname_`) %in% colnames(cph_exp$data))) - expect_error(plot2(mp_cph_cat, variable = "celltype", plot_type = "pdp")) - expect_error(plot2(mp_cph_cat, variable = "celltype", plot_type = "nonexistent")) - expect_error(plot2(mp_cph_cat, variable = 1, plot_type = "nonexistent")) - expect_error(plot2(mp_cph_cat, variable = c("celltype", "trt"), plot_type = "nonexistent")) + expect_error(plot(mp_cph_cat, geom = "variable", variables = "celltype", plot_type = "pdp")) + expect_error(plot(mp_cph_cat, geom = "variable", variables = "celltype", plot_type = "nonexistent")) + expect_error(plot(mp_cph_cat, geom = "variable", variables = 1, plot_type = "nonexistent")) + expect_error(plot(mp_cph_cat, geom = "variable", variables = c("celltype", "trt"), plot_type = "nonexistent")) mp_cph_num <- model_profile(cph_exp, @@ -126,11 +128,11 @@ test_that("model_profile with type = 'accumulated' works", { plot(mp_cph_num, variable_type = "numerical") plot(mp_cph_num, numerical_plot_type = "contours") - ### Add tests for plot2 for numerical ALE + ### Add tests for plot for numerical ALE # single timepoint - plot2(mp_cph_num, variable = "karno", plot_type = "ale", times=cph_exp$times[1]) + plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "ale", times=cph_exp$times[1]) # multiple timepoints - plot2(mp_cph_num, times = c(4, 80.7), variable = "karno", plot_type = "ale") + plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "ale") expect_s3_class(mp_cph_num, "model_profile_survival") expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) diff --git a/vignettes/pdp.Rmd b/vignettes/pdp.Rmd index f417f822..9ee786ab 100644 --- a/vignettes/pdp.Rmd +++ b/vignettes/pdp.Rmd @@ -57,30 +57,30 @@ The `plot()` function can also be used to plot the explanations for a subset of plot(pdp, variables = c("karno"), numerical_plot_type = "contours") ``` -The plots above make use of the time dependent output of survival models, by placing the time dimension on the x-axis. However, for people familiar with Partial Dependence explanations in classification and regression, it might be more intuitive to place the variable values on the x-axis. For this reason, we provide the `plot2()` function, which can display the explanations without the aspect of time. +The plots above make use of the time dependent output of survival models, by placing the time dimension on the x-axis. However, for people familiar with Partial Dependence explanations in classification and regression, it might be more intuitive to place the variable values on the x-axis. For this reason, we provide the `geom = "variable"` argument, which can display the explanations without the aspect of time. To use this function a specific time of interest has to be chosen. This time needs to be one of the values in the `times` field of the explainer. If the automatically generated times do not contain the time of interest, one needs to manually specify the `times` argument when creating the explainer. The example below shows the PD explanations for the `karno` variable at time 80. The y-axis represents the mean prediction (survival function), x-axis represents the values of the studied variable. Thin background lines are individual ceteris paribus profiles (otherwise known as ICE profiles). ```{r} -plot2(pdp, variable = "karno", times = 80) +plot(pdp, geom = "variable", variables = "karno", times = 80) ``` The same plot can be generated for the categorical `celltype` variable. In this case the x-axis represents the different values of the studied variable, boxplots present the distribution of individual ceteris paribus profiles, and the line represents the mean prediction (survival function), which is the PD explanation. ```{r} -plot2(pdp, variable = "celltype", times = 80) +plot(pdp, geom = "variable", variables = "celltype", times = 80) ``` Of course, the plots can be prepared for multiple timepoints, at the same time and presented on one plot. ```{r} -plot2(pdp, variable = "karno", times = c(1, 80, 151.72)) +plot(pdp, geom = "variable", variables = "karno", times = c(1, 80, 151.72)) ``` ```{r} -plot2(pdp, variable = "celltype", times = c(1, 80, 151.72)) +plot(pdp, geom = "variable", variables = "celltype", times = c(1, 80, 151.72)) ``` From d19fdd8e95b36a998fb969d60f66ed71ceca2026 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Thu, 10 Aug 2023 14:36:06 +0200 Subject: [PATCH 122/207] Fix failing pkgdown build --- _pkgdown.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/_pkgdown.yml b/_pkgdown.yml index 1ae967af..a8b45479 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -31,7 +31,6 @@ reference: - subtitle: Model Profile - contents: - plot.model_profile_survival - - plot2.model_profile_survival - plot.model_profile_2d_survival - subtitle: Model SurvSHAP(t) - contents: plot.aggregated_surv_shap From 87583d3b76ca699faa5ed48777658763c7723c9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 11 Aug 2023 10:32:11 +0200 Subject: [PATCH 123/207] Fix `&` ->` &&` --- R/plot_model_profile_survival.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index f72bdb40..0e8b56c1 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -13,9 +13,9 @@ #' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately. Only used when `geom = "variable"` #' @param plot_type character, one of `"pdp"`, `"ice"`, `"pdp+ice"`, or `NULL` (default). If `NULL` then the type of plot is chosen automatically based on the number of variables to be plotted. Only used when `geom = "variable"` #' @param times numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If `NULL` (default) then the median time from the explainer object is used. Only used when `geom = "variable"` and `marginalize_over_time = FALSE -#' @param title character, title of the plot -#' @param subtitle character, subtitle of the plot, `"default"` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels -#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue"). +#' @param title character, title of the plot +#' @param subtitle character, subtitle of the plot, `"default"` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels +#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue"). #' @param rug character, one of `"all"`, `"events"`, `"censors"`, `"none"` or `NULL`. Which times to mark on the x axis in `geom_rug()`. Only used when `geom = "time"`. #' @param rug_colors character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. #' @@ -200,7 +200,7 @@ plot2 <- function(x, single_timepoint <- ((length(times) == 1) || marginalize_over_time) if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", unique(variable), " variable") - if (single_timepoint & !marginalize_over_time){ + if (single_timepoint && !marginalize_over_time){ subtitle <- paste0(subtitle, " and time=", times) } } From 96b3f3f7c09dcbfbcf89c5f45b97f3ea151abfac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 11 Aug 2023 10:32:33 +0200 Subject: [PATCH 124/207] Remove unnecessary `requireNamespace()` calls. --- DESCRIPTION | 2 +- NEWS.md | 17 +++++++++-------- R/surv_feature_importance.R | 12 ++++++++---- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 6a6555f1..545b8b09 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.0.0.9101 +Version: 1.0.0.9102 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), diff --git a/NEWS.md b/NEWS.md index 4f178cf6..3a649c47 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,13 +1,14 @@ # survex (development) -* Fix not being able to plot or print SurvLIME results for the cph model sometimes. ([#72](https://github.com/ModelOriented/survex/issues/72)) -* Add global explanations via the SurvSHAP(t) method (see `model_survshap()` function) -* Add plots for global SurvSHAP(t) explanations (see `plot.aggregated_surv_shap()`) -* Add Accumulated Local Effects (ALE) explanations (see `model_profile(..., type = "accumulated")`) -* Add 2-dimensional PDP and ALE plots (see `model_profile_2d()` function) -* Add `plot2()` function for plotting PDP and ALE explanations without the time dimension -* Improvement on the vignettes for the package (see `vignette("pdp")` and `vignette("global-survshap")`) -* Increase the test coverage of the pacakge. +* fixed not being able to plot or print SurvLIME results for the cph model sometimes. ([#72](https://github.com/ModelOriented/survex/issues/72)) +* added global explanations via the SurvSHAP(t) method (see `model_survshap()` function) +* added plots for global SurvSHAP(t) explanations (see `plot.aggregated_surv_shap()`) +* added Accumulated Local Effects (ALE) explanations (see `model_profile(..., type = "accumulated")`) +* added 2-dimensional PDP and ALE plots (see `model_profile_2d()` function) +* added `plot2()` function for plotting PDP and ALE explanations without the time dimension +* made improvements on the vignettes for the package (see `vignette("pdp")` and `vignette("global-survshap")`) +* increased the test coverage of the pacakge +* reduced the number of expensive `requireNamespace()` calls ([#83](https://github.com/ModelOriented/survex/issues/83)) # survex 1.0.0 diff --git a/R/surv_feature_importance.R b/R/surv_feature_importance.R index 9e156e71..28704f99 100644 --- a/R/surv_feature_importance.R +++ b/R/surv_feature_importance.R @@ -121,7 +121,11 @@ surv_feature_importance.default <- function(x, # start: actual calculations # one permutation round: subsample data, permute variables and compute losses - if (requireNamespace("progressr", quietly = TRUE)) prog <- progressr::progressor(along = 1:((length(variables) + 2) * B)) + if (requireNamespace("progressr", quietly = TRUE)) { + prog <- progressr::progressor(along = 1:((length(variables) + 2) * B)) + } else { + prog <- function() NULL + } sampled_rows <- 1:nrow(data) loss_after_permutation <- function() { if (!is.null(N)) { @@ -138,17 +142,17 @@ surv_feature_importance.default <- function(x, risk_true <- predict_function(x, sampled_data) # loss on the full model or when outcomes are permuted loss_full <- loss_function(observed, risk_true, surv_true, times) - if (requireNamespace("progressr", quietly = TRUE)) prog() + prog() chosen <- sample(1:nrow(observed)) loss_baseline <- loss_function(observed[chosen, ], risk_true, surv_true, times) - if (requireNamespace("progressr", quietly = TRUE)) prog() + prog() # loss upon dropping a single variable (or a single group) loss_variables <- sapply(variables, function(variables_set) { ndf <- sampled_data ndf[, variables_set] <- ndf[sample(1:nrow(ndf)), variables_set] predicted <- predict_function(x, ndf) predicted_surv <- predict_survival_function(x, ndf, times) - if (requireNamespace("progressr", quietly = TRUE)) prog() + prog() loss_function(observed, predicted, predicted_surv, times) }) From d5fd329d76893b9e076e1b7fa70313ee6cd4cd4b Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 11 Aug 2023 11:05:45 +0200 Subject: [PATCH 125/207] change sample for `model_survshap` --- R/model_survshap.R | 6 +++--- man/model_survshap.surv_explainer.Rd | 6 +++--- tests/testthat/test-model_survshap.R | 4 ++-- vignettes/global-survshap.Rmd | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/R/model_survshap.R b/R/model_survshap.R index fb5151d9..fbca5dea 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -32,9 +32,9 @@ #' #' ranger_global_survshap <- model_survshap( #' explainer = rsf_ranger_exp, -#' new_observation = veteran[1:40, !colnames(veteran) %in% c("time", "status")], -#' y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), -#' aggregation_method = "mean_absolute", +#' new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status")], +#' y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)]), +#' aggregation_method = "integral", #' calculation_method = "kernelshap", #' ) #' plot(ranger_global_survshap) diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 47bfbd27..4c3b74aa 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -60,9 +60,9 @@ rsf_ranger_exp <- explain( ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, - new_observation = veteran[1:40, !colnames(veteran) \%in\% c("time", "status")], - y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), - aggregation_method = "mean_absolute", + new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) \%in\% c("time", "status")], + y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)]), + aggregation_method = "integral", calculation_method = "kernelshap", ) plot(ranger_global_survshap) diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index 35be19e7..8e5626bc 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -16,8 +16,8 @@ test_that("global survshap explanations with kernelshap work for ranger, using n ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, - new_observation = veteran[c(1:3, 16:18, 111:113, 126:128), !colnames(veteran) %in% c("time", "status")], - y_true = Surv(veteran$time[c(1:3, 16:18, 111:113, 126:128)], veteran$status[c(1:3, 16:18, 111:113, 126:128)]), + new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status")], + y_true = Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)]), aggregation_method = "mean_absolute", calculation_method = "kernelshap" ) diff --git a/vignettes/global-survshap.Rmd b/vignettes/global-survshap.Rmd index 182ce301..bba0257d 100644 --- a/vignettes/global-survshap.Rmd +++ b/vignettes/global-survshap.Rmd @@ -34,7 +34,7 @@ exp <- explain(cph, data = vet[, -c(3,4)], y = Surv(vet$time, vet$status)) We use the explainer and the `model_survshap()` function to calculate SurvSHAP(t) explanations. We can specify the observations for which we want to calculate the explanations. In this example we calculate the explanations for the first 20 observations in the veteran dataset. **Note:** The background for generating SHAP values is the `data` field of the explainer! If you want to calculate explanations with a background that is not the training data, you need to manually specify the `data` argument, when creating the explainer. ```{r} -shap <- model_survshap(exp, veteran[1:20, -c(3,4)]) +shap <- model_survshap(exp, veteran[c(1:4, 17:20, 110:113, 126:129), -c(3,4)]) ``` We plot these explanations using the `plot.aggregated_surv_shap()` function. By default the plot features 2 panels, the one on the left depicts overall importance of variables. The panel on the right demonstates the time-dependent importance of each variable calculated as the mean absolute SHAP value at each time point across all observations. @@ -53,4 +53,4 @@ For `kind = "swarm"` a swarm plot is generated that shows the SHAP values for ea ```{r} plot(shap, kind = "swarm") -``` \ No newline at end of file +``` From 57eff0e29547a7ef8b6372ee04e76476d96cd39e Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 11 Aug 2023 11:05:59 +0200 Subject: [PATCH 126/207] fix jitter height in swarm plot --- R/plot_surv_shap.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 19c5e961..2bb199c1 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -253,7 +253,7 @@ plot_shap_global_swarm <- function(x, with(df, { ggplot(data = df, aes(x = shap_value, y = variable, color = var_value)) + geom_vline(xintercept = 0, color = "#ceced9", linetype="solid") + - geom_jitter(width=0) + + geom_jitter(width=0, height=0.15) + scale_color_gradient2( name = "Variable value", low = colors[1], From dc11395c86fa0a194f9f6553edd149000014373f Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 11 Aug 2023 11:57:17 +0200 Subject: [PATCH 127/207] fix too long lines in example --- R/model_survshap.R | 6 ++++-- man/model_survshap.surv_explainer.Rd | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/R/model_survshap.R b/R/model_survshap.R index fbca5dea..43624bce 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -32,8 +32,10 @@ #' #' ranger_global_survshap <- model_survshap( #' explainer = rsf_ranger_exp, -#' new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status")], -#' y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)]), +#' new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), +#' !colnames(veteran) %in% c("time", "status")], +#' y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], +#' veteran$status[c(1:4, 17:20, 110:113, 126:129)]), #' aggregation_method = "integral", #' calculation_method = "kernelshap", #' ) diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 4c3b74aa..3448eb19 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -60,8 +60,10 @@ rsf_ranger_exp <- explain( ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, - new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) \%in\% c("time", "status")], - y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)]), + new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), + !colnames(veteran) \%in\% c("time", "status")], + y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], + veteran$status[c(1:4, 17:20, 110:113, 126:129)]), aggregation_method = "integral", calculation_method = "kernelshap", ) From 1a033fda08a57ff2ea479102f79689011f1bd0fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Fri, 11 Aug 2023 12:39:55 +0200 Subject: [PATCH 128/207] Change progressr to steps --- R/surv_feature_importance.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/surv_feature_importance.R b/R/surv_feature_importance.R index 28704f99..1a70238d 100644 --- a/R/surv_feature_importance.R +++ b/R/surv_feature_importance.R @@ -122,7 +122,7 @@ surv_feature_importance.default <- function(x, # start: actual calculations # one permutation round: subsample data, permute variables and compute losses if (requireNamespace("progressr", quietly = TRUE)) { - prog <- progressr::progressor(along = 1:((length(variables) + 2) * B)) + prog <- progressr::progressor(steps = (length(variables) + 2) * B) } else { prog <- function() NULL } From fd7e04278c8756eb3de01f2e4232e125d4409c05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Krzyzi=C5=84ski?= Date: Thu, 17 Aug 2023 11:39:45 +0200 Subject: [PATCH 129/207] Update vignettes/pdp.Rmd Co-authored-by: sophhan <79644349+sophhan@users.noreply.github.com> --- vignettes/pdp.Rmd | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vignettes/pdp.Rmd b/vignettes/pdp.Rmd index 9ee786ab..26247d00 100644 --- a/vignettes/pdp.Rmd +++ b/vignettes/pdp.Rmd @@ -35,8 +35,8 @@ We use the explainer and the `model_profile()` function to calculate Partial Dep We can calculate Accumulated Local Effects in the same way, by setting the `type` argument to `"accumulated"`. ```{r} -pdp <- model_profile(exp, variables = c("karno", "celltype"), n = 20) -ale <- model_profile(exp, variables = c("karno", "celltype"), n = 20, type = "accumulated") +pdp <- model_profile(exp, variables = c("karno", "celltype"), N = 20) +ale <- model_profile(exp, variables = c("karno", "celltype"), N = 20, type = "accumulated") ``` To plot these explanations you can use the plot function. By default the explanations for all calculated variables are plotted. This example demonstrates this for the `pdp` object which contains the explanations for 2 variables. From 02023e3fdc6c5b625f658bead8f86d81e65bc6cf Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 17 Aug 2023 13:07:18 +0200 Subject: [PATCH 130/207] remove commented print --- R/surv_ceteris_paribus.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/surv_ceteris_paribus.R b/R/surv_ceteris_paribus.R index a6a211b1..afc5053d 100644 --- a/R/surv_ceteris_paribus.R +++ b/R/surv_ceteris_paribus.R @@ -189,7 +189,6 @@ calculate_variable_survival_profile.default <- function(data, variable_splits, m `_ids_` = rep(ids, each = length(times) * length(split_points)), check.names = FALSE ) - # print(table(ids)) prog() new_data }) From 660141b5657f65ec70578c5b7706add9a29b6cb9 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 17 Aug 2023 13:07:55 +0200 Subject: [PATCH 131/207] fix titles for model_profile plots --- R/plot_model_profile_survival.R | 43 +++++++++++++++++---------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 0e8b56c1..6afb3bcd 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -36,11 +36,11 @@ #' plot(m_prof, numerical_plot_type = "contours") #' #' plot(m_prof, variables = c("trt", "age"), facet_ncol = 1) -#' +#' #' plot(m_prof, geom = "variable", variables = "karno", plot_type = "pdp+ice") -#' +#' #' plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "karno", plot_type = "pdp+ice") -#' +#' #' plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "trt", plot_type = "pdp+ice") #' } #' @@ -65,6 +65,18 @@ plot.model_profile_survival <- function(x, stop("`geom` must be one of 'time' or 'survival'.") } + if (title == "default") { + if (x$type == "partial") { + title <- "Partial dependence survival profiles" + if (geom == "variable") { + title <- "default" + } + } + if (x$type == "accumulated") { + title <- "Accumulated local effects survival profiles" + } + } + if (geom == "variable") { pl <- plot2( @@ -87,16 +99,8 @@ plot.model_profile_survival <- function(x, } }) explanations_list <- c(list(x), list(...)) - + num_models <- length(explanations_list) - if (title == "default") { - if (x$type == "partial") { - title <- "Partial dependence survival profiles" - } - if (x$type == "accumulated") { - title <- "Accumulated local effects survival profiles" - } - } if (num_models == 1) { result <- prepare_model_profile_plots(x, @@ -155,6 +159,12 @@ plot2 <- function(x, } else if (x$type == "partial") plot_type <- "pdp+ice" } + if (plot_type == "ice") { + title <- "Individual conditional expectation survival profiles" + } else if (plot_type == "pdp+ice") { + title <- "Partial dependence with individual conditional expectation survival profiles" + } + if (x$type == "accumulated" && plot_type != "ale") { stop("For accumulated local effects explanations only plot_type = 'ale' is available") } @@ -188,15 +198,6 @@ plot2 <- function(x, )) } - if (title == "default") { - if (x$type == "partial") { - title <- "Partial dependence survival profiles" - } - if (x$type == "accumulated") { - title <- "Accumulated local effects survival profiles" - } - } - single_timepoint <- ((length(times) == 1) || marginalize_over_time) if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", unique(variable), " variable") From f02fc0124254a3301a1933cdc213dd9cd0362e5f Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 17 Aug 2023 13:08:14 +0200 Subject: [PATCH 132/207] add character variables as categorical ones --- R/surv_model_profiles.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index a8659369..0bf2b8de 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -111,8 +111,9 @@ surv_ale <- function(x, categorical_variables <- colnames(data)[categorical_variables] additional_categorical_variables <- categorical_variables factor_variables <- colnames(data)[sapply(data, is.factor)] + character_variables <- colnames(data)[sapply(data, is.character)] categorical_variables <- - unique(c(additional_categorical_variables, factor_variables)) + unique(c(additional_categorical_variables, character_variables, factor_variables)) model <- x$model label <- x$label @@ -259,7 +260,6 @@ surv_ale <- function(x, # Quantile points vector quantile_vec <- c(min(variable_values), quantile_vals) quantile_vec <- unique(quantile_vec) - quantile_df <- data.frame(id = 1:length(quantile_vec), value = quantile_vec) @@ -279,6 +279,7 @@ surv_ale <- function(x, predict_survival_function(model = model, newdata = X_lower, times = times) + predictions_upper <- predict_survival_function(model = model, newdata = X_upper, From d634f2acaabdf6215c4923683b2471c82fdd8095 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 17 Aug 2023 16:48:13 +0200 Subject: [PATCH 133/207] fix survival function predictions from censored --- R/explain.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/explain.R b/R/explain.R index 785ddc87..8e6559b2 100644 --- a/R/explain.R +++ b/R/explain.R @@ -707,7 +707,7 @@ explain.model_fit <- } if (is.null(predict_survival_function)) { - predict_survival_function <- function(model, newdata, times) { prediction <- predict(model, new_data = newdata, type = "survival", time = times )$.pred + predict_survival_function <- function(model, newdata, times) { prediction <- predict(model, new_data = newdata, type = "survival", eval_time = times )$.pred return_matrix <- t(sapply(prediction, function(x) x$.pred_survival)) return_matrix[is.na(return_matrix)] <- 0 return_matrix From b1a7b206990e25eecbbfae3f5608cda47420dced Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 17 Aug 2023 16:48:41 +0200 Subject: [PATCH 134/207] fix centering for model_profile --- R/model_profile.R | 15 +++++++++------ R/plot_model_profile_survival.R | 2 +- man/model_profile.surv_explainer.Rd | 8 ++++---- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/R/model_profile.R b/R/model_profile.R index 5f7be226..05022f4b 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -13,7 +13,7 @@ #' @param variable_splits_type character, decides how variable grids should be calculated. Use `"quantiles"` for percentiles or `"uniform"` (default) to get uniform grid of points. #' @param groups if `output_type == "risk"` a variable name that will be used for grouping. By default `NULL`, so no groups are calculated. If `output_type == "survival"` then ignored #' @param k passed to `DALEX::model_profile` if `output_type == "risk"`, otherwise ignored -#' @param center logical, should profiles be centered before clustering +#' @param center logical, should profiles be centered around the average prediction #' @param type the type of variable profile, `"partial"` for Partial Dependence, `"accumulated"` for Accumulated Local Effects, or `"conditional"` (available only for `output_type == "risk"`) #' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the `DALEX::model_profile` function. #' @@ -56,8 +56,8 @@ model_profile <- function(explainer, ..., groups = NULL, k = NULL, - center = FALSE, type = "partial", + center = FALSE, output_type = "survival") UseMethod("model_profile", explainer) #' @rdname model_profile.surv_explainer @@ -71,7 +71,7 @@ model_profile.surv_explainer <- function(explainer, variable_splits_type = "uniform", groups = NULL, k = NULL, - center = TRUE, + center = FALSE, type = "partial", output_type = "survival") { @@ -101,10 +101,11 @@ model_profile.surv_explainer <- function(explainer, categorical_variables = categorical_variables, grid_points = grid_points, variable_splits_type = variable_splits_type, + center = center, ...) + result <- surv_aggregate_profiles(cp_profiles, ..., - variables = variables, - center = center) + variables = variables) } else if (type == "accumulated"){ cp_profiles <- list(variable_values = data.frame(ndata)) result <- surv_ale(explainer, @@ -112,6 +113,7 @@ model_profile.surv_explainer <- function(explainer, variables = variables, categorical_variables = categorical_variables, grid_points = grid_points, + center = center, ...) } else { stop("Currently only `partial` and `accumulated` types are implemented") @@ -120,7 +122,8 @@ model_profile.surv_explainer <- function(explainer, ret <- list(eval_times = unique(result$`_times_`), cp_profiles = cp_profiles, result = result, - type = type) + type = type, + center = center) class(ret) <- c("model_profile_survival", "list") ret$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] ret$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 6afb3bcd..8e97cf11 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -474,7 +474,7 @@ prepare_model_profile_plots <- function(x, aggregated_profiles$`_real_point_` <- FALSE - pl <- plot_individual_ceteris_paribus_survival(aggregated_profiles, variables, colors, numerical_plot_type, rug_df, rug, rug_colors) + pl <- plot_individual_ceteris_paribus_survival(aggregated_profiles, variables, colors, numerical_plot_type, rug_df, rug, rug_colors, x$center) patchwork::wrap_plots(pl, ncol = facet_ncol) + patchwork::plot_annotation( diff --git a/man/model_profile.surv_explainer.Rd b/man/model_profile.surv_explainer.Rd index c8d54d4e..6d56bdef 100644 --- a/man/model_profile.surv_explainer.Rd +++ b/man/model_profile.surv_explainer.Rd @@ -12,8 +12,8 @@ model_profile( ..., groups = NULL, k = NULL, - center = FALSE, type = "partial", + center = FALSE, output_type = "survival" ) @@ -27,7 +27,7 @@ model_profile( variable_splits_type = "uniform", groups = NULL, k = NULL, - center = TRUE, + center = FALSE, type = "partial", output_type = "survival" ) @@ -45,10 +45,10 @@ model_profile( \item{k}{passed to \code{DALEX::model_profile} if \code{output_type == "risk"}, otherwise ignored} -\item{center}{logical, should profiles be centered before clustering} - \item{type}{the type of variable profile, \code{"partial"} for Partial Dependence, \code{"accumulated"} for Accumulated Local Effects, or \code{"conditional"} (available only for \code{output_type == "risk"})} +\item{center}{logical, should profiles be centered around the average prediction} + \item{output_type}{either \code{"survival"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the \code{DALEX::model_profile} function.} \item{categorical_variables}{character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the \code{variables} argument, they will be added at the end.} From cc5cc084b8e58eee51b7bcd60a4b2d4408ffc4f5 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 17 Aug 2023 16:49:28 +0200 Subject: [PATCH 135/207] centering for ceteris paribus --- R/plot_surv_ceteris_paribus.R | 24 +++++++++++++----------- R/predict_profile.R | 10 ++++++++-- R/surv_ceteris_paribus.R | 21 ++++++++++++++++----- man/predict_profile.surv_explainer.Rd | 8 ++++++-- man/surv_ceteris_paribus.Rd | 1 + 5 files changed, 44 insertions(+), 20 deletions(-) diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index 4b180202..85b6925f 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -54,13 +54,11 @@ plot.surv_ceteris_paribus <- function(x, rug_colors = c("#dd0000", "#222222")) { if (!is.null(variable_type)) check_variable_type(variable_type) - check_numerical_plot_type(numerical_plot_type) explanations_list <- c(list(x), list(...)) num_models <- length(explanations_list) - if (num_models == 1){ result <- prepare_ceteris_paribus_plots(x, colors, @@ -112,11 +110,12 @@ prepare_ceteris_paribus_plots <- function(x, subtitle = "default", rug = "all", rug_colors = c("#dd0000", "#222222")){ - rug_df <- data.frame(times = x$event_times, statuses = as.character(x$event_statuses), label = unique(x$result$`_label_`)) - obs <- x$variable_values + obs <- as.data.frame(x$variable_values) + center <- x$center x <- x$result + all_profiles <- x class(all_profiles) <- "data.frame" @@ -161,11 +160,9 @@ prepare_ceteris_paribus_plots <- function(x, "_yhat_", "_label_", "_ids_")] - key <- obs[, sv, drop = FALSE] - tmp$`_real_point_` <- - tmp[, sv] == key[as.character(tmp$`_ids_`), sv] + tmp$`_real_point_` <- tmp[, sv] == key[, sv] colnames(tmp)[1] <- "_x_" tmp$`_x_` <- as.character(tmp$`_x_`) @@ -181,7 +178,8 @@ prepare_ceteris_paribus_plots <- function(x, numerical_plot_type = numerical_plot_type, rug_df = rug_df, rug = rug, - rug_colors = rug_colors) + rug_colors = rug_colors, + center = center) patchwork::wrap_plots(pl, ncol = facet_ncol) + patchwork::plot_annotation(title = title, @@ -196,7 +194,9 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, numerical_plot_type, rug_df, rug, - rug_colors) { + rug_colors, + center) { + pl <- lapply(variables, function(var) { df <- all_profiles[all_profiles$`_vname_` == var, ] @@ -228,10 +228,11 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, ) + geom_line(data = df[df$`_real_point_`, ], color = "red", linewidth = 0.8) + - xlab("") + ylab("survival function value") + ylim(c(0, 1)) + xlim(c(0,NA))+ + xlab("") + ylab("survival function value") + xlim(c(0,NA))+ theme_default_survex() + facet_wrap(~`_vname_`) }) + if (!center) base_plot <- base_plot + ylim(c(0, 1)) } else { base_plot <- with(df, { ggplot( @@ -277,8 +278,9 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, scale_color_manual(name = paste0(unique(df$`_vname_`), " value"), values = generate_discrete_color_scale(n_colors, colors)) + theme_default_survex() + - xlab("") + ylab("survival function value") + ylim(c(0, 1)) + xlim(c(0,NA))+ + xlab("") + ylab("survival function value") + xlim(c(0,NA))+ facet_wrap(~`_vname_`) }) + if (!center) base_plot <- base_plot + ylim(c(0, 1)) } return_plot <- add_rug_to_plot(base_plot, rug_df, rug, rug_colors) diff --git a/R/predict_profile.R b/R/predict_profile.R index 0a3146b6..2e6e7a5d 100644 --- a/R/predict_profile.R +++ b/R/predict_profile.R @@ -10,6 +10,7 @@ #' @param type character, only `"ceteris_paribus"` is implemented #' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the `DALEX::predict_profile` function. #' @param variable_splits_type character, decides how variable grids should be calculated. Use `"quantiles"` for percentiles or `"uniform"` (default) to get uniform grid of points. +#' @param center logical, should profiles be centered around the average prediction #' #' @return An object of class `c("predict_profile_survival", "surv_ceteris_paribus")`. It is a list with the final result in the `result` element. #' @@ -43,7 +44,8 @@ predict_profile <- function(explainer, categorical_variables = NULL, ..., type = "ceteris_paribus", - variable_splits_type = "uniform") + variable_splits_type = "uniform", + center = FALSE) UseMethod("predict_profile", explainer) #' @rdname predict_profile.surv_explainer @@ -55,12 +57,15 @@ predict_profile.surv_explainer <- function(explainer, ..., type = "ceteris_paribus", output_type = "survival", - variable_splits_type = "uniform") + variable_splits_type = "uniform", + center = FALSE) { variables <- unique(variables, categorical_variables) if (!type %in% "ceteris_paribus") stop("Type not supported") if (!output_type %in% c("risk", "survival")) stop("output_type not supported") + if (length(dim(new_observation)) != 2 & nrow(new_observation) != 1) + stop("new_observation should be a single row data.frame") if (output_type == "risk") { return(DALEX::predict_profile(explainer = explainer, @@ -78,6 +83,7 @@ predict_profile.surv_explainer <- function(explainer, variables = variables, categorical_variables = categorical_variables, variable_splits_type = variable_splits_type, + center = center, ...) class(res) <- c("predict_profile_survival", class(res)) res$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] diff --git a/R/surv_ceteris_paribus.R b/R/surv_ceteris_paribus.R index afc5053d..f8dbd057 100644 --- a/R/surv_ceteris_paribus.R +++ b/R/surv_ceteris_paribus.R @@ -26,9 +26,9 @@ surv_ceteris_paribus.surv_explainer <- function(x, variable_splits = NULL, grid_points = 101, variable_splits_type = "uniform", + center = FALSE, ...) { test_explainer(x, has_data = TRUE, has_survival = TRUE, has_y = TRUE, function_name = "ceteris_paribus_survival") - data <- x$data model <- x$model label <- x$label @@ -46,6 +46,7 @@ surv_ceteris_paribus.surv_explainer <- function(x, grid_points = grid_points, variable_splits_type = variable_splits_type, variable_splits_with_obs = TRUE, + center = center, label = label, times = times, ... @@ -62,6 +63,7 @@ surv_ceteris_paribus.default <- function(x, grid_points = 101, variable_splits_type = "uniform", variable_splits_with_obs = TRUE, + center = center, label = NULL, times = times, ...) { @@ -97,8 +99,10 @@ surv_ceteris_paribus.default <- function(x, new_observation, variable_splits, x, + center, predict_survival_function, - times + times, + ... ) profiles$`_vtype_` <- ifelse(profiles$`_vname_` %in% categorical_variables, "categorical", "numerical") @@ -113,7 +117,8 @@ surv_ceteris_paribus.default <- function(x, ret <- list( eval_times = times, variable_values = new_observation, - result = cbind(profiles, `_label_` = label) + result = cbind(profiles, `_label_` = label), + center = center ) class(ret) <- c("surv_ceteris_paribus", "list") @@ -160,11 +165,11 @@ calculate_variable_split.default <- function(data, variables = colnames(data), c -calculate_variable_survival_profile <- function(data, variable_splits, model, predict_survival_function = NULL, times = NULL, ...) { +calculate_variable_survival_profile <- function(data, variable_splits, model, center, predict_survival_function = NULL, times = NULL, ...) { UseMethod("calculate_variable_survival_profile") } -calculate_variable_survival_profile.default <- function(data, variable_splits, model, predict_survival_function = NULL, times = NULL, ...) { +calculate_variable_survival_profile.default <- function(data, variable_splits, model, center, predict_survival_function = NULL, times = NULL, ...) { variables <- names(variable_splits) prog <- progressr::progressor(along = 1:(length(variables))) @@ -174,6 +179,9 @@ calculate_variable_survival_profile.default <- function(data, variable_splits, m ids <- rownames(data) } + predictions_original <- predict_survival_function(model, data, times) + mean_pred <- colMeans(predictions_original) + profiles <- lapply(variables, function(variable) { split_points <- variable_splits[[variable]] @@ -181,6 +189,9 @@ calculate_variable_survival_profile.default <- function(data, variable_splits, m new_data[, variable] <- rep(split_points, nrow(data)) yhat <- c(t(predict_survival_function(model, new_data, times))) + if (center){ + yhat <- yhat - mean_pred + } new_data <- data.frame(new_data[rep(seq_len(nrow(new_data)), each = length(times)), ], `_times_` = rep(times, times = nrow(new_data)), diff --git a/man/predict_profile.surv_explainer.Rd b/man/predict_profile.surv_explainer.Rd index 2d09d05b..9818a637 100644 --- a/man/predict_profile.surv_explainer.Rd +++ b/man/predict_profile.surv_explainer.Rd @@ -12,7 +12,8 @@ predict_profile( categorical_variables = NULL, ..., type = "ceteris_paribus", - variable_splits_type = "uniform" + variable_splits_type = "uniform", + center = FALSE ) \method{predict_profile}{surv_explainer}( @@ -23,7 +24,8 @@ predict_profile( ..., type = "ceteris_paribus", output_type = "survival", - variable_splits_type = "uniform" + variable_splits_type = "uniform", + center = FALSE ) } \arguments{ @@ -41,6 +43,8 @@ predict_profile( \item{variable_splits_type}{character, decides how variable grids should be calculated. Use \code{"quantiles"} for percentiles or \code{"uniform"} (default) to get uniform grid of points.} +\item{center}{logical, should profiles be centered around the average prediction} + \item{output_type}{either \code{"survival"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the \code{DALEX::predict_profile} function.} } \value{ diff --git a/man/surv_ceteris_paribus.Rd b/man/surv_ceteris_paribus.Rd index e202abbd..da5d1e70 100644 --- a/man/surv_ceteris_paribus.Rd +++ b/man/surv_ceteris_paribus.Rd @@ -15,6 +15,7 @@ surv_ceteris_paribus(x, ...) variable_splits = NULL, grid_points = 101, variable_splits_type = "uniform", + center = FALSE, ... ) } From 92134472d1efc2cddd41dc9532e09f296f102237 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 11:28:05 +0200 Subject: [PATCH 136/207] improve contour plots and labs --- R/plot_surv_ceteris_paribus.R | 53 ++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index 85b6925f..15b3ee46 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -228,11 +228,13 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, ) + geom_line(data = df[df$`_real_point_`, ], color = "red", linewidth = 0.8) + - xlab("") + ylab("survival function value") + xlim(c(0,NA))+ + xlab("") + ylab("centered profile value") + xlim(c(0,NA))+ theme_default_survex() + facet_wrap(~`_vname_`) }) - if (!center) base_plot <- base_plot + ylim(c(0, 1)) + if (!center) { + base_plot <- base_plot + ylim(c(0, 1)) + ylab("survival function value") + } } else { base_plot <- with(df, { ggplot( @@ -242,16 +244,37 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, y = as.numeric(as.character(`_x_`)), z = `_yhat_` ) - ) + - geom_contour_filled(breaks = seq(1, 0, -0.1)) + - scale_fill_manual(name = "SF value", values = grDevices::colorRampPalette(c("#c7f5bf", "#8bdcbe", "#46bac2", "#4378bf", "#371ea3"))(10), - labels = seq(1, 0, -0.1)) + - guides(fill = guide_legend(nrow = 1, label.position = "top")) + - xlab("") + ylab("variable value") + xlim(c(0,NA))+ - theme_default_survex() + - theme(legend.spacing = grid::unit(0.1, 'line')) + - facet_wrap(~`_vname_`) - }) + ) }) + if (!center){ + base_plot <- base_plot + + geom_contour_filled(binwidth=0.1) + + scale_fill_manual(name = "SF value", + values = grDevices::colorRampPalette(c("#c7f5bf", "#8bdcbe", "#46bac2", "#4378bf", "#371ea3"))(10), + drop = FALSE) + + guides(fill = guide_colorsteps(direction = "horizontal", + barwidth = 0.5*unit(par("pin")[1], "in"), + barheight = 0.02*unit(par("pin")[2], "in"), + reverse = TRUE, + show.limits = TRUE)) + + xlab("") + ylab("variable value") + xlim(c(0,NA)) + + theme_default_survex() + + facet_wrap(~`_vname_`) + } else { + base_plot <- base_plot + + geom_contour_filled(bins=10) + + scale_fill_manual(name = "centered profile value", + values = grDevices::colorRampPalette(c("#c7f5bf", "#8bdcbe", "#46bac2", "#4378bf", "#371ea3"))(10), + drop = FALSE) + + guides(fill = guide_colorsteps(direction = "horizontal", + barwidth = 0.5*unit(par("pin")[1], "in"), + barheight = 0.02*unit(par("pin")[2], "in"), + reverse = TRUE, + show.limits = TRUE, + label.theme = element_text(size=7))) + + xlab("") + ylab("variable value") + xlim(c(0,NA)) + + theme_default_survex() + + facet_wrap(~`_vname_`) + } if (any(df$`_real_point_`)) { range_time <- range(df["_times_"]) var_val <- as.numeric(unique(df[df$`_real_point_`, "_x_"])) @@ -278,9 +301,11 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, scale_color_manual(name = paste0(unique(df$`_vname_`), " value"), values = generate_discrete_color_scale(n_colors, colors)) + theme_default_survex() + - xlab("") + ylab("survival function value") + xlim(c(0,NA))+ + xlab("") + ylab("centered profle value") + xlim(c(0,NA))+ facet_wrap(~`_vname_`) }) - if (!center) base_plot <- base_plot + ylim(c(0, 1)) + if (!center) { + base_plot <- base_plot + ylim(c(0, 1)) + ylab("survival function value") + } } return_plot <- add_rug_to_plot(base_plot, rug_df, rug, rug_colors) From c448622b121884444c7b0a625f3ffe708fe4860e Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 11:28:30 +0200 Subject: [PATCH 137/207] change error messages --- R/plot_model_profile_2d.R | 4 ++-- R/plot_model_profile_survival.R | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index a207d77c..362df684 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -116,12 +116,12 @@ prepare_model_profile_2d_plots <- function(x, ){ if (is.null(times)) { times <- quantile(x$eval_times, p = 0.5, type = 1) - warning("Plot will be prepared for the median time point from the `times` vector. For another time point, set the value of `times`.") + warning("Plot will be prepared for the median time point from the explainer's `times` vector. For another time point, set the value of `times`.") } if (!marginalize_over_time && length(times) > 1) { times <- times[1] - warning("Plot will be prepared for the first time point in the `times` vector. For aggregation over time, set the option `marginalize_over_time = TRUE`.") + warning("Plot will be prepared for the first time point in the provided `times` vector. For aggregation over time, set the option `marginalize_over_time = TRUE`.") } if (!all(times %in% x$eval_times)) { diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 8e97cf11..7112c1e9 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -90,6 +90,9 @@ plot.model_profile_survival <- function(x, subtitle = subtitle, colors = colors ) + if (x$center) { + pl <- pl + labs(y = "centered profile value") + } return(pl) } @@ -187,7 +190,7 @@ plot2 <- function(x, if (is.null(times)) { times <- quantile(x$eval_times, p = 0.5, type = 1) - warning("Plot will be prepared for the median time point from the `times` vector. For another time point, set the value of `times`.") + warning("Plot will be prepared for the median time point from the explainer's `times` vector. For another time point, set the value of `times`.") } if (!all(times %in% x$eval_times)) { From 48b547f50cc99a1d0a70af434466931c5f3969ed Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 12:14:36 +0200 Subject: [PATCH 138/207] add import from graphics --- NAMESPACE | 1 + R/plot_surv_ceteris_paribus.R | 1 + 2 files changed, 2 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 8666414f..65564719 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -82,6 +82,7 @@ importFrom(DALEX,theme_drwhy) importFrom(DALEX,theme_drwhy_vertical) importFrom(DALEX,theme_ema) importFrom(DALEX,theme_ema_vertical) +importFrom(graphics,par) importFrom(stats,aggregate) importFrom(stats,as.formula) importFrom(stats,ave) diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index 15b3ee46..33198771 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -188,6 +188,7 @@ prepare_ceteris_paribus_plots <- function(x, } #' @import ggplot2 +#' @importFrom graphics par plot_individual_ceteris_paribus_survival <- function(all_profiles, variables, colors, From 9ec52f696db2166027fd57135aa817904ff0a9fc Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 12:14:56 +0200 Subject: [PATCH 139/207] fix centering 2D model profiles --- R/model_profile_2d.R | 28 +++++++++++++++++++------- man/model_profile_2d.surv_explainer.Rd | 4 ++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index 3b786a8d..f8382f7e 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -9,7 +9,7 @@ #' @param N number of observations used for the calculation of aggregated profiles. By default `100`. If `NULL` all observations are used. #' @param categorical_variables character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the `variables` argument, they will be added at the end. #' @param grid_points maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default `25`. -#' @param center logical, should profiles be centered at 0 +#' @param center logical, should profiles be centered around the average prediction #' @param variable_splits_type character, decides how variable grids should be calculated. Use `"quantiles"` for quantiles or `"uniform"` (default) to get uniform grid of points. Used only if `type = "partial"`. #' @param type the type of variable profile, `"partial"` for Partial Dependence or `"accumulated"` for Accumulated Local Effects #' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. Currently only `"survival"` is available. @@ -44,7 +44,7 @@ model_profile_2d <- function(explainer, N = 100, categorical_variables = NULL, grid_points = 25, - center = TRUE, + center = FALSE, variable_splits_type = "uniform", type = "partial", output_type = "survival") @@ -58,7 +58,7 @@ model_profile_2d.surv_explainer <- function(explainer, N = 100, categorical_variables = NULL, grid_points = 25, - center = TRUE, + center = FALSE, variable_splits_type = "uniform", type = "partial", output_type = "survival" @@ -95,7 +95,8 @@ model_profile_2d.surv_explainer <- function(explainer, variables = variables, categorical_variables = categorical_variables, grid_points = grid_points, - variable_splits_type = variable_splits_type + variable_splits_type = variable_splits_type, + center = center ) } else if (type == "accumulated") { result <- surv_ale_2d( @@ -125,7 +126,8 @@ surv_pdp_2d <- function(x, variables, categorical_variables, grid_points, - variable_splits_type + variable_splits_type, + center ) { model <- x$model label <- x$label @@ -148,9 +150,21 @@ surv_pdp_2d <- function(x, names(expanded_data)[colnames(expanded_data) == "x"] <- var2 expanded_data <- expanded_data[,colnames(data)] + + predictions_original <- predict_survival_function(model = model, + newdata = data, + times = times) + mean_pred <- colMeans(predictions_original) + predictions <- predict_survival_function(model = model, newdata = expanded_data, times = times) + + preds <- c(t(predictions)) + if (center) { + preds <- preds - mean_pred + } + res <- data.frame( "_v1name_" = var1, "_v2name_" = var2, @@ -159,7 +173,7 @@ surv_pdp_2d <- function(x, "_v1value_" = as.character(rep(expanded_data[,var1], each=length(times))), "_v2value_" = as.character(rep(expanded_data[,var2], each=length(times))), "_times_" = rep(times, nrow(expanded_data)), - "_yhat_" = c(t(predictions)), + "_yhat_" = preds, "_label_" = label, check.names = FALSE ) @@ -363,7 +377,7 @@ surv_ale_2d_num_num <- function(model, ale <- merge(ale, ale1, by = c("time", "interval1")) ale <- merge(ale, ale2, by = c("time", "interval2")) ale$ale <- ale$yhat_cumsum - ale$ale1 - ale$ale2 - ale$fJ0 - ale <- ale[order(ale$time, ale$interval1, ale$interval2),] + ale <- ale[order(ale$interval1, ale$interval2, ale$time),] if (!center){ ale$ale <- ale$ale + mean_pred diff --git a/man/model_profile_2d.surv_explainer.Rd b/man/model_profile_2d.surv_explainer.Rd index 6eaf1e27..ce04ced3 100644 --- a/man/model_profile_2d.surv_explainer.Rd +++ b/man/model_profile_2d.surv_explainer.Rd @@ -11,7 +11,7 @@ model_profile_2d( N = 100, categorical_variables = NULL, grid_points = 25, - center = TRUE, + center = FALSE, variable_splits_type = "uniform", type = "partial", output_type = "survival" @@ -40,7 +40,7 @@ model_profile_2d( \item{grid_points}{maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default \code{25}.} -\item{center}{logical, should profiles be centered at 0} +\item{center}{logical, should profiles be centered around the average prediction} \item{variable_splits_type}{character, decides how variable grids should be calculated. Use \code{"quantiles"} for quantiles or \code{"uniform"} (default) to get uniform grid of points. Used only if \code{type = "partial"}.} From 562d7e3ad58d8510b0c5241d648cfa62373430ea Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 12:50:22 +0200 Subject: [PATCH 140/207] change default value of center to FALSE --- man/model_profile_2d.surv_explainer.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/model_profile_2d.surv_explainer.Rd b/man/model_profile_2d.surv_explainer.Rd index ce04ced3..56cf0e44 100644 --- a/man/model_profile_2d.surv_explainer.Rd +++ b/man/model_profile_2d.surv_explainer.Rd @@ -23,7 +23,7 @@ model_profile_2d( N = 100, categorical_variables = NULL, grid_points = 25, - center = TRUE, + center = FALSE, variable_splits_type = "uniform", type = "partial", output_type = "survival" From 3e9e09c29494d6c3eac0a237fd785a95cafa3d7e Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 21:34:24 +0200 Subject: [PATCH 141/207] add x axis labels --- R/plot_surv_ceteris_paribus.R | 12 ++++++++---- R/plot_surv_feature_importance.R | 6 ++---- R/plot_surv_lime.R | 5 ++--- R/plot_surv_model_performance.R | 10 +++------- R/plot_surv_shap.R | 8 ++------ 5 files changed, 17 insertions(+), 24 deletions(-) diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index 33198771..b432ba88 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -229,7 +229,8 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, ) + geom_line(data = df[df$`_real_point_`, ], color = "red", linewidth = 0.8) + - xlab("") + ylab("centered profile value") + xlim(c(0,NA))+ + labs(x = "time", y = "centered profile value") + + xlim(c(0,NA)) + theme_default_survex() + facet_wrap(~`_vname_`) }) @@ -257,7 +258,8 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, barheight = 0.02*unit(par("pin")[2], "in"), reverse = TRUE, show.limits = TRUE)) + - xlab("") + ylab("variable value") + xlim(c(0,NA)) + + labs(x = "time", y = "variable value") + + xlim(c(0,NA)) + theme_default_survex() + facet_wrap(~`_vname_`) } else { @@ -272,7 +274,8 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, reverse = TRUE, show.limits = TRUE, label.theme = element_text(size=7))) + - xlab("") + ylab("variable value") + xlim(c(0,NA)) + + labs(x = "time", y = "variable value") + + xlim(c(0,NA)) + theme_default_survex() + facet_wrap(~`_vname_`) } @@ -302,7 +305,8 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, scale_color_manual(name = paste0(unique(df$`_vname_`), " value"), values = generate_discrete_color_scale(n_colors, colors)) + theme_default_survex() + - xlab("") + ylab("centered profle value") + xlim(c(0,NA))+ + labs(x = "time", y = "centered profile value") + + xlim(c(0,NA))+ facet_wrap(~`_vname_`) }) if (!center) { base_plot <- base_plot + ylim(c(0, 1)) + ylab("survival function value") diff --git a/R/plot_surv_feature_importance.R b/R/plot_surv_feature_importance.R index c0ec5a6b..3c3a98c4 100644 --- a/R/plot_surv_feature_importance.R +++ b/R/plot_surv_feature_importance.R @@ -88,11 +88,9 @@ plot.surv_feature_importance <- function(x, ..., ggplot(data = plotting_df, aes(x = `_times_`, y = values, color = ind, label = ind)) + geom_line(linewidth = 0.8) + theme_default_survex() + - xlab("") + - ylab(y_lab) + - xlim(c(0,NA))+ + labs(x = "time", y = y_lab, title = title, subtitle = subtitle) + + xlim(c(0,NA)) + scale_color_manual(name = "Variable", values = c("#000000", generate_discrete_color_scale(num_vars, colors))) + - labs(title = title, subtitle = subtitle) + facet_wrap(~label) }) diff --git a/R/plot_surv_lime.R b/R/plot_surv_lime.R index 520be24f..99f958c0 100644 --- a/R/plot_surv_lime.R +++ b/R/plot_surv_lime.R @@ -90,9 +90,8 @@ plot.surv_lime <- function(x, ggplot(data = sf_df, aes(x = times, y = sfs, group = type, color = type)) + geom_line(linewidth = 0.8) + theme_default_survex() + - xlab("") + - xlim(c(0,NA))+ - ylab("survival function value") + + labs(x = "time", y = "survival function value") + + xlim(c(0,NA)) + scale_color_manual("", values = generate_discrete_color_scale(2, colors)) }) return(patchwork::wrap_plots(pl, pl2, nrow = 1, widths = c(3, 5))) diff --git a/R/plot_surv_model_performance.R b/R/plot_surv_model_performance.R index 7ddc6f22..92cc0f58 100644 --- a/R/plot_surv_model_performance.R +++ b/R/plot_surv_model_performance.R @@ -73,10 +73,8 @@ plot_td_surv_model_performance <- function(x, ..., metrics = NULL, title = NULL, ggplot(data = df[df$ind %in% metrics, ], aes(x = times, y = values, group = label, color = label)) + geom_line(linewidth = 0.8) + theme_default_survex() + - xlab("") + - ylab("metric value") + - xlim(c(0,NA))+ - labs(title = title, subtitle = subtitle) + + labs(x = "time", y = "metric value", title = title, subtitle = subtitle) + + xlim(c(0,NA)) + scale_color_manual("", values = generate_discrete_color_scale(num_colors, colors)) + facet_wrap(~ind, ncol = facet_ncol, scales = "free_y") }) @@ -102,9 +100,7 @@ plot_scalar_surv_model_performance <- function(x, ..., metrics = NULL, title = N ggplot(data = df, aes(x = label, y = values, fill = label)) + geom_col() + theme_default_survex() + - xlab("") + - ylab("metric value") + - labs(title = title, subtitle = subtitle) + + labs(x = "model", y = "metric value", title = title, subtitle = subtitle) + scale_fill_manual("", values = generate_discrete_color_scale(num_colors, colors)) + facet_wrap(~ind, ncol = facet_ncol, scales = "free_y") }) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 2bb199c1..d03b7328 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -70,15 +70,11 @@ plot.surv_shap <- function(x, n_colors <- length(unique(long_df$ind)) - y_lab <- "SurvSHAP(t) value" - - base_plot <- with(long_df, { ggplot(data = long_df, aes(x = times, y = values, color = ind)) + geom_line(linewidth = 0.8) + - ylab(y_lab) + xlab("") + - xlim(c(0,NA))+ - labs(title = title, subtitle = subtitle) + + labs(x = "time", y = "SurvSHAP(t) value", title = title, subtitle = subtitle) + + xlim(c(0,NA)) + scale_color_manual("variable", values = generate_discrete_color_scale(n_colors, colors)) + theme_default_survex() + facet_wrap(~label, ncol = 1, scales = "free_y") From 2333c7b6a3f821ed2e44355b8f27e7f5eae1b990 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 22:05:00 +0200 Subject: [PATCH 142/207] add new engine from censored --- R/explain.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/explain.R b/R/explain.R index 8e6559b2..d60b0d82 100644 --- a/R/explain.R +++ b/R/explain.R @@ -728,7 +728,7 @@ explain.model_fit <- } if (is.null(predict_function)) { - if (model$spec$engine %in% c("mboost", "survival", "glmnet", "flexsurv")){ + if (model$spec$engine %in% c("mboost", "survival", "glmnet", "flexsurv", "flexsurvspline")){ predict_function <- function(model, newdata, times) predict(model, new_data = newdata, type = "linear_pred")$.pred_linear_pred attr(predict_function, "verbose_info") <- "predict.model_fit with type = 'linear_pred' will be used" } else { From 917973342044054f9c5e640bd7fd51c82095a366 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 22:05:12 +0200 Subject: [PATCH 143/207] fix condition for error --- R/predict_profile.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/predict_profile.R b/R/predict_profile.R index 2e6e7a5d..8ea935e6 100644 --- a/R/predict_profile.R +++ b/R/predict_profile.R @@ -64,7 +64,7 @@ predict_profile.surv_explainer <- function(explainer, variables <- unique(variables, categorical_variables) if (!type %in% "ceteris_paribus") stop("Type not supported") if (!output_type %in% c("risk", "survival")) stop("output_type not supported") - if (length(dim(new_observation)) != 2 & nrow(new_observation) != 1) + if (length(dim(new_observation)) != 2 | nrow(new_observation) != 1) stop("new_observation should be a single row data.frame") if (output_type == "risk") { From f4b574eeaf4ddcafb52a727d5572b51f15c52390 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 22:05:27 +0200 Subject: [PATCH 144/207] add tests for centered profiles --- tests/testthat/test-predict_profile.R | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-predict_profile.R b/tests/testthat/test-predict_profile.R index a6884261..e28b4fdd 100644 --- a/tests/testthat/test-predict_profile.R +++ b/tests/testthat/test-predict_profile.R @@ -22,12 +22,19 @@ test_that("ceteris_paribus works", { expect_error(plot(cph_pp, variable_type = "nonexistent")) expect_error(plot(cph_pp, numerical_plot_type = "nonexistent")) + expect_error(predict_profile(cph_exp, veteran[2, -c(3, 4)], output_type = "nonexistent")) + expect_error(predict_profile(cph_exp, veteran[2:3, -c(3, 4)])) + expect_error(predict_profile(cph_exp, veteran[2, -c(3, 4)], type = "nonexistent")) + + cph_pp_centered <- predict_profile(cph_exp, veteran[2, -c(3, 4)], center = TRUE) + plot(cph_pp_centered) + plot(cph_pp_centered, numerical_plot_type = "contours") cph_pp_cat <- predict_profile(cph_exp, veteran[2, -c(3, 4)], variables = c("celltype")) + plot(predict_profile(cph_exp, veteran[2, -c(3, 4)], categorical_variables = 1)) plot(cph_pp_cat, variable_type = "categorical", colors = c("#ff0000", "#00ff00", "#0000ff")) plot(cph_pp_cat, variable_type = "categorical") - expect_s3_class(cph_pp, c("predict_profile_survival", "surv_ceteris_paribus")) expect_s3_class(cph_pp_cat, c("predict_profile_survival", "surv_ceteris_paribus")) @@ -44,7 +51,6 @@ test_that("ceteris_paribus works", { expect_setequal(cph_pp_cat$eval_times, cph_exp$times) expect_output(print(cph_pp)) - }) From bd9ed5aa19ad9f343b87439298efc740f58a7418 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 22:35:47 +0200 Subject: [PATCH 145/207] fix model_survshap naming convention --- R/model_survshap.R | 2 + R/plot_surv_shap.R | 55 ++++++++++++++++++++-------- man/model_survshap.surv_explainer.Rd | 2 + man/plot.aggregated_surv_shap.Rd | 44 ++++++++++++++++------ tests/testthat/test-model_survshap.R | 18 ++++----- vignettes/global-survshap.Rmd | 4 +- 6 files changed, 86 insertions(+), 39 deletions(-) diff --git a/R/model_survshap.R b/R/model_survshap.R index 43624bce..96bb2605 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -40,6 +40,8 @@ #' calculation_method = "kernelshap", #' ) #' plot(ranger_global_survshap) +#' plot(ranger_global_survshap, kind = "beeswarm") +#' plot(ranger_global_survshap, kind = "profile", color_variable = "karno") #' } #' #' @rdname model_survshap.surv_explainer diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index d03b7328..55edfc71 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -93,7 +93,7 @@ plot.surv_shap <- function(x, #' explanations of survival models created using the `model_survshap()` function. #' #' @param x an object of class `aggregated_surv_shap` to be plotted -#' @param kind character, one of `"importance"`, `"swarm"`, or `"profile"`. Type of chart to be plotted; `"importance"` shows the importance of variables over time and aggregated, `"swarm"` shows the distribution of SurvSHAP(t) values for variables and observations, `"profile"` shows the dependence of SurvSHAP(t) values on variable values. +#' @param geom character, one of `"importance"`, `"beeswarm"`, or `"profile"`. Type of chart to be plotted; `"importance"` shows the importance of variables over time and aggregated, `"beeswarm"` shows the distribution of SurvSHAP(t) values for variables and observations, `"profile"` shows the dependence of SurvSHAP(t) values on variable values. #' @param ... additional parameters passed to internal functions #' @param title character, title of the plot #' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for the XXX model (n = YYY)", where XXX is the explainer label and YYY is the number of observations used for calculations @@ -104,19 +104,19 @@ plot.surv_shap <- function(x, #' #' @section Plot options: #' -#' ## `plot.aggregated_surv_shap(type = "importance")` +#' ## `plot.aggregated_surv_shap(geom = "importance")` #' #' * `rug` - character, one of `"all"`, `"events"`, `"censors"`, `"none"` or `NULL`. Which times to mark on the x axis in `geom_rug()`. #' * `rug_colors` - character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. #' * `xlab_left, ylab_right` - axis labels for left and right plots (due to different aggregation possibilities) #' #' -#' ## `plot.aggregated_surv_shap(type = "swarm")` +#' ## `plot.aggregated_surv_shap(geom = "beeswarm")` #' #' * no additional parameters #' #' -#' ## `plot.aggregated_surv_shap(type = "swarm")` +#' ## `plot.aggregated_surv_shap(geom = "profile")` #' #' * `variable` - variable for which the profile is to be plotted, by default first from result data #' * `color_variable` - variable used to denote the color, by default equal to `variable` @@ -124,19 +124,39 @@ plot.surv_shap <- function(x, #' #' @examples #' \donttest{ -#' library(survival) -#' library(survex) -#' -#' model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) -#' exp <- explain(model) +#' veteran <- survival::veteran +#' rsf_ranger <- ranger::ranger( +#' survival::Surv(time, status) ~ ., +#' data = veteran, +#' respect.unordered.factors = TRUE, +#' num.trees = 100, +#' mtry = 3, +#' max.depth = 5 +#' ) +#' rsf_ranger_exp <- explain( +#' rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = survival::Surv(veteran$time, veteran$status), +#' verbose = FALSE +#' ) #' -#' p_parts_shap <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survshap") -#' plot(p_parts_shap) +#' ranger_global_survshap <- model_survshap( +#' explainer = rsf_ranger_exp, +#' new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), +#' !colnames(veteran) %in% c("time", "status")], +#' y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], +#' veteran$status[c(1:4, 17:20, 110:113, 126:129)]), +#' aggregation_method = "integral", +#' calculation_method = "kernelshap", +#' ) +#' plot(ranger_global_survshap) +#' plot(ranger_global_survshap, kind = "beeswarm") +#' plot(ranger_global_survshap, kind = "profile", color_variable = "karno") #' } #' #'@export plot.aggregated_surv_shap <- function(x, - kind = "importance", + geom = "importance", ..., title="default", subtitle="default", @@ -148,18 +168,21 @@ plot.aggregated_surv_shap <- function(x, high = "#371ea3") } + if (geom == "swarm") + geom <- "beeswarm" + switch( - kind, + geom, "importance" = plot_shap_global_importance(x = x, ... = ..., colors = colors), - "swarm" = plot_shap_global_swarm(x = x, + "beeswarm" = plot_shap_global_beeswarm(x = x, ... = ..., colors = colors), "profile" = plot_shap_global_profile(x = x, ... = ..., colors = colors), - stop("`kind` must be one of 'importance', 'swarm' or 'profile'") + stop("`kind` must be one of 'importance', 'beeswarm' or 'profile'") ) } @@ -220,7 +243,7 @@ plot_shap_global_importance <- function(x, return(pl) } -plot_shap_global_swarm <- function(x, +plot_shap_global_beeswarm <- function(x, ..., title = "default", subtitle = "default", diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 3448eb19..54ed7612 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -68,6 +68,8 @@ ranger_global_survshap <- model_survshap( calculation_method = "kernelshap", ) plot(ranger_global_survshap) +plot(ranger_global_survshap, kind = "beeswarm") +plot(ranger_global_survshap, kind = "profile", color_variable = "karno") } } diff --git a/man/plot.aggregated_surv_shap.Rd b/man/plot.aggregated_surv_shap.Rd index 173cd1aa..6f822ad3 100644 --- a/man/plot.aggregated_surv_shap.Rd +++ b/man/plot.aggregated_surv_shap.Rd @@ -6,7 +6,7 @@ \usage{ \method{plot}{aggregated_surv_shap}( x, - kind = "importance", + geom = "importance", ..., title = "default", subtitle = "default", @@ -17,7 +17,7 @@ \arguments{ \item{x}{an object of class \code{aggregated_surv_shap} to be plotted} -\item{kind}{character, one of \code{"importance"}, \code{"swarm"}, or \code{"profile"}. Type of chart to be plotted; \code{"importance"} shows the importance of variables over time and aggregated, \code{"swarm"} shows the distribution of SurvSHAP(t) values for variables and observations, \code{"profile"} shows the dependence of SurvSHAP(t) values on variable values.} +\item{geom}{character, one of \code{"importance"}, \code{"beeswarm"}, or \code{"profile"}. Type of chart to be plotted; \code{"importance"} shows the importance of variables over time and aggregated, \code{"beeswarm"} shows the distribution of SurvSHAP(t) values for variables and observations, \code{"profile"} shows the dependence of SurvSHAP(t) values on variable values.} \item{...}{additional parameters passed to internal functions} @@ -38,7 +38,7 @@ explanations of survival models created using the \code{model_survshap()} functi } \section{Plot options}{ -\subsection{\code{plot.aggregated_surv_shap(type = "importance")}}{ +\subsection{\code{plot.aggregated_surv_shap(geom = "importance")}}{ \itemize{ \item \code{rug} - character, one of \code{"all"}, \code{"events"}, \code{"censors"}, \code{"none"} or \code{NULL}. Which times to mark on the x axis in \code{geom_rug()}. \item \code{rug_colors} - character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. @@ -46,13 +46,13 @@ explanations of survival models created using the \code{model_survshap()} functi } } -\subsection{\code{plot.aggregated_surv_shap(type = "swarm")}}{ +\subsection{\code{plot.aggregated_surv_shap(geom = "beeswarm")}}{ \itemize{ \item no additional parameters } } -\subsection{\code{plot.aggregated_surv_shap(type = "swarm")}}{ +\subsection{\code{plot.aggregated_surv_shap(geom = "profile")}}{ \itemize{ \item \code{variable} - variable for which the profile is to be plotted, by default first from result data \item \code{color_variable} - variable used to denote the color, by default equal to \code{variable} @@ -62,14 +62,34 @@ explanations of survival models created using the \code{model_survshap()} functi \examples{ \donttest{ -library(survival) -library(survex) - -model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) -exp <- explain(model) +veteran <- survival::veteran +rsf_ranger <- ranger::ranger( + survival::Surv(time, status) ~ ., + data = veteran, + respect.unordered.factors = TRUE, + num.trees = 100, + mtry = 3, + max.depth = 5 +) +rsf_ranger_exp <- explain( + rsf_ranger, + data = veteran[, -c(3, 4)], + y = survival::Surv(veteran$time, veteran$status), + verbose = FALSE +) -p_parts_shap <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survshap") -plot(p_parts_shap) +ranger_global_survshap <- model_survshap( + explainer = rsf_ranger_exp, + new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), + !colnames(veteran) \%in\% c("time", "status")], + y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], + veteran$status[c(1:4, 17:20, 110:113, 126:129)]), + aggregation_method = "integral", + calculation_method = "kernelshap", +) +plot(ranger_global_survshap) +plot(ranger_global_survshap, kind = "beeswarm") +plot(ranger_global_survshap, kind = "profile", color_variable = "karno") } } diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index 8e5626bc..acfe8f94 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -22,11 +22,11 @@ test_that("global survshap explanations with kernelshap work for ranger, using n calculation_method = "kernelshap" ) plot(ranger_global_survshap) - plot(ranger_global_survshap, kind = "swarm") - plot(ranger_global_survshap, kind = "profile") - plot(ranger_global_survshap, kind = "profile", variable = "karno", color_variable = "celltype") - plot(ranger_global_survshap, kind = "profile", variable = "karno", color_variable = "age") - expect_error(plot(ranger_global_survshap, kind = "nonexistent")) + plot(ranger_global_survshap, geom = "beeswarm") + plot(ranger_global_survshap, geom = "profile") + plot(ranger_global_survshap, geom = "profile", variable = "karno", color_variable = "celltype") + plot(ranger_global_survshap, geom = "profile", variable = "karno", color_variable = "age") + expect_error(plot(ranger_global_survshap, geom = "nonexistent")) expect_s3_class(ranger_global_survshap, c("aggregated_surv_shap", "surv_shap")) expect_equal(length(ranger_global_survshap$eval_times), length(rsf_ranger_exp$times)) @@ -41,10 +41,10 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex calculation_method = "kernelshap" ) plot(cph_global_survshap) - plot(cph_global_survshap, kind = "swarm") - plot(cph_global_survshap, kind = "profile") - plot(cph_global_survshap, kind = "profile", variable = "karno", color_variable = "celltype") - plot(cph_global_survshap, kind = "profile", variable = "karno", color_variable = "age") + plot(cph_global_survshap, geom = "beeswarm") + plot(cph_global_survshap, geom = "profile") + plot(cph_global_survshap, geom = "profile", variable = "karno", color_variable = "celltype") + plot(cph_global_survshap, geom = "profile", variable = "karno", color_variable = "age") expect_s3_class(cph_global_survshap, c("aggregated_surv_shap", "surv_shap")) expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times)) diff --git a/vignettes/global-survshap.Rmd b/vignettes/global-survshap.Rmd index bba0257d..3a7b4108 100644 --- a/vignettes/global-survshap.Rmd +++ b/vignettes/global-survshap.Rmd @@ -49,8 +49,8 @@ The `plot.aggregated_surv_shap()` function can also be used to plot the explanat plot(shap, variable = "karno", kind = "profile") ``` -For `kind = "swarm"` a swarm plot is generated that shows the SHAP values for each observation. The swarm plot is a good way to assess the distribution of SHAP values for each variable. +For `kind = "beeswarm"` a bee swarm plot is generated that shows the SHAP values for each observation. The swarm plot is a good way to assess the distribution of SHAP values for each variable. ```{r} -plot(shap, kind = "swarm") +plot(shap, kind = "beeswarm") ``` From 2a7b3698e0293571f33777d52d8a3b8f9272a913 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 22:48:53 +0200 Subject: [PATCH 146/207] add extract predict survshap --- NAMESPACE | 1 + R/utils.R | 59 ++++++++++++++++++++++++++++ man/extract_predict_survshap.Rd | 47 ++++++++++++++++++++++ tests/testthat/test-model_survshap.R | 6 +++ 4 files changed, 113 insertions(+) create mode 100644 man/extract_predict_survshap.Rd diff --git a/NAMESPACE b/NAMESPACE index 65564719..5e354fb4 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -51,6 +51,7 @@ export(cd_auc) export(cumulative_hazard_to_survival) export(explain) export(explain_survival) +export(extract_predict_survshap) export(integrated_brier_score) export(integrated_cd_auc) export(loss_adapt_mlr3proba) diff --git a/R/utils.R b/R/utils.R index acb9f5b5..755b85c5 100644 --- a/R/utils.R +++ b/R/utils.R @@ -168,6 +168,65 @@ risk_from_chf <- function(predict_cumulative_hazard_function, times) { function(model, newdata) rowSums(predict_cumulative_hazard_function(model, newdata, times)) } +#' Extract Local SurvSHAP(t) from Global SurvSHAP(t) +#' +#' Helper function to extract local SurvSHAP(t) explanation from global one. +#' Can be can be useful for creating SurvSHAP(t) plots for single observations. +#' +#' @param aggregated_survshap an object of class `aggregated_surv_shap` containing the computed global SHAP values +#' @param index a numeric value, position of an observation to be extracted in the result of global explanation +#' +#' @return An object of classes `c("predict_parts_survival", "surv_shap")`. It is a list with the element `result` containing the results of the explanation. +#' +#' @examples +#' veteran <- survival::veteran +#' rsf_ranger <- ranger::ranger( +#' survival::Surv(time, status) ~ ., +#' data = veteran, +#' respect.unordered.factors = TRUE, +#' num.trees = 100, +#' mtry = 3, +#' max.depth = 5 +#' ) +#' rsf_ranger_exp <- explain( +#' rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = survival::Surv(veteran$time, veteran$status), +#' verbose = FALSE +#' ) +#' +#' ranger_global_survshap <- model_survshap( +#' explainer = rsf_ranger_exp, +#' new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), +#' !colnames(veteran) %in% c("time", "status")]) +#' +#' local_survshap_1 <- extract_predict_survshap(ranger_global_survshap, index = 1) +#' plot(local_survshap_1) +#' ) +#' +#' @export +extract_predict_survshap <- function(aggregated_survshap, index){ + if (class(aggregated_survshap) != "aggregated_surv_shap") + stop("`aggregated_survshap` object must be of class 'aggregated_surv_shap'") + + if (index > aggregated_survshap$n_observations) + stop(paste("Incorrect `index`, number of observations in `aggregated_survshap` is", aggregated_survshap$n_observations)) + + + res <- list() + res$eval_times <- aggregated_survshap$eval_times + res$event_times <- aggregated_survshap$event_times + res$event_statuses <- aggregated_survshap$event_statuses + res$variable_values <- aggregated_survshap$variable_values[index,] + res$result <- aggregated_survshap$result[[index]] + res$aggregate <- aggregated_survshap$aggregate[[index]] + class(res) <- c("predict_parts_survival", "surv_shap") + attr(res, "label") <- attr(aggregated_survshap, "label") + + res +} + + #' @keywords internal add_rug_to_plot <- function(base_plot, rug_df, rug, rug_colors){ if (rug == "all"){ diff --git a/man/extract_predict_survshap.Rd b/man/extract_predict_survshap.Rd new file mode 100644 index 00000000..93e69393 --- /dev/null +++ b/man/extract_predict_survshap.Rd @@ -0,0 +1,47 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{extract_predict_survshap} +\alias{extract_predict_survshap} +\title{Extract Local SurvSHAP(t) from Global SurvSHAP(t)} +\usage{ +extract_predict_survshap(aggregated_survshap, index) +} +\arguments{ +\item{aggregated_survshap}{an object of class \code{aggregated_surv_shap} containing the computed global SHAP values} + +\item{index}{a numeric value, position of an observation to be extracted in the result of global explanation} +} +\value{ +An object of classes \code{c("predict_parts_survival", "surv_shap")}. It is a list with the element \code{result} containing the results of the explanation. +} +\description{ +Helper function to extract local SurvSHAP(t) explanation from global one. +Can be can be useful for creating SurvSHAP(t) plots for single observations. +} +\examples{ +veteran <- survival::veteran +rsf_ranger <- ranger::ranger( + survival::Surv(time, status) ~ ., + data = veteran, + respect.unordered.factors = TRUE, + num.trees = 100, + mtry = 3, + max.depth = 5 +) +rsf_ranger_exp <- explain( + rsf_ranger, + data = veteran[, -c(3, 4)], + y = survival::Surv(veteran$time, veteran$status), + verbose = FALSE +) + +ranger_global_survshap <- model_survshap( + explainer = rsf_ranger_exp, + new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), + !colnames(veteran) \%in\% c("time", "status")]) + +local_survshap_1 <- extract_predict_survshap(ranger_global_survshap, index = 1) +plot(local_survshap_1) +) + +} diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index acfe8f94..2508f4d1 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -28,6 +28,12 @@ test_that("global survshap explanations with kernelshap work for ranger, using n plot(ranger_global_survshap, geom = "profile", variable = "karno", color_variable = "age") expect_error(plot(ranger_global_survshap, geom = "nonexistent")) + single_survshap <- extract_predict_survshap(ranger_global_survshap, 5) + expect_s3_class(single_survshap, c("predict_parts_survival", "surv_shap")) + expect_error(extract_predict_survshap(ranger_global_survshap, 200)) + expect_error(extract_predict_survshap(single_survshap, 5)) + + expect_s3_class(ranger_global_survshap, c("aggregated_surv_shap", "surv_shap")) expect_equal(length(ranger_global_survshap$eval_times), length(rsf_ranger_exp$times)) expect_true(all(names(ranger_global_survshap$variable_values) == colnames(rsf_ranger_exp$data))) From 678f98606c8a379104f1546e86333c5b1e145d9c Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 22:51:36 +0200 Subject: [PATCH 147/207] fix --- R/utils.R | 3 +-- man/extract_predict_survshap.Rd | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/R/utils.R b/R/utils.R index 755b85c5..2b79917c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -202,11 +202,10 @@ risk_from_chf <- function(predict_cumulative_hazard_function, times) { #' #' local_survshap_1 <- extract_predict_survshap(ranger_global_survshap, index = 1) #' plot(local_survshap_1) -#' ) #' #' @export extract_predict_survshap <- function(aggregated_survshap, index){ - if (class(aggregated_survshap) != "aggregated_surv_shap") + if (inherits(aggregated_survshap, "aggregated_surv_shap")) stop("`aggregated_survshap` object must be of class 'aggregated_surv_shap'") if (index > aggregated_survshap$n_observations) diff --git a/man/extract_predict_survshap.Rd b/man/extract_predict_survshap.Rd index 93e69393..855cbdc1 100644 --- a/man/extract_predict_survshap.Rd +++ b/man/extract_predict_survshap.Rd @@ -42,6 +42,5 @@ ranger_global_survshap <- model_survshap( local_survshap_1 <- extract_predict_survshap(ranger_global_survshap, index = 1) plot(local_survshap_1) -) } From 01c8d3b317bcfa94a0b04d3cc70afad03f39b744 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 22:53:54 +0200 Subject: [PATCH 148/207] fix typo --- R/utils.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/utils.R b/R/utils.R index 2b79917c..146290b4 100644 --- a/R/utils.R +++ b/R/utils.R @@ -205,7 +205,7 @@ risk_from_chf <- function(predict_cumulative_hazard_function, times) { #' #' @export extract_predict_survshap <- function(aggregated_survshap, index){ - if (inherits(aggregated_survshap, "aggregated_surv_shap")) + if (!inherits(aggregated_survshap, "aggregated_surv_shap")) stop("`aggregated_survshap` object must be of class 'aggregated_surv_shap'") if (index > aggregated_survshap$n_observations) From 7df26964b42b7f1ddf7e0a24b41c29b3ab46cadb Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 23:00:16 +0200 Subject: [PATCH 149/207] kind -> geom --- R/model_survshap.R | 4 ++-- R/plot_surv_shap.R | 6 +++--- man/model_survshap.surv_explainer.Rd | 4 ++-- man/plot.aggregated_surv_shap.Rd | 4 ++-- vignettes/global-survshap.Rmd | 8 ++++---- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/R/model_survshap.R b/R/model_survshap.R index 96bb2605..256804eb 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -40,8 +40,8 @@ #' calculation_method = "kernelshap", #' ) #' plot(ranger_global_survshap) -#' plot(ranger_global_survshap, kind = "beeswarm") -#' plot(ranger_global_survshap, kind = "profile", color_variable = "karno") +#' plot(ranger_global_survshap, geom = "beeswarm") +#' plot(ranger_global_survshap, geom = "profile", color_variable = "karno") #' } #' #' @rdname model_survshap.surv_explainer diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 55edfc71..df1fc41b 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -150,8 +150,8 @@ plot.surv_shap <- function(x, #' calculation_method = "kernelshap", #' ) #' plot(ranger_global_survshap) -#' plot(ranger_global_survshap, kind = "beeswarm") -#' plot(ranger_global_survshap, kind = "profile", color_variable = "karno") +#' plot(ranger_global_survshap, geom = "beeswarm") +#' plot(ranger_global_survshap, geom = "profile", color_variable = "karno") #' } #' #'@export @@ -182,7 +182,7 @@ plot.aggregated_surv_shap <- function(x, "profile" = plot_shap_global_profile(x = x, ... = ..., colors = colors), - stop("`kind` must be one of 'importance', 'beeswarm' or 'profile'") + stop("`geom` must be one of 'importance', 'beeswarm' or 'profile'") ) } diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 54ed7612..491fe3e7 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -68,8 +68,8 @@ ranger_global_survshap <- model_survshap( calculation_method = "kernelshap", ) plot(ranger_global_survshap) -plot(ranger_global_survshap, kind = "beeswarm") -plot(ranger_global_survshap, kind = "profile", color_variable = "karno") +plot(ranger_global_survshap, geom = "beeswarm") +plot(ranger_global_survshap, geom = "profile", color_variable = "karno") } } diff --git a/man/plot.aggregated_surv_shap.Rd b/man/plot.aggregated_surv_shap.Rd index 6f822ad3..b1e0bbb1 100644 --- a/man/plot.aggregated_surv_shap.Rd +++ b/man/plot.aggregated_surv_shap.Rd @@ -88,8 +88,8 @@ ranger_global_survshap <- model_survshap( calculation_method = "kernelshap", ) plot(ranger_global_survshap) -plot(ranger_global_survshap, kind = "beeswarm") -plot(ranger_global_survshap, kind = "profile", color_variable = "karno") +plot(ranger_global_survshap, geom = "beeswarm") +plot(ranger_global_survshap, geom = "profile", color_variable = "karno") } } diff --git a/vignettes/global-survshap.Rmd b/vignettes/global-survshap.Rmd index 3a7b4108..92614eaf 100644 --- a/vignettes/global-survshap.Rmd +++ b/vignettes/global-survshap.Rmd @@ -43,14 +43,14 @@ We plot these explanations using the `plot.aggregated_surv_shap()` function. By plot(shap) ``` -The `plot.aggregated_surv_shap()` function can also be used to plot the explanations for a single variable. The `variable` argument specifies the variable for which the explanations are plotted. The `kind` argument specifies the type of plot. For `kind = "profile"` a plot is generated that shows the mean SHAP value (averaged across the time domain) depending on the value of the variable. +The `plot.aggregated_surv_shap()` function can also be used to plot the explanations for a single variable. The `variable` argument specifies the variable for which the explanations are plotted. The `geom` argument specifies the type of plot. For `geom = "profile"` a plot is generated that shows the mean SHAP value (averaged across the time domain) depending on the value of the variable. ```{r} -plot(shap, variable = "karno", kind = "profile") +plot(shap, variable = "karno", geom = "profile") ``` -For `kind = "beeswarm"` a bee swarm plot is generated that shows the SHAP values for each observation. The swarm plot is a good way to assess the distribution of SHAP values for each variable. +For `geom = "beeswarm"` a bee swarm plot is generated that shows the SHAP values for each observation. The swarm plot is a good way to assess the distribution of SHAP values for each variable. ```{r} -plot(shap, kind = "beeswarm") +plot(shap, geom = "beeswarm") ``` From 1e658234b22d3a4b00d72837140cc482aceedb46 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 18 Aug 2023 23:03:36 +0200 Subject: [PATCH 150/207] fix references --- _pkgdown.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/_pkgdown.yml b/_pkgdown.yml index a8b45479..7ee2eff2 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -45,6 +45,7 @@ reference: - plot.surv_ceteris_paribus - title: Utility functions - contents: + - extract_predict_survshap - transform_to_stepfunction - risk_from_chf - cumulative_hazard_to_survival From 8412e67b3fd5cecf3dfc622335670e3ecafb9383 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 21 Aug 2023 16:54:37 +0200 Subject: [PATCH 151/207] change type of color scale for time when geom = "variable" --- R/plot_model_profile_survival.R | 34 +++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 7112c1e9..c3121859 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -166,6 +166,8 @@ plot2 <- function(x, title <- "Individual conditional expectation survival profiles" } else if (plot_type == "pdp+ice") { title <- "Partial dependence with individual conditional expectation survival profiles" + } else if (plot_type == "pdp"){ + title <- "Partial dependence survival profiles" } if (x$type == "accumulated" && plot_type != "ale") { @@ -221,6 +223,7 @@ plot2 <- function(x, pdp_df$time <- as.factor(pdp_df$time) } + ice_df <- NULL if (ice_needed) { ice_df <- x$cp_profiles$result[(x$cp_profiles$result$`_vname_` == variable) & (x$cp_profiles$result$`_times_` %in% times), ] @@ -266,7 +269,10 @@ plot2 <- function(x, ice_df <- aggregate(predictions ~ ., data = ice_df, mean) color_scale <- generate_discrete_color_scale(1, colors) } else { - color_scale <- generate_discrete_color_scale(length(times), colors) + if (is.null(colors) | length(colors) < 3) + color_scale <- c(low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3") } if (is_categorical) { @@ -346,29 +352,45 @@ plot_pdp_num <- function(pdp_dt, ylim(y_floor_pd, y_ceiling_pd) } } else { ## multiple timepoints + pdp_dt$time <- as.numeric(as.character(pdp_dt$time)) + if (!is.null(ice_dt)) + ice_dt$time <- as.numeric(as.character(ice_dt$time)) + if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(alpha = 0.2, mapping = aes(group = interaction(id, time), color = time)) + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + - scale_color_manual(name = "time", values = colors) + + scale_colour_gradient2( + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = median(as.numeric(as.character(pdp_dt$time)))) + ylim(y_floor_ice, y_ceiling_ice) } # PDP + ICE else if (plot_type == "pdp+ice") { ggplot() + geom_line(data = ice_dt, aes(x = !!feature_name_sym, y = predictions, group = interaction(id, time), color = time), alpha = 0.1) + - geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, color = time), linewidth = 1.5, lineend = "round", linejoin = "round") + + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, color = time, group = time), linewidth = 1.5, lineend = "round", linejoin = "round") + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, group = time), color = "black", linewidth = 0.5, linetype = "dashed", lineend = "round", linejoin = "round") + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + - scale_color_manual(name = "time", values = colors) + + scale_colour_gradient2( + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = median(as.numeric(as.character(pdp_dt$time)))) + ylim(y_floor_ice, y_ceiling_ice) } # PDP else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + - geom_line(aes(color = time)) + + geom_line(aes(color = time, group = time)) + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + - scale_color_manual(name = "time", values = colors) + + scale_colour_gradient2( + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = median(as.numeric(as.character(pdp_dt$time)))) + ylim(y_floor_pd, y_ceiling_pd) } } From 8dcab853e240debfca6e9e292eebcd547400c02c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Tue, 22 Aug 2023 11:21:29 +0200 Subject: [PATCH 152/207] Reformat code and examples with the `styler` package --- R/explain.R | 911 ++++++++++++------------ R/metrics.R | 49 +- R/misc_set_theme_survex.R | 26 +- R/model_info.R | 15 +- R/model_parts.R | 68 +- R/model_performance.R | 40 +- R/model_profile.R | 125 ++-- R/model_profile_2d.R | 212 +++--- R/model_survshap.R | 47 +- R/plot_model_profile_2d.R | 145 ++-- R/plot_model_profile_survival.R | 51 +- R/plot_predict_profile_survival.R | 2 - R/plot_surv_ceteris_paribus.R | 261 ++++--- R/plot_surv_feature_importance.R | 25 +- R/plot_surv_lime.R | 58 +- R/plot_surv_model_performance.R | 83 ++- R/plot_surv_model_performance_rocs.R | 33 +- R/plot_surv_shap.R | 252 ++++--- R/predict_parts.R | 35 +- R/predict_profile.R | 50 +- R/predict_surv_explainer.R | 20 +- R/print.R | 2 - R/surv_ceteris_paribus.R | 11 +- R/surv_feature_importance.R | 38 +- R/surv_integrated_feature_importance.R | 11 +- R/surv_lime.R | 58 +- R/surv_model_performance.R | 47 +- R/surv_model_profiles.R | 297 ++++---- R/surv_shap.R | 87 ++- R/utils.R | 186 ++--- man/c_index.Rd | 3 +- man/explain_survival.Rd | 44 +- man/extract_predict_survshap.Rd | 31 +- man/loss_one_minus_c_index.Rd | 3 +- man/model_performance.surv_explainer.Rd | 24 +- man/model_profile.surv_explainer.Rd | 14 +- man/model_profile_2d.surv_explainer.Rd | 8 +- man/model_survshap.surv_explainer.Rd | 38 +- man/plot.aggregated_surv_shap.Rd | 38 +- man/plot.model_profile_2d_survival.Rd | 12 +- man/predict.surv_explainer.Rd | 17 +- man/predict_parts.surv_explainer.Rd | 7 +- man/predict_profile.surv_explainer.Rd | 5 +- man/risk_from_chf.Rd | 12 +- man/surv_model_info.Rd | 12 +- man/transform_to_stepfunction.Rd | 7 +- 46 files changed, 1864 insertions(+), 1656 deletions(-) diff --git a/R/explain.R b/R/explain.R index d60b0d82..049d4549 100644 --- a/R/explain.R +++ b/R/explain.R @@ -50,14 +50,20 @@ #' library(survival) #' library(survex) #' -#' cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, -#' model = TRUE, x = TRUE) +#' cph <- survival::coxph(survival::Surv(time, status) ~ ., +#' data = veteran, +#' model = TRUE, x = TRUE +#' ) #' cph_exp <- explain(cph) #' -#' rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, -#' respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) -#' rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], -#' y = Surv(veteran$time, veteran$status)) +#' rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., +#' data = veteran, +#' respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 +#' ) +#' rsf_ranger_exp <- explain(rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = Surv(veteran$time, veteran$status) +#' ) #' #' rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) #' rsf_src_exp <- explain(rsf_src) @@ -65,28 +71,30 @@ #' library(censored, quietly = TRUE) #' #' bt <- parsnip::boost_tree() %>% -#' parsnip::set_engine("mboost") %>% -#' parsnip::set_mode("censored regression") %>% -#' generics::fit(survival::Surv(time, status) ~ ., data = veteran) +#' parsnip::set_engine("mboost") %>% +#' parsnip::set_mode("censored regression") %>% +#' generics::fit(survival::Surv(time, status) ~ ., data = veteran) #' bt_exp <- explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status)) #' #' ###### explain_survival() ###### #' -#' cph <- coxph(Surv(time, status) ~ ., data=veteran) +#' cph <- coxph(Surv(time, status) ~ ., data = veteran) #' -#' veteran_data <- veteran[, -c(3,4)] +#' veteran_data <- veteran[, -c(3, 4)] #' veteran_y <- Surv(veteran$time, veteran$status) #' risk_pred <- function(model, newdata) predict(model, newdata, type = "risk") #' surv_pred <- function(model, newdata, times) pec::predictSurvProb(model, newdata, times) #' chf_pred <- function(model, newdata, times) -log(surv_pred(model, newdata, times)) #' -#' manual_cph_explainer <- explain_survival(model = cph, -#' data = veteran_data, -#' y = veteran_y, -#' predict_function = risk_pred, -#' predict_survival_function = surv_pred, -#' predict_cumulative_hazard_function = chf_pred, -#' label = "manual coxph") +#' manual_cph_explainer <- explain_survival( +#' model = cph, +#' data = veteran_data, +#' y = veteran_y, +#' predict_function = risk_pred, +#' predict_survival_function = surv_pred, +#' predict_cumulative_hazard_function = chf_pred, +#' label = "manual coxph" +#' ) #' } #' #' @import survival @@ -108,19 +116,19 @@ explain_survival <- ..., label = NULL, verbose = TRUE, - colorize = !isTRUE(getOption('knitr.in.progress')), + colorize = !isTRUE(getOption("knitr.in.progress")), model_info = NULL, type = NULL, - times = NULL, times_generation = "quantiles", predict_survival_function = NULL, - predict_cumulative_hazard_function = NULL) - { + predict_cumulative_hazard_function = NULL) { if (!colorize) { - color_codes <- list(yellow_start = "", yellow_end = "", - red_start = "", red_end = "", - green_start = "", green_end = "") + color_codes <- list( + yellow_start = "", yellow_end = "", + red_start = "", red_end = "", + green_start = "", green_end = "" + ) } @@ -145,7 +153,7 @@ explain_survival <- # verbose label if (is.null(label)) { label <- tail(class(model), 1) - verbose_cat(" -> model label : ", label, is.default = TRUE, verbose = verbose) + verbose_cat(" -> model label : ", label, is.default = TRUE, verbose = verbose) } else { if (!is.character(label)) { label <- substr(as.character(label), 1, 15) @@ -198,7 +206,7 @@ explain_survival <- } else { n_events <- sum(y[, 2]) n_censored <- length(y) - n_events - frac_censored <- round(n_censored/n, 3) + frac_censored <- round(n_censored / n, 3) if (!is.null(attr(y, "verbose_info")) && attr(y, "verbose_info") == "extracted") { verbose_cat(" -> target variable : ", length(y), " values (", n_events, "events and", n_censored, "censored , censoring rate =", frac_censored, ")", "(", color_codes$yellow_start, "extracted from the model", color_codes$yellow_end, ")", verbose = verbose) attr(y, "verbose_info") <- NULL @@ -219,15 +227,15 @@ explain_survival <- if (is.null(times)) { if (!is.null(y)) { switch(times_generation, - "uniform" = { - times <- seq(min(y[, 1]), max(y[, 1]), length.out = 50) - method_description <- "50 uniformly distributed time points from min to max" - }, - "quantiles" = { - times <- quantile(y[, 1], seq(0, 0.99, 0.02)) - method_description <- "50 time points being consecutive quantiles (0.00, 0.02, ..., 0.98)" - }, - stop("times_generation needs to be 'uniform' or 'quantiles'") + "uniform" = { + times <- seq(min(y[, 1]), max(y[, 1]), length.out = 50) + method_description <- "50 uniformly distributed time points from min to max" + }, + "quantiles" = { + times <- quantile(y[, 1], seq(0, 0.99, 0.02)) + method_description <- "50 time points being consecutive quantiles (0.00, 0.02, ..., 0.98)" + }, + stop("times_generation needs to be 'uniform' or 'quantiles'") ) times <- sort(unique(times)) times_stats <- get_times_stats(times) @@ -247,7 +255,7 @@ explain_survival <- if (is.null(predict_function)) { if (!is.null(predict_cumulative_hazard_function)) { predict_function <- function(model, newdata) risk_from_chf(predict_cumulative_hazard_function(model, newdata, times = times)) - verbose_cat(" -> predict function : ", "sum over the predict_cumulative_hazard_function will be used", is.default = TRUE, verbose = verbose) + verbose_cat(" -> predict function : ", "sum over the predict_cumulative_hazard_function will be used", is.default = TRUE, verbose = verbose) } else { verbose_cat(" -> predict function : not specified! (", color_codes$red_start, "WARNING", color_codes$red_end, ")", verbose = verbose) warning("Prediction function not specified") @@ -285,7 +293,7 @@ explain_survival <- if (!"function" %in% class(predict_survival_function)) { verbose_cat(" -> predict survival function : 'predict_survival_function' is not a 'function' class object! (", color_codes$red_start, "WARNING", color_codes$red_end, ")", verbose = verbose) warning("Survival function not available") - } + } } # verbose predict cumulative hazard function @@ -303,7 +311,7 @@ explain_survival <- if (!"function" %in% class(predict_cumulative_hazard_function)) { verbose_cat(" -> predict cumulative hazard function : 'predict_cumulative_hazard_function' is not a 'function' class object! (", color_codes$red_start, "WARNING", color_codes$red_end, ")", verbose = verbose) warning("'predict_cumulative_hazard_function' is not a 'function' class object") - } + } } # verbose model info @@ -337,7 +345,6 @@ explain_survival <- colorize = colorize, model_info = model_info, type = type, - times = times, predict_survival_function = predict_survival_function, predict_cumulative_hazard_function = predict_cumulative_hazard_function, @@ -349,42 +356,41 @@ explain_survival <- # verbose end - everything went OK verbose_cat("", color_codes$green_start, "A new explainer has been created!", color_codes$green_end, verbose = verbose) explainer - } #' @rdname explain_survival #' @export -explain <- - function(model, - data = NULL, - y = NULL, - predict_function = NULL, - predict_function_target_column = NULL, - residual_function = NULL, - weights = NULL, - ..., - label = NULL, - verbose = TRUE, - colorize = !isTRUE(getOption('knitr.in.progress')), - model_info = NULL, - type = NULL) -UseMethod("explain", model) +explain <- function(model, + data = NULL, + y = NULL, + predict_function = NULL, + predict_function_target_column = NULL, + residual_function = NULL, + weights = NULL, + ..., + label = NULL, + verbose = TRUE, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = NULL, + type = NULL) { + UseMethod("explain", model) +} #' @rdname explain_survival #' @export -explain.default <- function(model, - data = NULL, - y = NULL, - predict_function = NULL, - predict_function_target_column = NULL, - residual_function = NULL, - weights = NULL, - ..., - label = NULL, - verbose = TRUE, - colorize = !isTRUE(getOption('knitr.in.progress')), - model_info = NULL, - type = NULL) { +explain.default <- function(model, + data = NULL, + y = NULL, + predict_function = NULL, + predict_function_target_column = NULL, + residual_function = NULL, + weights = NULL, + ..., + label = NULL, + verbose = TRUE, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = NULL, + type = NULL) { supported_models <- c("aalen", "riskRegression", "cox.aalen", "cph", "coxph", "selectCox", "pecCforest", "prodlim", "psm", "survfit", "pecRpart") if (inherits(model, supported_models)) { return( @@ -407,451 +413,444 @@ explain.default <- function(model, ) } else { DALEX::explain(model, - data = data, - y = y, - predict_function = predict_function, - predict_function_target_column = predict_function_target_column, - residual_function = residual_function, - weights = weights, - ..., - label = label, - verbose = verbose, - colorize = !isTRUE(getOption('knitr.in.progress')), - model_info = model_info, - type = type) - } -} - -#' @export -explain.coxph <- - function(model, - data = NULL, - y = NULL, - predict_function = NULL, - predict_function_target_column = NULL, - residual_function = NULL, - weights = NULL, - ..., - label = NULL, - verbose = TRUE, - colorize = !isTRUE(getOption('knitr.in.progress')), - model_info = NULL, - type = NULL, - - times = NULL, - times_generation = "quantiles", - predict_survival_function = NULL, - predict_cumulative_hazard_function = NULL) { - - if (is.null(data)) { - data <- model$model[, attr(model$terms, "term.labels")] - if (is.null(data)) { - stop( - "use `model=TRUE` and `x=TRUE` while creating coxph model or provide `data` manually" - ) - } - attr(data, "verbose_info") <- "extracted" - } - - if (is.null(y)) { - y <- model$y - if (is.null(y)) { - stop("use `y=TRUE` while creating coxph model or provide `y` manually") - } - attr(y, "verbose_info") <- "extracted" - } - - if (is.null(predict_survival_function)) { - predict_survival_function <- function(model, newdata, times) { - pec::predictSurvProb(model, newdata, times) - } - attr(predict_survival_function, "verbose_info") <- "predictSurvProb.coxph will be used" - attr(predict_survival_function, "is.default") <- TRUE - } else { - attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) - } - - if (is.null(predict_cumulative_hazard_function)) { - predict_cumulative_hazard_function <- - function(model, newdata, times) { - survival_to_cumulative_hazard(predict_survival_function(model, newdata, times)) - } - attr(predict_cumulative_hazard_function, "verbose_info") <- "-log(predict_survival_function) will be used" - attr(predict_cumulative_hazard_function, "is.default") <- TRUE - } else { - attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) - } - - if (is.null(predict_function)) { - predict_function <- function(model, newdata) { - predict(model, newdata, type = "risk") - } - attr(predict_function, "verbose_info") <- "predict.coxph with type = 'risk' will be used" - attr(predict_function, "is.default") <- TRUE - } else { - attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) - } - - explain_survival( - model, data = data, y = y, predict_function = predict_function, predict_function_target_column = predict_function_target_column, residual_function = residual_function, weights = weights, - ..., + ... = ..., label = label, verbose = verbose, - colorize = colorize, + colorize = !isTRUE(getOption("knitr.in.progress")), model_info = model_info, - type = type, - times = times, - times_generation = times_generation, - predict_survival_function = predict_survival_function, - predict_cumulative_hazard_function = predict_cumulative_hazard_function + type = type ) } - +} #' @export -explain.ranger <- - function(model, - data = NULL, - y = NULL, - predict_function = NULL, - predict_function_target_column = NULL, - residual_function = NULL, - weights = NULL, - ..., - label = NULL, - verbose = TRUE, - colorize = !isTRUE(getOption('knitr.in.progress')), - model_info = NULL, - type = NULL, - - times = NULL, - times_generation = "quantiles", - predict_survival_function = NULL, - predict_cumulative_hazard_function = NULL) { +explain.coxph <- function(model, + data = NULL, + y = NULL, + predict_function = NULL, + predict_function_target_column = NULL, + residual_function = NULL, + weights = NULL, + ..., + label = NULL, + verbose = TRUE, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = NULL, + type = NULL, + times = NULL, + times_generation = "quantiles", + predict_survival_function = NULL, + predict_cumulative_hazard_function = NULL) { + if (is.null(data)) { + data <- model$model[, attr(model$terms, "term.labels")] + if (is.null(data)) { + stop( + "use `model=TRUE` and `x=TRUE` while creating coxph model or provide `data` manually" + ) + } + attr(data, "verbose_info") <- "extracted" + } - if (is.null(predict_survival_function)) { - predict_survival_function <- transform_to_stepfunction(predict, - type = "survival", - times_element = "unique.death.times", - prediction_element = "survival") - attr(predict_survival_function, "verbose_info") <- "stepfun based on predict.ranger()$survival will be used" - attr(predict_survival_function, "is.default") <- TRUE - } else { - attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) + if (is.null(y)) { + y <- model$y + if (is.null(y)) { + stop("use `y=TRUE` while creating coxph model or provide `y` manually") } + attr(y, "verbose_info") <- "extracted" + } - if (is.null(predict_cumulative_hazard_function)) { - predict_cumulative_hazard_function <- transform_to_stepfunction(predict, - type = "chf", - times_element = "unique.death.times", - prediction_element = "chf") - attr(predict_cumulative_hazard_function, "verbose_info") <- "stepfun based on predict.ranger()$chf will be used" - attr(predict_cumulative_hazard_function, "is.default") <- TRUE - } else { - attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) + if (is.null(predict_survival_function)) { + predict_survival_function <- function(model, newdata, times) { + pec::predictSurvProb(model, newdata, times) } + attr(predict_survival_function, "verbose_info") <- "predictSurvProb.coxph will be used" + attr(predict_survival_function, "is.default") <- TRUE + } else { + attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) + } - if (is.null(predict_function)) { - predict_function <- function(model, newdata, times) { - rowSums(predict_cumulative_hazard_function(model, newdata, times)) + if (is.null(predict_cumulative_hazard_function)) { + predict_cumulative_hazard_function <- + function(model, newdata, times) { + survival_to_cumulative_hazard(predict_survival_function(model, newdata, times)) } - attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" - attr(predict_function, "is.default") <- TRUE - attr(predict_function, "use.times") <- TRUE - } else { - attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) - } + attr(predict_cumulative_hazard_function, "verbose_info") <- "-log(predict_survival_function) will be used" + attr(predict_cumulative_hazard_function, "is.default") <- TRUE + } else { + attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) + } - explain_survival( - model, - data = data, - y = y, - predict_function = predict_function, - predict_function_target_column = predict_function_target_column, - residual_function = residual_function, - weights = weights, - ..., - label = label, - verbose = verbose, - colorize = colorize, - model_info = model_info, - type = type, - times = times, - times_generation = times_generation, - predict_survival_function = predict_survival_function, - predict_cumulative_hazard_function = predict_cumulative_hazard_function - ) + if (is.null(predict_function)) { + predict_function <- function(model, newdata) { + predict(model, newdata, type = "risk") + } + attr(predict_function, "verbose_info") <- "predict.coxph with type = 'risk' will be used" + attr(predict_function, "is.default") <- TRUE + } else { + attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) } + explain_survival( + model, + data = data, + y = y, + predict_function = predict_function, + predict_function_target_column = predict_function_target_column, + residual_function = residual_function, + weights = weights, + ... = ..., + label = label, + verbose = verbose, + colorize = colorize, + model_info = model_info, + type = type, + times = times, + times_generation = times_generation, + predict_survival_function = predict_survival_function, + predict_cumulative_hazard_function = predict_cumulative_hazard_function + ) +} + #' @export -explain.rfsrc <- - function(model, - data = NULL, - y = NULL, - predict_function = NULL, - predict_function_target_column = NULL, - residual_function = NULL, - weights = NULL, - ..., - label = NULL, - verbose = TRUE, - colorize = !isTRUE(getOption('knitr.in.progress')), - model_info = NULL, - type = NULL, +explain.ranger <- function(model, + data = NULL, + y = NULL, + predict_function = NULL, + predict_function_target_column = NULL, + residual_function = NULL, + weights = NULL, + ..., + label = NULL, + verbose = TRUE, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = NULL, + type = NULL, + times = NULL, + times_generation = "quantiles", + predict_survival_function = NULL, + predict_cumulative_hazard_function = NULL) { + if (is.null(predict_survival_function)) { + predict_survival_function <- transform_to_stepfunction(predict, + type = "survival", + times_element = "unique.death.times", + prediction_element = "survival" + ) + attr(predict_survival_function, "verbose_info") <- "stepfun based on predict.ranger()$survival will be used" + attr(predict_survival_function, "is.default") <- TRUE + } else { + attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) + } - times = NULL, - times_generation = "quantiles", - predict_survival_function = NULL, - predict_cumulative_hazard_function = NULL) { - if (is.null(label)) { - label <- class(model)[1] - attr(label, "verbose_info") <- "default" - } + if (is.null(predict_cumulative_hazard_function)) { + predict_cumulative_hazard_function <- transform_to_stepfunction(predict, + type = "chf", + times_element = "unique.death.times", + prediction_element = "chf" + ) + attr(predict_cumulative_hazard_function, "verbose_info") <- "stepfun based on predict.ranger()$chf will be used" + attr(predict_cumulative_hazard_function, "is.default") <- TRUE + } else { + attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) + } - if (is.null(data)) { - data <- model$xvar - attr(data, "verbose_info") <- "extracted" + if (is.null(predict_function)) { + predict_function <- function(model, newdata, times) { + rowSums(predict_cumulative_hazard_function(model, newdata, times)) } + attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" + attr(predict_function, "is.default") <- TRUE + attr(predict_function, "use.times") <- TRUE + } else { + attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) + } - if (is.null(y)) { - tmp_y <- model$yvar - y <- survival::Surv(tmp_y[, 1], tmp_y[, 2]) - attr(y, "verbose_info") <- "extracted" - } + explain_survival( + model, + data = data, + y = y, + predict_function = predict_function, + predict_function_target_column = predict_function_target_column, + residual_function = residual_function, + weights = weights, + ... = ..., + label = label, + verbose = verbose, + colorize = colorize, + model_info = model_info, + type = type, + times = times, + times_generation = times_generation, + predict_survival_function = predict_survival_function, + predict_cumulative_hazard_function = predict_cumulative_hazard_function + ) +} - if (is.null(predict_survival_function)) { - predict_survival_function <- transform_to_stepfunction(predict, - type = "survival", - times_element = "time.interest", - prediction_element = "survival") - attr(predict_survival_function, "verbose_info") <- "stepfun based on predict.rfsrc()$survival will be used" - attr(predict_survival_function, "is.default") <- TRUE - } else { - attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) - } - if (is.null(predict_cumulative_hazard_function)) { - predict_cumulative_hazard_function <- transform_to_stepfunction(predict, - type = "chf", - times_element = "time.interest", - prediction_element = "chf") - attr(predict_cumulative_hazard_function, "verbose_info") <- "stepfun based on predict.rfsrc()$chf will be used" - attr(predict_cumulative_hazard_function, "is.default") <- TRUE - } else { - attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) - } +#' @export +explain.rfsrc <- function(model, + data = NULL, + y = NULL, + predict_function = NULL, + predict_function_target_column = NULL, + residual_function = NULL, + weights = NULL, + ..., + label = NULL, + verbose = TRUE, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = NULL, + type = NULL, + times = NULL, + times_generation = "quantiles", + predict_survival_function = NULL, + predict_cumulative_hazard_function = NULL) { + if (is.null(label)) { + label <- class(model)[1] + attr(label, "verbose_info") <- "default" + } - if (is.null(predict_function)) { - predict_function <- function(model, newdata, times) { - rowSums(predict_cumulative_hazard_function(model, newdata, times = times)) - } - attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" - attr(predict_function, "is.default") <- TRUE - attr(predict_function, "use.times") <- TRUE - } else { - attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) - } + if (is.null(data)) { + data <- model$xvar + attr(data, "verbose_info") <- "extracted" + } - explain_survival( - model, - data = data, - y = y, - predict_function = predict_function, - predict_function_target_column = predict_function_target_column, - residual_function = residual_function, - weights = weights, - ..., - label = label, - verbose = verbose, - colorize = colorize, - model_info = model_info, - type = type, - times = times, - times_generation = times_generation, - predict_survival_function = predict_survival_function, - predict_cumulative_hazard_function = predict_cumulative_hazard_function - ) + if (is.null(y)) { + tmp_y <- model$yvar + y <- survival::Surv(tmp_y[, 1], tmp_y[, 2]) + attr(y, "verbose_info") <- "extracted" } + if (is.null(predict_survival_function)) { + predict_survival_function <- transform_to_stepfunction(predict, + type = "survival", + times_element = "time.interest", + prediction_element = "survival" + ) + attr(predict_survival_function, "verbose_info") <- "stepfun based on predict.rfsrc()$survival will be used" + attr(predict_survival_function, "is.default") <- TRUE + } else { + attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) + } -#' @export -explain.model_fit <- - function(model, - data = NULL, - y = NULL, - predict_function = NULL, - predict_function_target_column = NULL, - residual_function = NULL, - weights = NULL, - ..., - label = NULL, - verbose = TRUE, - colorize = !isTRUE(getOption('knitr.in.progress')), - model_info = NULL, - type = NULL, + if (is.null(predict_cumulative_hazard_function)) { + predict_cumulative_hazard_function <- transform_to_stepfunction(predict, + type = "chf", + times_element = "time.interest", + prediction_element = "chf" + ) + attr(predict_cumulative_hazard_function, "verbose_info") <- "stepfun based on predict.rfsrc()$chf will be used" + attr(predict_cumulative_hazard_function, "is.default") <- TRUE + } else { + attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) + } - times = NULL, - times_generation = "quantiles", - predict_survival_function = NULL, - predict_cumulative_hazard_function = NULL) { - if (is.null(label)) { - label <- paste(rev(class(model)), collapse = "") - attr(label, "verbose_info") <- "default" + if (is.null(predict_function)) { + predict_function <- function(model, newdata, times) { + rowSums(predict_cumulative_hazard_function(model, newdata, times = times)) } + attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" + attr(predict_function, "is.default") <- TRUE + attr(predict_function, "use.times") <- TRUE + } else { + attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) + } - if (is.null(predict_survival_function)) { - predict_survival_function <- function(model, newdata, times) { prediction <- predict(model, new_data = newdata, type = "survival", eval_time = times )$.pred - return_matrix <- t(sapply(prediction, function(x) x$.pred_survival)) - return_matrix[is.na(return_matrix)] <- 0 - return_matrix - } - attr(predict_survival_function, "verbose_info") <- "predict.model_fit with type = 'survival' will be used" - attr(predict_survival_function, "is.default") <- TRUE - } else { - attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) - } + explain_survival( + model, + data = data, + y = y, + predict_function = predict_function, + predict_function_target_column = predict_function_target_column, + residual_function = residual_function, + weights = weights, + ... = ..., + label = label, + verbose = verbose, + colorize = colorize, + model_info = model_info, + type = type, + times = times, + times_generation = times_generation, + predict_survival_function = predict_survival_function, + predict_cumulative_hazard_function = predict_cumulative_hazard_function + ) +} - if (is.null(predict_cumulative_hazard_function)) { - predict_cumulative_hazard_function <- - function(object, newdata, times) survival_to_cumulative_hazard(predict_survival_function(object, newdata, times)) - attr(predict_cumulative_hazard_function, "verbose_info") <- "-log(predict_survival_function) will be used" - attr(predict_cumulative_hazard_function, "is.default") <- TRUE - } else { - attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) + +#' @export +explain.model_fit <- function(model, + data = NULL, + y = NULL, + predict_function = NULL, + predict_function_target_column = NULL, + residual_function = NULL, + weights = NULL, + ..., + label = NULL, + verbose = TRUE, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = NULL, + type = NULL, + times = NULL, + times_generation = "quantiles", + predict_survival_function = NULL, + predict_cumulative_hazard_function = NULL) { + if (is.null(label)) { + label <- paste(rev(class(model)), collapse = "") + attr(label, "verbose_info") <- "default" + } + + if (is.null(predict_survival_function)) { + predict_survival_function <- function(model, newdata, times) { + prediction <- predict(model, new_data = newdata, type = "survival", eval_time = times)$.pred + return_matrix <- t(sapply(prediction, function(x) x$.pred_survival)) + return_matrix[is.na(return_matrix)] <- 0 + return_matrix } + attr(predict_survival_function, "verbose_info") <- "predict.model_fit with type = 'survival' will be used" + attr(predict_survival_function, "is.default") <- TRUE + } else { + attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) + } - if (is.null(predict_function)) { - if (model$spec$engine %in% c("mboost", "survival", "glmnet", "flexsurv", "flexsurvspline")){ - predict_function <- function(model, newdata, times) predict(model, new_data = newdata, type = "linear_pred")$.pred_linear_pred - attr(predict_function, "verbose_info") <- "predict.model_fit with type = 'linear_pred' will be used" - } else { - predict_function <- function(model, newdata, times) rowSums(predict_cumulative_hazard_function(model, newdata, times = times)) - attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" - } - attr(predict_function, "use.times") <- TRUE - attr(predict_function, "is.default") <- TRUE + if (is.null(predict_cumulative_hazard_function)) { + predict_cumulative_hazard_function <- + function(object, newdata, times) survival_to_cumulative_hazard(predict_survival_function(object, newdata, times)) + attr(predict_cumulative_hazard_function, "verbose_info") <- "-log(predict_survival_function) will be used" + attr(predict_cumulative_hazard_function, "is.default") <- TRUE + } else { + attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) + } + + if (is.null(predict_function)) { + if (model$spec$engine %in% c("mboost", "survival", "glmnet", "flexsurv", "flexsurvspline")) { + predict_function <- function(model, newdata, times) predict(model, new_data = newdata, type = "linear_pred")$.pred_linear_pred + attr(predict_function, "verbose_info") <- "predict.model_fit with type = 'linear_pred' will be used" } else { - attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) + predict_function <- function(model, newdata, times) rowSums(predict_cumulative_hazard_function(model, newdata, times = times)) + attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" } - - explain_survival( - model, - data = data, - y = y, - predict_function = predict_function, - predict_function_target_column = predict_function_target_column, - residual_function = residual_function, - weights = weights, - ..., - label = label, - verbose = verbose, - colorize = colorize, - model_info = model_info, - type = type, - times = times, - times_generation = times_generation, - predict_survival_function = predict_survival_function, - predict_cumulative_hazard_function = predict_cumulative_hazard_function - ) + attr(predict_function, "use.times") <- TRUE + attr(predict_function, "is.default") <- TRUE + } else { + attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) } + explain_survival( + model, + data = data, + y = y, + predict_function = predict_function, + predict_function_target_column = predict_function_target_column, + residual_function = residual_function, + weights = weights, + ... = ..., + label = label, + verbose = verbose, + colorize = colorize, + model_info = model_info, + type = type, + times = times, + times_generation = times_generation, + predict_survival_function = predict_survival_function, + predict_cumulative_hazard_function = predict_cumulative_hazard_function + ) +} -#' @export -explain.LearnerSurv <- - function(model, - data = NULL, - y = NULL, - predict_function = NULL, - predict_function_target_column = NULL, - residual_function = NULL, - weights = NULL, - ..., - label = NULL, - verbose = TRUE, - colorize = !isTRUE(getOption('knitr.in.progress')), - model_info = NULL, - type = NULL, - times = NULL, - times_generation = "quantiles", - predict_survival_function = NULL, - predict_cumulative_hazard_function = NULL) { +#' @export +explain.LearnerSurv <- function(model, + data = NULL, + y = NULL, + predict_function = NULL, + predict_function_target_column = NULL, + residual_function = NULL, + weights = NULL, + ..., + label = NULL, + verbose = TRUE, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = NULL, + type = NULL, + times = NULL, + times_generation = "quantiles", + predict_survival_function = NULL, + predict_cumulative_hazard_function = NULL) { + if (is.null(label)) { + label <- class(model)[1] + attr(label, "verbose_info") <- "default" + } - if (is.null(label)) { - label <- class(model)[1] - attr(label, "verbose_info") <- "default" + if (is.null(predict_survival_function)) { + if ("distr" %in% model$predict_types) { + predict_survival_function <- function(model, newdata, times) t(model$predict_newdata(newdata)$distr$survival(times)) + attr(predict_survival_function, "verbose_info") <- "predict_newdata()$distr$survival will be used" + attr(predict_survival_function, "is.default") <- TRUE } + } else { + attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) + } - if (is.null(predict_survival_function)) { - if ("distr" %in% model$predict_types) { - predict_survival_function <- function(model, newdata, times) t(model$predict_newdata(newdata)$distr$survival(times)) - attr(predict_survival_function, "verbose_info") <- "predict_newdata()$distr$survival will be used" - attr(predict_survival_function, "is.default") <- TRUE - } - } else { - attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) + if (is.null(predict_cumulative_hazard_function)) { + if ("distr" %in% model$predict_types) { + predict_cumulative_hazard_function <- function(model, newdata, times) t(model$predict_newdata(newdata)$distr$cumHazard(times)) + attr(predict_cumulative_hazard_function, "verbose_info") <- "predict_newdata()$distr$cumHazard will be used" + attr(predict_cumulative_hazard_function, "is.default") <- TRUE } + } else { + attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) + } - if (is.null(predict_cumulative_hazard_function)) { - if ("distr" %in% model$predict_types) { - predict_cumulative_hazard_function <- function(model, newdata, times) t(model$predict_newdata(newdata)$distr$cumHazard(times)) - attr(predict_cumulative_hazard_function, "verbose_info") <- "predict_newdata()$distr$cumHazard will be used" - attr(predict_cumulative_hazard_function, "is.default") <- TRUE - } + if (is.null(predict_function)) { + if ("crank" %in% model$predict_types) { + predict_function <- function(model, newdata, times) model$predict_newdata(newdata)$crank + attr(predict_function, "verbose_info") <- "predict_newdata()$crank will be used" } else { - attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) - } - - if (is.null(predict_function)) { - if ("crank" %in% model$predict_types) { - predict_function <- function(model, newdata, times) model$predict_newdata(newdata)$crank - attr(predict_function, "verbose_info") <- "predict_newdata()$crank will be used" - } else { - predict_function <- function(model, newdata, times) { - rowSums(predict_cumulative_hazard_function(model, newdata, times)) - } - attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" + predict_function <- function(model, newdata, times) { + rowSums(predict_cumulative_hazard_function(model, newdata, times)) } - attr(predict_function, "is.default") <- TRUE - attr(predict_function, "use.times") <- TRUE - } else { - attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) + attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" } - - explain_survival( - model, - data = data, - y = y, - predict_function = predict_function, - predict_function_target_column = predict_function_target_column, - residual_function = residual_function, - weights = weights, - ..., - label = label, - verbose = verbose, - colorize = colorize, - model_info = model_info, - type = type, - times = times, - times_generation = times_generation, - predict_survival_function = predict_survival_function, - predict_cumulative_hazard_function = predict_cumulative_hazard_function - ) + attr(predict_function, "is.default") <- TRUE + attr(predict_function, "use.times") <- TRUE + } else { + attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) } + explain_survival( + model, + data = data, + y = y, + predict_function = predict_function, + predict_function_target_column = predict_function_target_column, + residual_function = residual_function, + weights = weights, + ... = ..., + label = label, + verbose = verbose, + colorize = colorize, + model_info = model_info, + type = type, + times = times, + times_generation = times_generation, + predict_survival_function = predict_survival_function, + predict_cumulative_hazard_function = predict_cumulative_hazard_function + ) +} + verbose_cat <- function(..., is.default = NULL, verbose = TRUE) { if (verbose) { if (!is.null(is.default)) { - txt <- paste(..., "(", color_codes$yellow_start, "default", color_codes$yellow_end, ")") + txt <- paste(..., "(", color_codes$yellow_start, "default", color_codes$yellow_end, ")") cat(txt, "\n") } else { cat(..., "\n") @@ -866,6 +865,8 @@ get_times_stats <- function(times) { # # colors for WARNING, NOTE, DEFAULT # -color_codes <- list(yellow_start = "\033[33m", yellow_end = "\033[39m", - red_start = "\033[31m", red_end = "\033[39m", - green_start = "\033[32m", green_end = "\033[39m") +color_codes <- list( + yellow_start = "\033[33m", yellow_end = "\033[39m", + red_start = "\033[31m", red_end = "\033[39m", + green_start = "\033[32m", green_end = "\033[39m" +) diff --git a/R/metrics.R b/R/metrics.R index bf9bdd82..e9687f9b 100644 --- a/R/metrics.R +++ b/R/metrics.R @@ -15,17 +15,15 @@ utils::globalVariables(c("PredictionSurv")) #' - \[1\] Graf, Erika, et al. ["Assessment and comparison of prognostic classification schemes for survival data."](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5) Statistics in Medicine 18.17‐18 (1999): 2529-2545. #' #' @export -loss_integrate <- function(loss_function, ..., normalization = NULL , max_quantile = 1){ - - if (!is.null(normalization)){ +loss_integrate <- function(loss_function, ..., normalization = NULL, max_quantile = 1) { + if (!is.null(normalization)) { if (!normalization %in% c("t_max", "survival")) stop("normalization should be either NULL, `t_max` or `survival`") } - integrated_loss_function <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL){ - - quantile_mask <- (times <= quantile(y_true[,1],max_quantile)) + integrated_loss_function <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { + quantile_mask <- (times <= quantile(y_true[, 1], max_quantile)) times <- times[quantile_mask] - surv <- surv[,quantile_mask] + surv <- surv[, quantile_mask] loss_values <- loss_function(y_true = y_true, risk = risk, surv = surv, times = times) @@ -35,7 +33,7 @@ loss_integrate <- function(loss_function, ..., normalization = NULL , max_quanti loss_values <- loss_values[na_mask] surv <- surv[na_mask] - calculate_integral(loss_values, times, normalization, y_true=y_true) + calculate_integral(loss_values, times, normalization, y_true = y_true) } attr(integrated_loss_function, "loss_type") <- "integrated" @@ -73,7 +71,8 @@ loss_integrate <- function(loss_function, ..., normalization = NULL , max_quanti #' rotterdam$year <- NULL #' cox_rotterdam_rec <- coxph(Surv(rtime, recur) ~ ., #' data = rotterdam, -#' model = TRUE, x = TRUE, y = TRUE) +#' model = TRUE, x = TRUE, y = TRUE +#' ) #' coxph_explainer <- explain(cox_rotterdam_rec) #' #' risk <- coxph_explainer$predict_function(coxph_explainer$model, coxph_explainer$data) @@ -82,7 +81,6 @@ loss_integrate <- function(loss_function, ..., normalization = NULL , max_quanti #' #' @export c_index <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { - n_rows <- length(y_true[, 1]) yi <- matrix(rep(y_true[, 1], n_rows), ncol = n_rows) @@ -97,7 +95,6 @@ c_index <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { bot <- sum(ifelse(yj < yi, 1, 0) * dj) top / bot - } attr(c_index, "loss_name") <- "C-index" attr(c_index, "loss_type") <- "risk-based" @@ -126,7 +123,8 @@ attr(c_index, "loss_type") <- "risk-based" #' rotterdam$year <- NULL #' cox_rotterdam_rec <- coxph(Surv(rtime, recur) ~ ., #' data = rotterdam, -#' model = TRUE, x = TRUE, y = TRUE) +#' model = TRUE, x = TRUE, y = TRUE +#' ) #' coxph_explainer <- explain(cox_rotterdam_rec) #' #' risk <- coxph_explainer$predict_function(coxph_explainer$model, coxph_explainer$data) @@ -177,8 +175,11 @@ attr(loss_one_minus_c_index, "loss_type") <- "risk-based" #' @export brier_score <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { # if times is not provided use - if (is.null(times)) times <- sort(unique(y_true[, 1])) - else times <- sort(unique(times)) + if (is.null(times)) { + times <- sort(unique(y_true[, 1])) + } else { + times <- sort(unique(times)) + } # calculate the inverse probability of censoring weights @@ -195,7 +196,7 @@ brier_score <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { ti <- t(matrix(rep(times, n_rows), ncol = n_rows)) gti <- matrix(G(ti), ncol = n_cols, nrow = n_rows) - gy <- matrix(G(y), ncol = n_cols, nrow = n_rows) + gy <- matrix(G(y), ncol = n_cols, nrow = n_rows) ind_1 <- ifelse(y <= ti & delta == 1, 1, 0) ind_2 <- ifelse(y > ti, 1, 0) @@ -203,7 +204,6 @@ brier_score <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { brier_score <- ind_1 * (surv^2) / gy + ind_2 * ((1 - surv)^2) / gti apply(brier_score, 2, mean, na.rm = TRUE) - } attr(brier_score, "loss_name") <- "Brier score" attr(brier_score, "loss_type") <- "time-dependent" @@ -253,7 +253,6 @@ attr(loss_brier_score, "loss_type") <- "time-dependent" #' #' @export cd_auc <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { - y_true[, 2] <- 1 - y_true[, 2] km <- survival::survfit(y_true ~ 1) G <- stepfun(km$time, c(1, km$surv)) @@ -275,7 +274,7 @@ cd_auc <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { survi <- matrix(rep(surv[, tt], n_rows), ncol = n_rows) survj <- t(matrix(rep(surv[, tt], n_rows), ncol = n_rows)) - top <- sum(ifelse(yj > time & yi <= time & survj > survi, 1, 0) / G(time)) + top <- sum(ifelse(yj > time & yi <= time & survj > survi, 1, 0) / G(time)) bl <- sum(ifelse(yi[, 1] > time, 1, 0)) br <- sum(ifelse(yi[, 1] <= time, 1, 0) / G(time)) @@ -284,7 +283,6 @@ cd_auc <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { }) unlist(results) - } attr(cd_auc, "loss_name") <- "C/D AUC" attr(cd_auc, "loss_type") <- "time-dependent" @@ -466,10 +464,8 @@ attr(loss_integrated_brier_score, "loss_type") <- "integrated" #' } #' #' @export -loss_adapt_mlr3proba <- function(measure, reverse = FALSE, ...){ - - loss_function <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL){ - +loss_adapt_mlr3proba <- function(measure, reverse = FALSE, ...) { + loss_function <- function(y_true = NULL, risk = NULL, surv = NULL, times = NULL) { colnames(surv) <- times surv_pred <- PredictionSurv$new( @@ -488,8 +484,11 @@ loss_adapt_mlr3proba <- function(measure, reverse = FALSE, ...){ return(output) } - if (reverse) attr(loss_function, "loss_name") <- paste("one minus", measure$id) - else attr(loss_function, "loss_name") <- measure$id + if (reverse) { + attr(loss_function, "loss_name") <- paste("one minus", measure$id) + } else { + attr(loss_function, "loss_name") <- measure$id + } attr(loss_function, "loss_type") <- "integrated" return(loss_function) diff --git a/R/misc_set_theme_survex.R b/R/misc_set_theme_survex.R index 27b24c93..e4efb9f5 100644 --- a/R/misc_set_theme_survex.R +++ b/R/misc_set_theme_survex.R @@ -19,7 +19,7 @@ #' plot(p_parts_lime) #' old <- set_theme_survex(ggplot2::theme_void(), ggplot2::theme_void()) #' plot(p_parts_lime) -#'} +#' } #' #' @importFrom DALEX theme_drwhy theme_drwhy_vertical theme_ema theme_ema_vertical #' @@ -30,12 +30,14 @@ set_theme_survex <- function(default_theme = "drwhy", default_theme_vertical = d # it should be either name or theme object if (!(any( class(default_theme) %in% c("character", "theme") - ))) + ))) { stop("The 'default_theme' shall be either character 'drwhy'/'ema' or ggplot2::theme object") + } if (!(any( class(default_theme_vertical) %in% c("character", "theme") - ))) + ))) { stop("The 'default_theme_vertical' shall be either character 'drwhy'/'ema' or ggplot2::theme object") + } # get default themes old <- .survex.env$default_themes @@ -43,10 +45,14 @@ set_theme_survex <- function(default_theme = "drwhy", default_theme_vertical = d # set themes if (is.character(default_theme)) { # from name - switch (default_theme, - drwhy = {.survex.env$default_themes <- list(default = theme_drwhy(), vertical = theme_drwhy_vertical())}, - ema = {.survex.env$default_themes <- list(default = theme_ema(), vertical = theme_ema_vertical())}, - stop("Only 'drwhy' or 'ema' names are allowed") + switch(default_theme, + drwhy = { + .survex.env$default_themes <- list(default = theme_drwhy(), vertical = theme_drwhy_vertical()) + }, + ema = { + .survex.env$default_themes <- list(default = theme_ema(), vertical = theme_ema_vertical()) + }, + stop("Only 'drwhy' or 'ema' names are allowed") ) } else { # from themes (ggplot2 objects) @@ -61,8 +67,9 @@ set_theme_survex <- function(default_theme = "drwhy", default_theme_vertical = d #' @export #' @rdname theme_survex theme_default_survex <- function() { - if (!exists("default_themes", envir = .survex.env)) + if (!exists("default_themes", envir = .survex.env)) { return(theme_drwhy()) + } .survex.env$default_themes[[1]] } @@ -70,8 +77,9 @@ theme_default_survex <- function() { #' @export #' @rdname theme_survex theme_vertical_default_survex <- function() { - if (!exists("default_themes", envir = .survex.env)) + if (!exists("default_themes", envir = .survex.env)) { return(theme_drwhy_vertical()) + } .survex.env$default_themes[[2]] } diff --git a/R/model_info.R b/R/model_info.R index c065da7d..0bb71b28 100644 --- a/R/model_info.R +++ b/R/model_info.R @@ -24,19 +24,24 @@ #' @examples #' library(survival) #' library(survex) -#' cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, -#' model = TRUE, x = TRUE, y = TRUE) +#' cph <- survival::coxph(survival::Surv(time, status) ~ ., +#' data = veteran, +#' model = TRUE, x = TRUE, y = TRUE +#' ) #' surv_model_info(cph) #' #' \donttest{ #' library(ranger) -#' rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, -#' num.trees = 50, mtry = 3, max.depth = 5) +#' rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., +#' data = veteran, +#' num.trees = 50, mtry = 3, max.depth = 5 +#' ) #' surv_model_info(rsf_ranger) #' } #' -surv_model_info <- function(model, ...) +surv_model_info <- function(model, ...) { UseMethod("surv_model_info") +} #' @rdname surv_model_info diff --git a/R/model_parts.R b/R/model_parts.R index 57145e7c..40422801 100644 --- a/R/model_parts.R +++ b/R/model_parts.R @@ -65,41 +65,41 @@ model_parts.surv_explainer <- function(explainer, if (type == "variable_importance") type <- "raw" # it's an alias switch(output_type, - "risk" = DALEX::model_parts( - explainer = explainer, - loss_function = loss_function, - ... = ..., - type = type, - N = N - ), - "survival" = { - test_explainer(explainer, has_data = TRUE, has_y = TRUE, has_survival = TRUE, function_name = "model_parts") + "risk" = DALEX::model_parts( + explainer = explainer, + loss_function = loss_function, + ... = ..., + type = type, + N = N + ), + "survival" = { + test_explainer(explainer, has_data = TRUE, has_y = TRUE, has_survival = TRUE, function_name = "model_parts") - if (attr(loss_function, "loss_type") == "integrated") { - res <- surv_integrated_feature_importance( - x = explainer, - loss_function = loss_function, - type = type, - N = N, - ... - ) - class(res) <- c("model_parts_survival", class(res)) - return(res) - } else { - res <- surv_feature_importance( - x = explainer, - loss_function = loss_function, - type = type, - N = N, - ... - ) - res$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] - res$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] - class(res) <- c("model_parts_survival", class(res)) - res - } - }, - stop("Type should be either `survival` or `risk`") + if (attr(loss_function, "loss_type") == "integrated") { + res <- surv_integrated_feature_importance( + x = explainer, + loss_function = loss_function, + type = type, + N = N, + ... + ) + class(res) <- c("model_parts_survival", class(res)) + return(res) + } else { + res <- surv_feature_importance( + x = explainer, + loss_function = loss_function, + type = type, + N = N, + ... + ) + res$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] + res$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] + class(res) <- c("model_parts_survival", class(res)) + res + } + }, + stop("Type should be either `survival` or `risk`") ) } diff --git a/R/model_performance.R b/R/model_performance.R index 4b6e568e..dd388b67 100644 --- a/R/model_performance.R +++ b/R/model_performance.R @@ -29,18 +29,22 @@ #' #' cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) #' rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., -#' data = veteran, -#' respect.unordered.factors = TRUE, -#' num.trees = 100, -#' mtry = 3, -#' max.depth = 5) +#' data = veteran, +#' respect.unordered.factors = TRUE, +#' num.trees = 100, +#' mtry = 3, +#' max.depth = 5 +#' ) #' #' rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., -#' data = veteran) +#' data = veteran +#' ) #' #' cph_exp <- explain(cph) -#' rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], -#' y = Surv(veteran$time, veteran$status)) +#' rsf_ranger_exp <- explain(rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = Surv(veteran$time, veteran$status) +#' ) #' rsf_src_exp <- explain(rsf_src) #' #' @@ -51,7 +55,9 @@ #' print(cph_model_performance) #' #' plot(rsf_ranger_model_performance, cph_model_performance, -#' rsf_src_model_performance, metrics_type = "scalar") +#' rsf_src_model_performance, +#' metrics_type = "scalar" +#' ) #' #' plot(rsf_ranger_model_performance, cph_model_performance, rsf_src_model_performance) #' @@ -64,11 +70,17 @@ model_performance <- function(explainer, ...) UseMethod("model_performance", exp #' @rdname model_performance.surv_explainer #' @export -model_performance.surv_explainer <- function(explainer, ..., type = "metrics", metrics = c("C-index" = c_index, - "Integrated C/D AUC" = integrated_cd_auc, - "Brier score" = brier_score, - "Integrated Brier score" = integrated_brier_score, - "C/D AUC" = cd_auc), times = NULL) { +model_performance.surv_explainer <- function(explainer, + ..., + type = "metrics", + metrics = c( + "C-index" = c_index, + "Integrated C/D AUC" = integrated_cd_auc, + "Brier score" = brier_score, + "Integrated Brier score" = integrated_brier_score, + "C/D AUC" = cd_auc + ), + times = NULL) { test_explainer(explainer, "model_performance", has_data = TRUE, has_y = TRUE, has_survival = TRUE, has_predict = TRUE) res <- surv_model_performance(explainer, ..., type = type, metrics = metrics, times = times) diff --git a/R/model_profile.R b/R/model_profile.R index 05022f4b..f0a536c9 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -32,16 +32,20 @@ #' cph_exp <- explain(cph) #' rsf_src_exp <- explain(rsf_src) #' -#' cph_model_profile <- model_profile(cph_exp, output_type = "survival", -#' variables = c("age")) +#' cph_model_profile <- model_profile(cph_exp, +#' output_type = "survival", +#' variables = c("age") +#' ) #' #' head(cph_model_profile$result) #' #' plot(cph_model_profile) #' -#' rsf_model_profile <- model_profile(rsf_src_exp, output_type = "survival", -#' variables = c("age", "celltype"), -#' type = "accumulated") +#' rsf_model_profile <- model_profile(rsf_src_exp, +#' output_type = "survival", +#' variables = c("age", "celltype"), +#' type = "accumulated" +#' ) #' #' head(rsf_model_profile$result) #' @@ -58,7 +62,9 @@ model_profile <- function(explainer, k = NULL, type = "partial", center = FALSE, - output_type = "survival") UseMethod("model_profile", explainer) + output_type = "survival") { + UseMethod("model_profile", explainer) +} #' @rdname model_profile.surv_explainer #' @export @@ -74,62 +80,67 @@ model_profile.surv_explainer <- function(explainer, center = FALSE, type = "partial", output_type = "survival") { - variables <- unique(variables, categorical_variables) switch(output_type, - "risk" = DALEX::model_profile(explainer = explainer, - variables = variables, - N = N, - ... = ..., - groups = groups, - k = k, - center = center, - type = type), - "survival" = { - test_explainer(explainer, "model_profile", has_data = TRUE, has_survival = TRUE) - data <- explainer$data - if (!is.null(N) && N < nrow(data)) { - ndata <- data[sample(1:nrow(data), N), , drop = FALSE] - } else { - ndata <- data[1:nrow(data), , drop = FALSE] - } - - if (type == "partial"){ - cp_profiles <- surv_ceteris_paribus(explainer, - new_observation = ndata, - variables = variables, - categorical_variables = categorical_variables, - grid_points = grid_points, - variable_splits_type = variable_splits_type, - center = center, - ...) + "risk" = DALEX::model_profile( + explainer = explainer, + variables = variables, + N = N, + ... = ..., + groups = groups, + k = k, + center = center, + type = type + ), + "survival" = { + test_explainer(explainer, "model_profile", has_data = TRUE, has_survival = TRUE) + data <- explainer$data + if (!is.null(N) && N < nrow(data)) { + ndata <- data[sample(1:nrow(data), N), , drop = FALSE] + } else { + ndata <- data[1:nrow(data), , drop = FALSE] + } - result <- surv_aggregate_profiles(cp_profiles, ..., - variables = variables) - } else if (type == "accumulated"){ - cp_profiles <- list(variable_values = data.frame(ndata)) - result <- surv_ale(explainer, - data = ndata, - variables = variables, - categorical_variables = categorical_variables, - grid_points = grid_points, - center = center, - ...) - } else { - stop("Currently only `partial` and `accumulated` types are implemented") - } + if (type == "partial") { + cp_profiles <- surv_ceteris_paribus(explainer, + new_observation = ndata, + variables = variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + variable_splits_type = variable_splits_type, + center = center, + ... + ) - ret <- list(eval_times = unique(result$`_times_`), - cp_profiles = cp_profiles, - result = result, - type = type, - center = center) - class(ret) <- c("model_profile_survival", "list") - ret$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] - ret$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] - ret - }, + result <- surv_aggregate_profiles(cp_profiles, ..., + variables = variables + ) + } else if (type == "accumulated") { + cp_profiles <- list(variable_values = data.frame(ndata)) + result <- surv_ale(explainer, + data = ndata, + variables = variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + center = center, + ... + ) + } else { + stop("Currently only `partial` and `accumulated` types are implemented") + } + ret <- list( + eval_times = unique(result$`_times_`), + cp_profiles = cp_profiles, + result = result, + type = type, + center = center + ) + class(ret) <- c("model_profile_survival", "list") + ret$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] + ret$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] + ret + }, stop("Currently only `risk` and `survival` output types are implemented") ) } diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index f8382f7e..b99a45bd 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -26,13 +26,15 @@ #' cph_exp <- explain(cph) #' #' cph_model_profile_2d <- model_profile_2d(cph_exp, -#' variables = list(c("age", "celltype"))) +#' variables = list(c("age", "celltype")) +#' ) #' head(cph_model_profile_2d$result) #' plot(cph_model_profile_2d) #' #' cph_model_profile_2d_ale <- model_profile_2d(cph_exp, -#' variables = list(c("age", "karno")), -#' type = "accumulated") +#' variables = list(c("age", "karno")), +#' type = "accumulated" +#' ) #' head(cph_model_profile_2d_ale$result) #' plot(cph_model_profile_2d_ale) #' } @@ -47,33 +49,30 @@ model_profile_2d <- function(explainer, center = FALSE, variable_splits_type = "uniform", type = "partial", - output_type = "survival") + output_type = "survival") { UseMethod("model_profile_2d", explainer) +} #' @rdname model_profile_2d.surv_explainer #' @export model_profile_2d.surv_explainer <- function(explainer, - variables = NULL, - N = 100, - categorical_variables = NULL, - grid_points = 25, - center = FALSE, - variable_splits_type = "uniform", - type = "partial", - output_type = "survival" - ) { - - if (is.null(variables) | !is.list(variables) | !all(sapply(variables, length) == 2)) + variables = NULL, + N = 100, + categorical_variables = NULL, + grid_points = 25, + center = FALSE, + variable_splits_type = "uniform", + type = "partial", + output_type = "survival") { + if (is.null(variables) || !is.list(variables) || !all(sapply(variables, length) == 2)) { stop("'variables' must be specified as a list of pairs (two-element vectors)") + } if (output_type != "survival") { stop("Currently only `survival` output type is implemented") } - test_explainer(explainer, - "model_profile", - has_data = TRUE, - has_survival = TRUE) + test_explainer(explainer, "model_profile", has_data = TRUE, has_survival = TRUE) data <- explainer$data if (!is.null(N) && N < nrow(data)) { @@ -127,8 +126,7 @@ surv_pdp_2d <- function(x, categorical_variables, grid_points, variable_splits_type, - center - ) { + center) { model <- x$model label <- x$label predict_survival_function <- x$predict_survival_function @@ -136,29 +134,34 @@ surv_pdp_2d <- function(x, unique_variables <- unlist(variables) variable_splits <- calculate_variable_split(data, - variables = unique_variables, - categorical_variables = categorical_variables, - grid_points = grid_points, - variable_splits_type = variable_splits_type) + variables = unique_variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + variable_splits_type = variable_splits_type + ) - profiles <- lapply(variables, FUN = function(variables_pair){ + profiles <- lapply(variables, FUN = function(variables_pair) { var1 <- variables_pair[1] var2 <- variables_pair[2] - expanded_data <- merge(variable_splits[[var1]], data[,!colnames(data) %in% variables_pair]) + expanded_data <- merge(variable_splits[[var1]], data[, !colnames(data) %in% variables_pair]) names(expanded_data)[colnames(expanded_data) == "x"] <- var1 expanded_data <- merge(variable_splits[[var2]], expanded_data) names(expanded_data)[colnames(expanded_data) == "x"] <- var2 - expanded_data <- expanded_data[,colnames(data)] + expanded_data <- expanded_data[, colnames(data)] - predictions_original <- predict_survival_function(model = model, - newdata = data, - times = times) + predictions_original <- predict_survival_function( + model = model, + newdata = data, + times = times + ) mean_pred <- colMeans(predictions_original) - predictions <- predict_survival_function(model = model, - newdata = expanded_data, - times = times) + predictions <- predict_survival_function( + model = model, + newdata = expanded_data, + times = times + ) preds <- c(t(predictions)) if (center) { @@ -170,15 +173,15 @@ surv_pdp_2d <- function(x, "_v2name_" = var2, "_v1type_" = ifelse(var1 %in% categorical_variables, "categorical", "numerical"), "_v2type_" = ifelse(var2 %in% categorical_variables, "categorical", "numerical"), - "_v1value_" = as.character(rep(expanded_data[,var1], each=length(times))), - "_v2value_" = as.character(rep(expanded_data[,var2], each=length(times))), + "_v1value_" = as.character(rep(expanded_data[, var1], each = length(times))), + "_v2value_" = as.character(rep(expanded_data[, var2], each = length(times))), "_times_" = rep(times, nrow(expanded_data)), "_yhat_" = preds, "_label_" = label, check.names = FALSE ) - return(aggregate(`_yhat_`~., data = res, FUN=mean)) - }) + return(aggregate(`_yhat_` ~ ., data = res, FUN = mean)) + }) profiles <- do.call(rbind, profiles) profiles @@ -189,24 +192,25 @@ surv_ale_2d <- function(x, variables, categorical_variables, grid_points, - center - ){ + center) { model <- x$model label <- x$label predict_survival_function <- x$predict_survival_function times <- x$times - predictions_original <- predict_survival_function(model = model, - newdata = data, - times = times) + predictions_original <- predict_survival_function( + model = model, + newdata = data, + times = times + ) mean_pred <- colMeans(predictions_original) - profiles <- lapply(variables, FUN = function(variables_pair){ + profiles <- lapply(variables, FUN = function(variables_pair) { var1 <- variables_pair[1] var2 <- variables_pair[2] - if (all(!variables_pair %in% categorical_variables)){ + if (all(!variables_pair %in% categorical_variables)) { surv_ale_2d_num_num( model, data, @@ -222,7 +226,6 @@ surv_ale_2d <- function(x, } else { stop("Currently 2D ALE are implemented only for pairs of numerical variables") } - }) profiles <- do.call(rbind, profiles) @@ -239,23 +242,25 @@ surv_ale_2d_num_num <- function(model, var1, var2, mean_pred, - center){ + center) { # Number of quantile points for determined by grid length quantile_vals1 <- as.numeric(quantile(data[, var1], - seq(0.01, 1, length.out = grid_points), - type = 1)) + seq(0.01, 1, length.out = grid_points), + type = 1 + )) quantile_vals2 <- as.numeric(quantile(data[, var2], - seq(0.01, 1, length.out = grid_points), - type = 1)) + seq(0.01, 1, length.out = grid_points), + type = 1 + )) quantile_vec1 <- unique(c(min(data[, var1]), quantile_vals1)) quantile_vec2 <- unique(c(min(data[, var2]), quantile_vals2)) data <- data[(data[, var1] <= max(quantile_vec1)) & - (data[, var1] >= min(quantile_vec1)) & - (data[, var2] <= max(quantile_vec2)) & - (data[, var2] >= min(quantile_vec2)), ] + (data[, var1] >= min(quantile_vec1)) & + (data[, var2] <= max(quantile_vec2)) & + (data[, var2] >= min(quantile_vec2)), ] # Matching instances to the grids of both features interval_index1 <- findInterval(data[, var1], quantile_vec1, left.open = TRUE) @@ -265,14 +270,22 @@ surv_ale_2d_num_num <- function(model, interval_index2[interval_index2 == 0] <- 1 X_low1_low2 <- X_up1_low2 <- X_low1_up2 <- X_up1_up2 <- data - X_low1_low2[, c(var1, var2)] <- cbind(quantile_vec1[interval_index1], - quantile_vec2[interval_index2]) - X_up1_low2[, c(var1, var2)] <- cbind(quantile_vec1[interval_index1 + 1], - quantile_vec2[interval_index2]) - X_low1_up2[, c(var1, var2)] <- cbind(quantile_vec1[interval_index1], - quantile_vec2[interval_index2 + 1]) - X_up1_up2[, c(var1, var2)] <- cbind(quantile_vec1[interval_index1 + 1], - quantile_vec2[interval_index2 + 1]) + X_low1_low2[, c(var1, var2)] <- cbind( + quantile_vec1[interval_index1], + quantile_vec2[interval_index2] + ) + X_up1_low2[, c(var1, var2)] <- cbind( + quantile_vec1[interval_index1 + 1], + quantile_vec2[interval_index2] + ) + X_low1_up2[, c(var1, var2)] <- cbind( + quantile_vec1[interval_index1], + quantile_vec2[interval_index2 + 1] + ) + X_up1_up2[, c(var1, var2)] <- cbind( + quantile_vec1[interval_index1 + 1], + quantile_vec2[interval_index2 + 1] + ) y_hat_11 <- predict_survival_function(model = model, newdata = X_low1_low2, times = times) y_hat_21 <- predict_survival_function(model = model, newdata = X_up1_low2, times = times) @@ -289,16 +302,17 @@ surv_ale_2d_num_num <- function(model, yhat = c(t(prediction_deltas)) ) - deltas <- aggregate(`yhat`~., data = deltas, FUN=mean) + deltas <- aggregate(`yhat` ~ ., data = deltas, FUN = mean) interval_grid <- expand.grid( interval1 = c(0, sort(unique(deltas$interval1))), interval2 = c(0, sort(unique(deltas$interval2))), time = times ) deltas <- merge(deltas, - interval_grid, - on = c("interval1", "interval2"), - all.y = TRUE) + interval_grid, + on = c("interval1", "interval2"), + all.y = TRUE + ) deltas$yhat_cumsum <- ave(deltas$yhat, deltas$time, deltas$interval1, FUN = function(x) cumsum(ifelse(is.na(x), 0, x))) deltas$yhat_cumsum <- ave(deltas$yhat_cumsum, deltas$time, deltas$interval2, FUN = function(x) cumsum(ifelse(is.na(x), 0, x))) @@ -316,18 +330,21 @@ surv_ale_2d_num_num <- function(model, # Computing the first-order effect of feature 1 res <- ale res$yhat_diff <- ave(ale$yhat_cumsum, - list(ale$interval2, ale$time), - FUN = function(x) c(x[1], diff(x))) + list(ale$interval2, ale$time), + FUN = function(x) c(x[1], diff(x)) + ) - ale1 <- do.call("rbind", lapply(sort(unique(res$interval1)), function(x){ + ale1 <- do.call("rbind", lapply(sort(unique(res$interval1)), function(x) { counts <- res[res$interval1 == x & res$time == times[1], "count"] - aggregate(yhat_diff~time, data=res[res$interval1==x,], - FUN = function(vals){ - sum(counts[-1] * (vals[-length(vals)] + vals[-1]) /2 / sum(counts[-1])) - }) + aggregate(yhat_diff ~ time, + data = res[res$interval1 == x, ], + FUN = function(vals) { + sum(counts[-1] * (vals[-length(vals)] + vals[-1]) / 2 / sum(counts[-1])) + } + ) })) - ale1$interval1 <- rep(sort(unique(res$interval1)), each=length(times)) + ale1$interval1 <- rep(sort(unique(res$interval1)), each = length(times)) ale1$yhat_diff[is.na(ale1$yhat_diff)] <- 0 ale1$ale1 <- ave(ale1$yhat_diff, ale1$time, FUN = cumsum) @@ -335,41 +352,45 @@ surv_ale_2d_num_num <- function(model, # Computing the first-order effect of feature 2 res <- ale res$yhat_diff <- ave(ale$yhat_cumsum, - list(ale$interval1, ale$time), - FUN = function(x) c(x[1], diff(x))) + list(ale$interval1, ale$time), + FUN = function(x) c(x[1], diff(x)) + ) - ale2 <- do.call("rbind", lapply(sort(unique(res$interval2)), function(x){ + ale2 <- do.call("rbind", lapply(sort(unique(res$interval2)), function(x) { counts <- res[res$interval2 == x & res$time == times[1], "count"] - aggregate(yhat_diff~time, data=res[res$interval2==x,], - FUN = function(vals){ - sum(counts[-1] * (vals[-length(vals)] + vals[-1]) /2 / sum(counts[-1])) - }) + aggregate(yhat_diff ~ time, + data = res[res$interval2 == x, ], + FUN = function(vals) { + sum(counts[-1] * (vals[-length(vals)] + vals[-1]) / 2 / sum(counts[-1])) + } + ) })) - ale2$interval2 <- rep(sort(unique(res$interval2)), each=length(times)) + ale2$interval2 <- rep(sort(unique(res$interval2)), each = length(times)) ale2$yhat_diff[is.na(ale2$yhat_diff)] <- 0 ale2$ale2 <- ave(ale2$yhat_diff, ale2$time, FUN = cumsum) fJ0 <- unlist(lapply(times, function(time) { - ale_time <- ale[ale$time == time,] - ale1_time <- ale1[ale1$time == time,] - ale2_time <- ale2[ale2$time == time,] + ale_time <- ale[ale$time == time, ] + ale1_time <- ale1[ale1$time == time, ] + ale2_time <- ale2[ale2$time == time, ] ale_time <- ale_time[c("interval1", "interval2", "yhat_cumsum")] dd <- reshape(ale_time, - idvar = "interval1", - timevar = "interval2", - direction = "wide")[,-1] + idvar = "interval1", + timevar = "interval2", + direction = "wide" + )[, -1] rownames(dd) <- unique(ale_time$interval1) colnames(dd) <- unique(ale_time$interval2) dd <- dd - outer(ale1_time$ale1, rep(1, nrow(ale2_time))) - outer(rep(1, nrow(ale1_time)), ale2_time$ale2) sum(cell_counts * (dd[1:(nrow(dd) - 1), 1:(ncol(dd) - 1)] + - dd[1:(nrow(dd) - 1), 2:ncol(dd)] + - dd[2:nrow(dd), 1:(ncol(dd) - 1)] + - dd[2:nrow(dd), 2:ncol(dd)]) / 4, na.rm = TRUE) / sum(cell_counts) + dd[1:(nrow(dd) - 1), 2:ncol(dd)] + + dd[2:nrow(dd), 1:(ncol(dd) - 1)] + + dd[2:nrow(dd), 2:ncol(dd)]) / 4, na.rm = TRUE) / sum(cell_counts) })) fJ0 <- data.frame("fJ0" = fJ0, time = times) @@ -377,9 +398,9 @@ surv_ale_2d_num_num <- function(model, ale <- merge(ale, ale1, by = c("time", "interval1")) ale <- merge(ale, ale2, by = c("time", "interval2")) ale$ale <- ale$yhat_cumsum - ale$ale1 - ale$ale2 - ale$fJ0 - ale <- ale[order(ale$interval1, ale$interval2, ale$time),] + ale <- ale[order(ale$interval1, ale$interval2, ale$time), ] - if (!center){ + if (!center) { ale$ale <- ale$ale + mean_pred } @@ -413,7 +434,6 @@ surv_ale_2d_num_num <- function(model, "_bottom_" = ale$bottom, "_count_" = ifelse(is.na(ale$count), 0, ale$count), "_label_" = label, - check.names = FALSE) + check.names = FALSE + ) } - - diff --git a/R/model_survshap.R b/R/model_survshap.R index 256804eb..9b0581e5 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -16,28 +16,32 @@ #' \donttest{ #' veteran <- survival::veteran #' rsf_ranger <- ranger::ranger( -#' survival::Surv(time, status) ~ ., -#' data = veteran, -#' respect.unordered.factors = TRUE, -#' num.trees = 100, -#' mtry = 3, -#' max.depth = 5 +#' survival::Surv(time, status) ~ ., +#' data = veteran, +#' respect.unordered.factors = TRUE, +#' num.trees = 100, +#' mtry = 3, +#' max.depth = 5 #' ) #' rsf_ranger_exp <- explain( -#' rsf_ranger, -#' data = veteran[, -c(3, 4)], -#' y = survival::Surv(veteran$time, veteran$status), -#' verbose = FALSE +#' rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = survival::Surv(veteran$time, veteran$status), +#' verbose = FALSE #' ) #' #' ranger_global_survshap <- model_survshap( -#' explainer = rsf_ranger_exp, -#' new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), -#' !colnames(veteran) %in% c("time", "status")], -#' y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], -#' veteran$status[c(1:4, 17:20, 110:113, 126:129)]), -#' aggregation_method = "integral", -#' calculation_method = "kernelshap", +#' explainer = rsf_ranger_exp, +#' new_observation = veteran[ +#' c(1:4, 17:20, 110:113, 126:129), +#' !colnames(veteran) %in% c("time", "status") +#' ], +#' y_true = survival::Surv( +#' veteran$time[c(1:4, 17:20, 110:113, 126:129)], +#' veteran$status[c(1:4, 17:20, 110:113, 126:129)] +#' ), +#' aggregation_method = "integral", +#' calculation_method = "kernelshap", #' ) #' plot(ranger_global_survshap) #' plot(ranger_global_survshap, geom = "beeswarm") @@ -46,9 +50,9 @@ #' #' @rdname model_survshap.surv_explainer #' @export -model_survshap <- - function(explainer, ...) +model_survshap <- function(explainer, ...) { UseMethod("model_survshap", explainer) + } #' @rdname model_survshap.surv_explainer #' @export @@ -58,7 +62,6 @@ model_survshap.surv_explainer <- function(explainer, calculation_method = "kernelshap", aggregation_method = "integral", ...) { - stopifnot( "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( !is.null(y_true), @@ -68,7 +71,8 @@ model_survshap.surv_explainer <- function(explainer, is.null(dim(y_true)) && length(y_true) == 2L ), TRUE - )) + ) + ) test_explainer( explainer, @@ -101,5 +105,4 @@ model_survshap.surv_explainer <- function(explainer, shap_values$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] shap_values$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] return(shap_values) - } diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index 362df684..fd192a55 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -25,14 +25,18 @@ #' cph_exp <- explain(cph) #' #' cph_model_profile_2d <- model_profile_2d(cph_exp, -#' variables = list(c("age", "celltype"), -#' c("age", "karno"))) +#' variables = list( +#' c("age", "celltype"), +#' c("age", "karno") +#' ) +#' ) #' head(cph_model_profile_2d$result) #' plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 103) #' #' cph_model_profile_2d_ale <- model_profile_2d(cph_exp, -#' variables = list(c("age", "karno")), -#' type = "accumulated") +#' variables = list(c("age", "karno")), +#' type = "accumulated" +#' ) #' head(cph_model_profile_2d_ale$result) #' plot(cph_model_profile_2d_ale, times = c(8, 103), marginalize_over_time = TRUE) #' } @@ -46,61 +50,65 @@ plot.model_profile_2d_survival <- function(x, facet_ncol = NULL, title = "default", subtitle = "default", - colors = NULL){ - + colors = NULL) { explanations_list <- c(list(x), list(...)) num_models <- length(explanations_list) - if (title == "default"){ - if (x$type == "partial") + if (title == "default") { + if (x$type == "partial") { title <- "2D partial dependence survival profiles" - if (x$type == "accumulated") + } + if (x$type == "accumulated") { title <- "2D accumulated local effects survival profiles" + } } if (!is.null(variables)) { variables <- intersect(x$variables, variables) - if (length(variables) == 0) + if (length(variables) == 0) { stop(paste0( "variables do not overlap with ", paste(x$variables, collapse = ", ") )) + } } else { variables <- x$variables } - if (is.null(colors)) + if (is.null(colors)) { colors <- c("#c7f5bf", "#8bdcbe", "#46bac2", "#4378bf", "#371ea3") + } - if (num_models == 1){ + if (num_models == 1) { result <- prepare_model_profile_2d_plots(x, - variables = variables, - times = times, - marginalize_over_time = marginalize_over_time, - facet_ncol = facet_ncol, - title = title, - subtitle = subtitle, - colors = colors - ) + variables = variables, + times = times, + marginalize_over_time = marginalize_over_time, + facet_ncol = facet_ncol, + title = title, + subtitle = subtitle, + colors = colors + ) return(result) } return_list <- list() labels <- list() - for (i in 1:num_models){ + for (i in 1:num_models) { this_title <- unique(explanations_list[[i]]$result$`_label_`) - return_list[[i]] <- prepare_model_profile_2d_plots(explanations_list[[i]], - variables = variables, - times = times, - marginalize_over_time = marginalize_over_time, - facet_ncol = 1, - title = this_title, - subtitle = subtitle, - colors = colors) - labels[[i]] <- c(this_title, rep("", length(variables)-1)) + return_list[[i]] <- prepare_model_profile_2d_plots(explanations_list[[i]], + variables = variables, + times = times, + marginalize_over_time = marginalize_over_time, + facet_ncol = 1, + title = this_title, + subtitle = subtitle, + colors = colors + ) + labels[[i]] <- c(this_title, rep("", length(variables) - 1)) } labels <- unlist(labels) - patchwork::wrap_plots(return_list, nrow = 1, tag_level="keep") + + patchwork::wrap_plots(return_list, nrow = 1, tag_level = "keep") + patchwork::plot_annotation(title, tag_levels = list(labels)) & theme_default_survex() } @@ -112,8 +120,7 @@ prepare_model_profile_2d_plots <- function(x, facet_ncol, title, subtitle, - colors -){ + colors) { if (is.null(times)) { times <- quantile(x$eval_times, p = 0.5, type = 1) warning("Plot will be prepared for the median time point from the explainer's `times` vector. For another time point, set the value of `times`.") @@ -135,70 +142,80 @@ prepare_model_profile_2d_plots <- function(x, all_profiles <- x$result df_time <- all_profiles[all_profiles$`_times_` %in% times, ] df_time$`_times_` <- NULL - if (marginalize_over_time){ - df_time <- aggregate(`_yhat_`~., data=df_time, FUN=mean) + if (marginalize_over_time) { + df_time <- aggregate(`_yhat_` ~ ., data = df_time, FUN = mean) } sf_range <- range(df_time$`_yhat_`) - pl <- lapply(seq_along(variables), function(i){ + pl <- lapply(seq_along(variables), function(i) { variable_pair <- variables[[i]] df <- df_time[df_time$`_v1name_` == variable_pair[1] & - df_time$`_v2name_` == variable_pair[2],] - if (any(df$`_v1type_` == "numerical")) + df_time$`_v2name_` == variable_pair[2], ] + if (any(df$`_v1type_` == "numerical")) { df$`_v1value_` <- as.numeric(as.character(df$`_v1value_`)) - else if (any(df$`_v1type_` == "categorical")) + } else if (any(df$`_v1type_` == "categorical")) { df$`_v1value_` <- as.character(df$`_v1value_`) - if (any(df$`_v2type_` == "numerical")) + } + if (any(df$`_v2type_` == "numerical")) { df$`_v2value_` <- as.numeric(as.character(df$`_v2value_`)) - else if (any(df$`_v2type_` == "categorical")) + } else if (any(df$`_v2type_` == "categorical")) { df$`_v2value_` <- as.character(df$`_v2value_`) + } xlabel <- unique(df$`_v1name_`) ylabel <- unique(df$`_v2name_`) - if (x$type == "partial"){ + if (x$type == "partial") { p <- with(df, { - ggplot(df, - aes(x = `_v1value_`, y = `_v2value_`, fill = `_yhat_`)) + + ggplot( + df, + aes(x = `_v1value_`, y = `_v2value_`, fill = `_yhat_`) + ) + geom_tile() + - scale_fill_gradientn(name = "PDP value", - colors = rev(grDevices::colorRampPalette(colors)(10)), - limits = sf_range) + + scale_fill_gradientn( + name = "PDP value", + colors = rev(grDevices::colorRampPalette(colors)(10)), + limits = sf_range + ) + labs(x = xlabel, y = ylabel) + theme(legend.position = "top") + - facet_wrap(~paste(`_v1name_`, `_v2name_`, sep = " : ")) + facet_wrap(~ paste(`_v1name_`, `_v2name_`, sep = " : ")) }) } else { - p <- with(df, { + p <- with(df, { ggplot(df, aes(x = `_v1value_`, y = `_v2value_`, fill = `_yhat_`)) + - geom_rect(aes(ymin = `_bottom_`, ymax = `_top_`, - xmin = `_left_`, xmax = `_right_`)) + - scale_fill_gradientn(name = "ALE value", - colors = rev(grDevices::colorRampPalette(colors)(10)), - limits = sf_range) + + geom_rect(aes( + ymin = `_bottom_`, ymax = `_top_`, + xmin = `_left_`, xmax = `_right_` + )) + + scale_fill_gradientn( + name = "ALE value", + colors = rev(grDevices::colorRampPalette(colors)(10)), + limits = sf_range + ) + labs(x = xlabel, y = ylabel) + theme(legend.position = "top") + - facet_wrap(~paste(`_v1name_`, `_v2name_`, sep = " : ")) - }) + facet_wrap(~ paste(`_v1name_`, `_v2name_`, sep = " : ")) + }) } - if (i != length(variables)) + if (i != length(variables)) { p <- p + guides(fill = "none") + } return(p) }) if (!is.null(subtitle) && subtitle == "default") { labels <- paste0(unique(all_profiles$`_label_`), collapse = ", ") subtitle <- paste0("created for the ", labels, " model") - if (!marginalize_over_time) + if (!marginalize_over_time) { subtitle <- paste0(subtitle, " and t = ", times) + } } patchwork::wrap_plots(pl, ncol = facet_ncol) & - patchwork::plot_annotation(title = title, - subtitle = subtitle) & theme_default_survex() & + patchwork::plot_annotation( + title = title, + subtitle = subtitle + ) & theme_default_survex() & plot_layout(guides = "collect") } - - - - diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index c3121859..2dd0c587 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -60,7 +60,6 @@ plot.model_profile_survival <- function(x, colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222")) { - if (!geom %in% c("time", "variable")) { stop("`geom` must be one of 'time' or 'survival'.") } @@ -69,7 +68,7 @@ plot.model_profile_survival <- function(x, if (x$type == "partial") { title <- "Partial dependence survival profiles" if (geom == "variable") { - title <- "default" + title <- "default" } } if (x$type == "accumulated") { @@ -78,7 +77,6 @@ plot.model_profile_survival <- function(x, } if (geom == "variable") { - pl <- plot2( x = x, variable = variables, @@ -106,7 +104,8 @@ plot.model_profile_survival <- function(x, num_models <- length(explanations_list) if (num_models == 1) { - result <- prepare_model_profile_plots(x, + result <- prepare_model_profile_plots( + x, variables = variables, variable_type = variable_type, facet_ncol = facet_ncol, @@ -124,7 +123,8 @@ plot.model_profile_survival <- function(x, labels <- list() for (i in 1:num_models) { this_title <- unique(explanations_list[[i]]$result$`_label_`) - return_list[[i]] <- prepare_model_profile_plots(explanations_list[[i]], + return_list[[i]] <- prepare_model_profile_plots( + explanations_list[[i]], variables = variables, variable_type = variable_type, facet_ncol = 1, @@ -166,7 +166,7 @@ plot2 <- function(x, title <- "Individual conditional expectation survival profiles" } else if (plot_type == "pdp+ice") { title <- "Partial dependence with individual conditional expectation survival profiles" - } else if (plot_type == "pdp"){ + } else if (plot_type == "pdp") { title <- "Partial dependence survival profiles" } @@ -206,7 +206,7 @@ plot2 <- function(x, single_timepoint <- ((length(times) == 1) || marginalize_over_time) if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", unique(variable), " variable") - if (single_timepoint && !marginalize_over_time){ + if (single_timepoint && !marginalize_over_time) { subtitle <- paste0(subtitle, " and time=", times) } } @@ -269,10 +269,13 @@ plot2 <- function(x, ice_df <- aggregate(predictions ~ ., data = ice_df, mean) color_scale <- generate_discrete_color_scale(1, colors) } else { - if (is.null(colors) | length(colors) < 3) - color_scale <- c(low = "#9fe5bd", - mid = "#46bac2", - high = "#371ea3") + if (is.null(colors) || length(colors) < 3) { + color_scale <- c( + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3" + ) + } } if (is_categorical) { @@ -291,7 +294,7 @@ plot2 <- function(x, ) } else { pdp_df[, 1] <- as.numeric(as.character(pdp_df[, 1])) - x_width <- diff(range(pdp_df[,variable])) + x_width <- diff(range(pdp_df[, variable])) pl <- plot_pdp_num( pdp_dt = pdp_df, ice_dt = ice_df, @@ -333,7 +336,7 @@ plot_pdp_num <- function(pdp_dt, if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(alpha = 0.2, mapping = aes(group = id), color = colors_discrete_drwhy(1)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width = 0.01 * x_width)) + ylim(y_floor_ice, y_ceiling_ice) } # PDP + ICE @@ -341,30 +344,32 @@ plot_pdp_num <- function(pdp_dt, ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(mapping = aes(group = id), alpha = 0.2) + geom_line(data = pdp_dt, aes(x = !!feature_name_sym, y = pd), linewidth = 2, color = colors_discrete_drwhy(1)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width = 0.01 * x_width)) + ylim(y_floor_ice, y_ceiling_ice) } # PDP else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + geom_line(color = colors_discrete_drwhy(1)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width = 0.01 * x_width)) + ylim(y_floor_pd, y_ceiling_pd) } } else { ## multiple timepoints pdp_dt$time <- as.numeric(as.character(pdp_dt$time)) - if (!is.null(ice_dt)) + if (!is.null(ice_dt)) { ice_dt$time <- as.numeric(as.character(ice_dt$time)) + } if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_sym, y = predictions)) + geom_line(alpha = 0.2, mapping = aes(group = interaction(id, time), color = time)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = position_jitter(width = 0.01 * x_width)) + scale_colour_gradient2( low = colors[1], mid = colors[2], high = colors[3], - midpoint = median(as.numeric(as.character(pdp_dt$time)))) + + midpoint = median(as.numeric(as.character(pdp_dt$time))) + ) + ylim(y_floor_ice, y_ceiling_ice) } # PDP + ICE @@ -373,24 +378,26 @@ plot_pdp_num <- function(pdp_dt, geom_line(data = ice_dt, aes(x = !!feature_name_sym, y = predictions, group = interaction(id, time), color = time), alpha = 0.1) + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, color = time, group = time), linewidth = 1.5, lineend = "round", linejoin = "round") + geom_path(data = pdp_dt, aes(x = !!feature_name_sym, y = pd, group = time), color = "black", linewidth = 0.5, linetype = "dashed", lineend = "round", linejoin = "round") + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_ice), sides = "b", alpha = 0.8, position = position_jitter(width = 0.01 * x_width)) + scale_colour_gradient2( low = colors[1], mid = colors[2], high = colors[3], - midpoint = median(as.numeric(as.character(pdp_dt$time)))) + + midpoint = median(as.numeric(as.character(pdp_dt$time))) + ) + ylim(y_floor_ice, y_ceiling_ice) } # PDP else if (plot_type == "pdp" || plot_type == "ale") { ggplot(data = pdp_dt, aes(x = !!feature_name_sym, y = pd)) + geom_line(aes(color = time, group = time)) + - geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width=0.01 * x_width)) + + geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width = 0.01 * x_width)) + scale_colour_gradient2( low = colors[1], mid = colors[2], high = colors[3], - midpoint = median(as.numeric(as.character(pdp_dt$time)))) + + midpoint = median(as.numeric(as.character(pdp_dt$time))) + ) + ylim(y_floor_pd, y_ceiling_pd) } } diff --git a/R/plot_predict_profile_survival.R b/R/plot_predict_profile_survival.R index 406b00e8..3da481f5 100644 --- a/R/plot_predict_profile_survival.R +++ b/R/plot_predict_profile_survival.R @@ -49,8 +49,6 @@ #' } #' @export plot.predict_profile_survival <- function(x, ...) { - class(x) <- class(x)[-1] plot(x, ...) - } diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index b432ba88..b98c0dd5 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -52,51 +52,54 @@ plot.surv_ceteris_paribus <- function(x, subtitle = "default", rug = "all", rug_colors = c("#dd0000", "#222222")) { - if (!is.null(variable_type)) + if (!is.null(variable_type)) { check_variable_type(variable_type) + } check_numerical_plot_type(numerical_plot_type) explanations_list <- c(list(x), list(...)) num_models <- length(explanations_list) - if (num_models == 1){ - result <- prepare_ceteris_paribus_plots(x, - colors, - variable_type, - facet_ncol, - variables, - numerical_plot_type, - title, - subtitle, - rug, - rug_colors) + if (num_models == 1) { + result <- prepare_ceteris_paribus_plots( + x, + colors, + variable_type, + facet_ncol, + variables, + numerical_plot_type, + title, + subtitle, + rug, + rug_colors + ) return(result) } return_list <- list() labels <- list() - for (i in 1:num_models){ + for (i in 1:num_models) { this_title <- unique(explanations_list[[i]]$result$`_label_`) - return_list[[i]] <- prepare_ceteris_paribus_plots(explanations_list[[i]], - colors, - variable_type, - 1, - variables, - numerical_plot_type, - this_title, - NULL, - rug, - rug_colors) - labels[[i]] <- c(this_title, rep("", length(return_list[[i]]$patches)-2)) + return_list[[i]] <- prepare_ceteris_paribus_plots( + explanations_list[[i]], + colors, + variable_type, + 1, + variables, + numerical_plot_type, + this_title, + NULL, + rug, + rug_colors + ) + labels[[i]] <- c(this_title, rep("", length(return_list[[i]]$patches) - 2)) } labels <- unlist(labels) - return_plot <- patchwork::wrap_plots(return_list, nrow = 1, tag_level="keep") + - patchwork::plot_annotation(title, tag_levels = list(labels)) & theme_default_survex() + return_plot <- patchwork::wrap_plots(return_list, nrow = 1, tag_level = "keep") + + patchwork::plot_annotation(title, tag_levels = list(labels)) & theme_default_survex() return(return_plot) - - } @@ -109,7 +112,7 @@ prepare_ceteris_paribus_plots <- function(x, title = "Ceteris paribus survival profile", subtitle = "default", rug = "all", - rug_colors = c("#dd0000", "#222222")){ + rug_colors = c("#dd0000", "#222222")) { rug_df <- data.frame(times = x$event_times, statuses = as.character(x$event_statuses), label = unique(x$result$`_label_`)) obs <- as.data.frame(x$variable_values) center <- x$center @@ -128,15 +131,15 @@ prepare_ceteris_paribus_plots <- function(x, } # variables to use - all_variables <- - na.omit(as.character(unique(all_profiles$`_vname_`))) + all_variables <- na.omit(as.character(unique(all_profiles$`_vname_`))) if (!is.null(variables)) { variables <- intersect(all_variables, variables) - if (length(variables) == 0) + if (length(variables) == 0) { stop(paste0( "variables do not overlap with ", paste(all_variables, collapse = ", ") )) + } all_variables <- variables } @@ -152,14 +155,18 @@ prepare_ceteris_paribus_plots <- function(x, all_variables <- intersect(all_variables, unique(x$`_vname_`)) lsc <- lapply(all_variables, function(sv) { - tmp <- x[x$`_vname_` == sv, - c(sv, - "_times_", - "_vname_", - "_vtype_", - "_yhat_", - "_label_", - "_ids_")] + tmp <- x[ + x$`_vname_` == sv, + c( + sv, + "_times_", + "_vname_", + "_vtype_", + "_yhat_", + "_label_", + "_ids_" + ) + ] key <- obs[, sv, drop = FALSE] tmp$`_real_point_` <- tmp[, sv] == key[, sv] @@ -179,12 +186,14 @@ prepare_ceteris_paribus_plots <- function(x, rug_df = rug_df, rug = rug, rug_colors = rug_colors, - center = center) + center = center + ) patchwork::wrap_plots(pl, ncol = facet_ncol) + - patchwork::plot_annotation(title = title, - subtitle = subtitle) & theme_default_survex() - + patchwork::plot_annotation( + title = title, + subtitle = subtitle + ) & theme_default_survex() } #' @import ggplot2 @@ -197,85 +206,98 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, rug, rug_colors, center) { - pl <- lapply(variables, function(var) { df <- all_profiles[all_profiles$`_vname_` == var, ] if (unique(df$`_vtype_`) == "numerical") { - if (!is.null(colors)) + if (!is.null(colors)) { scale_cont <- colors - else - scale_cont <- c(low = "#9fe5bd", - mid = "#46bac2", - high = "#371ea3") + } else { + scale_cont <- c( + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3" + ) + } if (numerical_plot_type == "lines") { base_plot <- with(df, { - ggplot( - df, - aes( - x = `_times_`, - y = `_yhat_`, - group = `_x_`, - color = as.numeric(as.character(`_x_`)) - ) - ) + - geom_line() + - scale_colour_gradient2( - name = paste0(unique(df$`_vname_`), " value"), - low = scale_cont[1], - high = scale_cont[3], - mid = scale_cont[2], - midpoint = median(as.numeric(as.character(df$`_x_`))) + ggplot( + df, + aes( + x = `_times_`, + y = `_yhat_`, + group = `_x_`, + color = as.numeric(as.character(`_x_`)) + ) ) + - geom_line(data = df[df$`_real_point_`, ], color = - "red", linewidth = 0.8) + - labs(x = "time", y = "centered profile value") + - xlim(c(0,NA)) + - theme_default_survex() + - facet_wrap(~`_vname_`) + geom_line() + + scale_colour_gradient2( + name = paste0(unique(df$`_vname_`), " value"), + low = scale_cont[1], + high = scale_cont[3], + mid = scale_cont[2], + midpoint = median(as.numeric(as.character(df$`_x_`))) + ) + + geom_line( + data = df[df$`_real_point_`, ], color = + "red", linewidth = 0.8 + ) + + labs(x = "time", y = "centered profile value") + + xlim(c(0, NA)) + + theme_default_survex() + + facet_wrap(~`_vname_`) }) if (!center) { base_plot <- base_plot + ylim(c(0, 1)) + ylab("survival function value") } } else { base_plot <- with(df, { - ggplot( - df, - aes( - x = `_times_`, - y = as.numeric(as.character(`_x_`)), - z = `_yhat_` + ggplot( + df, + aes( + x = `_times_`, + y = as.numeric(as.character(`_x_`)), + z = `_yhat_` + ) ) - ) }) - if (!center){ + }) + if (!center) { base_plot <- base_plot + - geom_contour_filled(binwidth=0.1) + - scale_fill_manual(name = "SF value", - values = grDevices::colorRampPalette(c("#c7f5bf", "#8bdcbe", "#46bac2", "#4378bf", "#371ea3"))(10), - drop = FALSE) + - guides(fill = guide_colorsteps(direction = "horizontal", - barwidth = 0.5*unit(par("pin")[1], "in"), - barheight = 0.02*unit(par("pin")[2], "in"), - reverse = TRUE, - show.limits = TRUE)) + + geom_contour_filled(binwidth = 0.1) + + scale_fill_manual( + name = "SF value", + values = grDevices::colorRampPalette(c("#c7f5bf", "#8bdcbe", "#46bac2", "#4378bf", "#371ea3"))(10), + drop = FALSE + ) + + guides(fill = guide_colorsteps( + direction = "horizontal", + barwidth = 0.5 * unit(par("pin")[1], "in"), + barheight = 0.02 * unit(par("pin")[2], "in"), + reverse = TRUE, + show.limits = TRUE + )) + labs(x = "time", y = "variable value") + - xlim(c(0,NA)) + + xlim(c(0, NA)) + theme_default_survex() + facet_wrap(~`_vname_`) } else { base_plot <- base_plot + - geom_contour_filled(bins=10) + - scale_fill_manual(name = "centered profile value", - values = grDevices::colorRampPalette(c("#c7f5bf", "#8bdcbe", "#46bac2", "#4378bf", "#371ea3"))(10), - drop = FALSE) + - guides(fill = guide_colorsteps(direction = "horizontal", - barwidth = 0.5*unit(par("pin")[1], "in"), - barheight = 0.02*unit(par("pin")[2], "in"), - reverse = TRUE, - show.limits = TRUE, - label.theme = element_text(size=7))) + + geom_contour_filled(bins = 10) + + scale_fill_manual( + name = "centered profile value", + values = grDevices::colorRampPalette(c("#c7f5bf", "#8bdcbe", "#46bac2", "#4378bf", "#371ea3"))(10), + drop = FALSE + ) + + guides(fill = guide_colorsteps( + direction = "horizontal", + barwidth = 0.5 * unit(par("pin")[1], "in"), + barheight = 0.02 * unit(par("pin")[2], "in"), + reverse = TRUE, + show.limits = TRUE, + label.theme = element_text(size = 7) + )) + labs(x = "time", y = "variable value") + - xlim(c(0,NA)) + + xlim(c(0, NA)) + theme_default_survex() + facet_wrap(~`_vname_`) } @@ -283,9 +305,9 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, range_time <- range(df["_times_"]) var_val <- as.numeric(unique(df[df$`_real_point_`, "_x_"])) base_plot <- base_plot + geom_segment(aes(x = range_time[1], y = var_val, xend = range_time[2], yend = var_val), color = "red") - } - base_plot } + base_plot + } } else { n_colors <- length(unique(df$`_x_`)) base_plot <- with(df, { @@ -298,16 +320,23 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, color = `_x_` ) ) + - geom_line(data = df[!df$`_real_point_`, ], - linewidth = 0.8) + - geom_line(data = df[df$`_real_point_`, ], - linewidth = 0.8, linetype = "longdash") + - scale_color_manual(name = paste0(unique(df$`_vname_`), " value"), - values = generate_discrete_color_scale(n_colors, colors)) + - theme_default_survex() + - labs(x = "time", y = "centered profile value") + - xlim(c(0,NA))+ - facet_wrap(~`_vname_`) }) + geom_line( + data = df[!df$`_real_point_`, ], + linewidth = 0.8 + ) + + geom_line( + data = df[df$`_real_point_`, ], + linewidth = 0.8, linetype = "longdash" + ) + + scale_color_manual( + name = paste0(unique(df$`_vname_`), " value"), + values = generate_discrete_color_scale(n_colors, colors) + ) + + theme_default_survex() + + labs(x = "time", y = "centered profile value") + + xlim(c(0, NA)) + + facet_wrap(~`_vname_`) + }) if (!center) { base_plot <- base_plot + ylim(c(0, 1)) + ylab("survival function value") } @@ -323,11 +352,13 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, check_variable_type <- function(variable_type) { - if (!(variable_type %in% c("numerical", "categorical"))) + if (!(variable_type %in% c("numerical", "categorical"))) { stop("variable_type needs to be 'numerical' or 'categorical'") + } } check_numerical_plot_type <- function(numerical_plot_type) { - if (!(numerical_plot_type %in% c("lines", "contours"))) + if (!(numerical_plot_type %in% c("lines", "contours"))) { stop("numerical_plot_type needs to be 'lines' or 'contours'") + } } diff --git a/R/plot_surv_feature_importance.R b/R/plot_surv_feature_importance.R index 3c3a98c4..97365594 100644 --- a/R/plot_surv_feature_importance.R +++ b/R/plot_surv_feature_importance.R @@ -41,7 +41,6 @@ plot.surv_feature_importance <- function(x, ..., colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222")) { - df_list <- c(list(x), list(...)) transformed_dfs <- lapply(df_list, function(x) { @@ -51,7 +50,7 @@ plot.surv_feature_importance <- function(x, ..., plotting_df <- with(x, cbind(x[1], stack(x, select = -`_times_`), label, row.names = NULL)) }) - transformed_rug_dfs <- lapply(df_list, function(x){ + transformed_rug_dfs <- lapply(df_list, function(x) { rug_df <- data.frame(times = x$event_times, statuses = as.character(x$event_statuses), label = unique(x$result$label)) }) @@ -68,9 +67,10 @@ plot.surv_feature_importance <- function(x, ..., num_vars <- length(unique(plotting_df$ind)) - 1 # remove full_model; note that num_vars <= max_vars additional_info <- switch(attr(x, "type"), - "raw" = "", - "ratio" = "\ndivided by the loss of full model", - "difference" = "\nwith loss of full model subtracted") + "raw" = "", + "ratio" = "\ndivided by the loss of full model", + "difference" = "\nwith loss of full model subtracted" + ) if (!is.null(attr(x, "loss_name"))) { y_lab <- paste0(paste(attr(x, "loss_name")[1], "loss after permutations"), additional_info) @@ -84,14 +84,13 @@ plot.surv_feature_importance <- function(x, ..., } base_plot <- with(plotting_df, { - - ggplot(data = plotting_df, aes(x = `_times_`, y = values, color = ind, label = ind)) + - geom_line(linewidth = 0.8) + - theme_default_survex() + - labs(x = "time", y = y_lab, title = title, subtitle = subtitle) + - xlim(c(0,NA)) + - scale_color_manual(name = "Variable", values = c("#000000", generate_discrete_color_scale(num_vars, colors))) + - facet_wrap(~label) + ggplot(data = plotting_df, aes(x = `_times_`, y = values, color = ind, label = ind)) + + geom_line(linewidth = 0.8) + + theme_default_survex() + + labs(x = "time", y = y_lab, title = title, subtitle = subtitle) + + xlim(c(0, NA)) + + scale_color_manual(name = "Variable", values = c("#000000", generate_discrete_color_scale(num_vars, colors))) + + facet_wrap(~label) }) return_plot <- add_rug_to_plot(base_plot, rug_df, rug, rug_colors) diff --git a/R/plot_surv_lime.R b/R/plot_surv_lime.R index 99f958c0..c00eccfe 100644 --- a/R/plot_surv_lime.R +++ b/R/plot_surv_lime.R @@ -36,16 +36,19 @@ plot.surv_lime <- function(x, subtitle = "default", max_vars = 7, colors = NULL) { - if (!type %in% c("coefficients", "local_importance")) + if (!type %in% c("coefficients", "local_importance")) { stop("Type should be one of `coefficients`, `local_importance`") + } local_importance <- as.numeric(x$result) * as.numeric(x$variable_values) - df <- data.frame(variable_names = names(x$variable_values), - variable_values = as.numeric(x$variable_values), - beta = as.numeric(x$result), - sign_beta = as.factor(sign(as.numeric(x$result))), - sign_local_importance = as.factor(sign(local_importance)), - local_importance = local_importance) + df <- data.frame( + variable_names = names(x$variable_values), + variable_values = as.numeric(x$variable_values), + beta = as.numeric(x$result), + sign_beta = as.factor(sign(as.numeric(x$result))), + sign_local_importance = as.factor(sign(local_importance)), + local_importance = local_importance + ) if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", attr(x, "label"), " model") @@ -54,11 +57,11 @@ plot.surv_lime <- function(x, if (type == "coefficients") { x_lab <- "SurvLIME coefficients" y_lab <- "" - df <- df[head(order(abs(df$beta), decreasing=TRUE), max_vars),] + df <- df[head(order(abs(df$beta), decreasing = TRUE), max_vars), ] pl <- with(df, { - ggplot(data = df, aes(x = beta, y = reorder(variable_names, beta, abs), fill = sign_beta)) + - geom_col() + - scale_fill_manual("", values = c("-1"="#f05a71", "0"="#ffffff", "1"="#8bdcbe")) + ggplot(data = df, aes(x = beta, y = reorder(variable_names, beta, abs), fill = sign_beta)) + + geom_col() + + scale_fill_manual("", values = c("-1" = "#f05a71", "0" = "#ffffff", "1" = "#8bdcbe")) }) } @@ -66,13 +69,11 @@ plot.surv_lime <- function(x, if (type == "local_importance") { x_lab <- "SurvLIME local importance" y_lab <- "" - df <- df[head(order(abs(df$local_importance), decreasing=TRUE), max_vars),] - pl <- with(df,{ - + df <- df[head(order(abs(df$local_importance), decreasing = TRUE), max_vars), ] + pl <- with(df, { ggplot(data = df, aes(x = local_importance, y = reorder(variable_names, local_importance, abs), fill = sign_local_importance)) + - geom_col() + - scale_fill_manual("", values = c("-1"="#f05a71", "0"="#ffffff", "1"="#8bdcbe")) - + geom_col() + + scale_fill_manual("", values = c("-1" = "#f05a71", "0" = "#ffffff", "1" = "#8bdcbe")) }) } pl <- pl + theme_vertical_default_survex() + @@ -81,21 +82,20 @@ plot.surv_lime <- function(x, ylab(y_lab) + theme(legend.position = "none") - sf_df <- data.frame(times = c(x$black_box_sf_times, x$expl_sf_times), - sfs = c(x$black_box_sf, x$expl_sf), - type = c(rep("black box survival function", length(x$black_box_sf)), rep("SurvLIME explanation survival function", length(x$expl_sf)))) + sf_df <- data.frame( + times = c(x$black_box_sf_times, x$expl_sf_times), + sfs = c(x$black_box_sf, x$expl_sf), + type = c(rep("black box survival function", length(x$black_box_sf)), rep("SurvLIME explanation survival function", length(x$expl_sf))) + ) if (show_survival_function) { - pl2 <- with(sf_df,{ - + pl2 <- with(sf_df, { ggplot(data = sf_df, aes(x = times, y = sfs, group = type, color = type)) + - geom_line(linewidth = 0.8) + - theme_default_survex() + - labs(x = "time", y = "survival function value") + - xlim(c(0,NA)) + - scale_color_manual("", values = generate_discrete_color_scale(2, colors)) + geom_line(linewidth = 0.8) + + theme_default_survex() + + labs(x = "time", y = "survival function value") + + xlim(c(0, NA)) + + scale_color_manual("", values = generate_discrete_color_scale(2, colors)) }) return(patchwork::wrap_plots(pl, pl2, nrow = 1, widths = c(3, 5))) } - - } diff --git a/R/plot_surv_model_performance.R b/R/plot_surv_model_performance.R index 92cc0f58..7b14df42 100644 --- a/R/plot_surv_model_performance.R +++ b/R/plot_surv_model_performance.R @@ -45,20 +45,37 @@ plot.surv_model_performance <- function(x, rug_df <- data.frame(times = x$event_times, statuses = as.character(x$event_statuses), label = attr(x, "label")) if (metrics_type %in% c("time_dependent", "functional")) { - pl <- plot_td_surv_model_performance(x, ..., metrics = metrics, title = title, subtitle = subtitle, facet_ncol = facet_ncol, colors = colors, rug_df = rug_df, rug = rug, rug_colors = rug_colors) + pl <- plot_td_surv_model_performance( + x, + ..., + metrics = metrics, + title = title, + subtitle = subtitle, + facet_ncol = facet_ncol, + colors = colors, + rug_df = rug_df, + rug = rug, + rug_colors = rug_colors + ) } else if (metrics_type == "scalar") { - pl <- plot_scalar_surv_model_performance(x, ..., metrics = metrics, title = title, subtitle = subtitle, facet_ncol = facet_ncol, colors = colors) + pl <- plot_scalar_surv_model_performance( + x, + ..., + metrics = metrics, + title = title, + subtitle = subtitle, + facet_ncol = facet_ncol, + colors = colors + ) } else { stop("`metrics_type` should be one of `time_dependent`, `functional` or `scalar`") } pl - } plot_td_surv_model_performance <- function(x, ..., metrics = NULL, title = NULL, subtitle = "default", facet_ncol = NULL, colors = NULL, rug_df = rug_df, rug = rug, rug_colors = rug_colors) { - df <- concatenate_td_dfs(x, ...) if (!is.null(subtitle) && subtitle == "default") { @@ -69,14 +86,14 @@ plot_td_surv_model_performance <- function(x, ..., metrics = NULL, title = NULL, num_colors <- length(unique(df$label)) - base_plot <- with(df,{ - ggplot(data = df[df$ind %in% metrics, ], aes(x = times, y = values, group = label, color = label)) + - geom_line(linewidth = 0.8) + - theme_default_survex() + - labs(x = "time", y = "metric value", title = title, subtitle = subtitle) + - xlim(c(0,NA)) + - scale_color_manual("", values = generate_discrete_color_scale(num_colors, colors)) + - facet_wrap(~ind, ncol = facet_ncol, scales = "free_y") + base_plot <- with(df, { + ggplot(data = df[df$ind %in% metrics, ], aes(x = times, y = values, group = label, color = label)) + + geom_line(linewidth = 0.8) + + theme_default_survex() + + labs(x = "time", y = "metric value", title = title, subtitle = subtitle) + + xlim(c(0, NA)) + + scale_color_manual("", values = generate_discrete_color_scale(num_colors, colors)) + + facet_wrap(~ind, ncol = facet_ncol, scales = "free_y") }) return_plot <- add_rug_to_plot(base_plot, rug_df, rug, rug_colors) @@ -97,15 +114,13 @@ plot_scalar_surv_model_performance <- function(x, ..., metrics = NULL, title = N num_colors <- length(unique(df$label)) with(df, { - ggplot(data = df, aes(x = label, y = values, fill = label)) + - geom_col() + - theme_default_survex() + - labs(x = "model", y = "metric value", title = title, subtitle = subtitle) + - scale_fill_manual("", values = generate_discrete_color_scale(num_colors, colors)) + - facet_wrap(~ind, ncol = facet_ncol, scales = "free_y") + ggplot(data = df, aes(x = label, y = values, fill = label)) + + geom_col() + + theme_default_survex() + + labs(x = "model", y = "metric value", title = title, subtitle = subtitle) + + scale_fill_manual("", values = generate_discrete_color_scale(num_colors, colors)) + + facet_wrap(~ind, ncol = facet_ncol, scales = "free_y") }) - - } @@ -113,26 +128,26 @@ concatenate_td_dfs <- function(x, ...) { all_things <- c(list(x), list(...)) all_dfs <- lapply(all_things, function(x) { - tmp_list <- lapply(x$result, function(metric) { - if(!is.null(attr(metric, "loss_type"))){ - if(attr(metric, "loss_type") == "time-dependent"){ + if (!is.null(attr(metric, "loss_type"))) { + if (attr(metric, "loss_type") == "time-dependent") { attr(metric, "loss_type") <- NULL - metric} + metric + } } }) tmp_list[sapply(tmp_list, is.null)] <- NULL df <- data.frame(tmp_list, - check.names = FALSE) + check.names = FALSE + ) df <- stack(df) - times <- rep(x$eval_times, length(tmp_list)) - label <- attr(x, "label") + times <- rep(x$eval_times, length(tmp_list)) + label <- attr(x, "label") df <- cbind(times, df, label) }) do.call(rbind, all_dfs) - } @@ -141,14 +156,16 @@ concatenate_dfs <- function(x, ...) { all_dfs <- lapply(all_things, function(x) { tmp_list <- lapply(x$result, function(metric) { - if(!is.null(attr(metric, "loss_type"))){ - if(attr(metric, "loss_type") != "time-dependent"){ - metric[1]} + if (!is.null(attr(metric, "loss_type"))) { + if (attr(metric, "loss_type") != "time-dependent") { + metric[1] + } } - }) + }) tmp_list[sapply(tmp_list, is.null)] <- NULL df <- data.frame(tmp_list, - check.names = FALSE) + check.names = FALSE + ) df <- stack(df) label <- attr(x, "label") df <- cbind(df, label) diff --git a/R/plot_surv_model_performance_rocs.R b/R/plot_surv_model_performance_rocs.R index 6752c3ac..33806e80 100644 --- a/R/plot_surv_model_performance_rocs.R +++ b/R/plot_surv_model_performance_rocs.R @@ -34,7 +34,6 @@ plot.surv_model_performance_rocs <- function(x, auc = TRUE, colors = NULL, facet_ncol = NULL) { - dfl <- c(list(x), list(...)) alldfs <- lapply(dfl, function(x) { @@ -51,30 +50,38 @@ plot.surv_model_performance_rocs <- function(x, num_colors <- length(unique(df$label)) - base_plot <- with(df, {ggplot(data = df, aes(x = FPR, y = TPR, group = label, color = label)) + + base_plot <- with(df, { + ggplot(data = df, aes(x = FPR, y = TPR, group = label, color = label)) + geom_line(linewidth = 0.8) + theme_default_survex() + xlab("1 - specificity (FPR)") + ylab("sensitivity (TPR)") + coord_fixed() + - theme(panel.grid.major.x = element_line(color = "grey90", linewidth = 0.5, linetype = 1), - panel.grid.minor.x = element_line(color = "grey90", linewidth = 0.5, linetype = 1)) + + theme( + panel.grid.major.x = element_line(color = "grey90", linewidth = 0.5, linetype = 1), + panel.grid.minor.x = element_line(color = "grey90", linewidth = 0.5, linetype = 1) + ) + labs(title = title, subtitle = subtitle) + scale_color_manual("", values = generate_discrete_color_scale(num_colors, colors)) + facet_wrap(~time, ncol = facet_ncol, labeller = function(x) lapply(x, function(x) paste0("t=", x))) }) - if (auc){ - auc_df <- unique(df[,c("label", "time", "AUC")]) + if (auc) { + auc_df <- unique(df[, c("label", "time", "AUC")]) auc_df$AUC <- round(auc_df$AUC, 3) - auc_df$y <- rep((0:(num_colors-1)) * 0.1, each=length(unique(auc_df$time))) + auc_df$y <- rep((0:(num_colors - 1)) * 0.1, each = length(unique(auc_df$time))) return_plot <- base_plot + - with(auc_df, { geom_text(auc_df, - mapping=aes(x=0.75, - y=y, - label=paste("AUC =", AUC), - color=label), - show.legend=FALSE)}) + with(auc_df, { + geom_text(auc_df, + mapping = aes( + x = 0.75, + y = y, + label = paste("AUC =", AUC), + color = label + ), + show.legend = FALSE + ) + }) } else { return_plot <- base_plot } diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index df1fc41b..5028dbd4 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -37,13 +37,12 @@ plot.surv_shap <- function(x, colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222")) { - dfl <- c(list(x), list(...)) long_df <- lapply(dfl, function(x) { label <- attr(x, "label") cols <- sort(head(order(x$aggregate, decreasing = TRUE), max_vars)) - sv <- x$result[,cols] + sv <- x$result[, cols] times <- x$eval_times transposed <- as.data.frame(cbind(times = times, sv)) rownames(transposed) <- NULL @@ -54,7 +53,7 @@ plot.surv_shap <- function(x, ) }) - transformed_rug_dfs <- lapply(dfl, function(x){ + transformed_rug_dfs <- lapply(dfl, function(x) { label <- attr(x, "label") rug_df <- data.frame(times = x$event_times, statuses = as.character(x$event_statuses), label = label) }) @@ -71,19 +70,18 @@ plot.surv_shap <- function(x, n_colors <- length(unique(long_df$ind)) base_plot <- with(long_df, { - ggplot(data = long_df, aes(x = times, y = values, color = ind)) + - geom_line(linewidth = 0.8) + - labs(x = "time", y = "SurvSHAP(t) value", title = title, subtitle = subtitle) + - xlim(c(0,NA)) + - scale_color_manual("variable", values = generate_discrete_color_scale(n_colors, colors)) + - theme_default_survex() + - facet_wrap(~label, ncol = 1, scales = "free_y") + ggplot(data = long_df, aes(x = times, y = values, color = ind)) + + geom_line(linewidth = 0.8) + + labs(x = "time", y = "SurvSHAP(t) value", title = title, subtitle = subtitle) + + xlim(c(0, NA)) + + scale_color_manual("variable", values = generate_discrete_color_scale(n_colors, colors)) + + theme_default_survex() + + facet_wrap(~label, ncol = 1, scales = "free_y") }) return_plot <- add_rug_to_plot(base_plot, rug_df, rug, rug_colors) return(return_plot) - } @@ -126,62 +124,74 @@ plot.surv_shap <- function(x, #' \donttest{ #' veteran <- survival::veteran #' rsf_ranger <- ranger::ranger( -#' survival::Surv(time, status) ~ ., -#' data = veteran, -#' respect.unordered.factors = TRUE, -#' num.trees = 100, -#' mtry = 3, -#' max.depth = 5 +#' survival::Surv(time, status) ~ ., +#' data = veteran, +#' respect.unordered.factors = TRUE, +#' num.trees = 100, +#' mtry = 3, +#' max.depth = 5 #' ) #' rsf_ranger_exp <- explain( -#' rsf_ranger, -#' data = veteran[, -c(3, 4)], -#' y = survival::Surv(veteran$time, veteran$status), -#' verbose = FALSE +#' rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = survival::Surv(veteran$time, veteran$status), +#' verbose = FALSE #' ) #' #' ranger_global_survshap <- model_survshap( -#' explainer = rsf_ranger_exp, -#' new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), -#' !colnames(veteran) %in% c("time", "status")], -#' y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], -#' veteran$status[c(1:4, 17:20, 110:113, 126:129)]), -#' aggregation_method = "integral", -#' calculation_method = "kernelshap", +#' explainer = rsf_ranger_exp, +#' new_observation = veteran[ +#' c(1:4, 17:20, 110:113, 126:129), +#' !colnames(veteran) %in% c("time", "status") +#' ], +#' y_true = survival::Surv( +#' veteran$time[c(1:4, 17:20, 110:113, 126:129)], +#' veteran$status[c(1:4, 17:20, 110:113, 126:129)] +#' ), +#' aggregation_method = "integral", +#' calculation_method = "kernelshap", #' ) #' plot(ranger_global_survshap) #' plot(ranger_global_survshap, geom = "beeswarm") #' plot(ranger_global_survshap, geom = "profile", color_variable = "karno") #' } #' -#'@export +#' @export plot.aggregated_surv_shap <- function(x, geom = "importance", ..., - title="default", - subtitle="default", - max_vars=7, - colors = NULL){ - if (is.null(colors)){ - colors <- c(low = "#9fe5bd", - mid = "#46bac2", - high = "#371ea3") + title = "default", + subtitle = "default", + max_vars = 7, + colors = NULL) { + if (is.null(colors)) { + colors <- c( + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3" + ) } - if (geom == "swarm") + if (geom == "swarm") { geom <- "beeswarm" + } - switch( - geom, - "importance" = plot_shap_global_importance(x = x, - ... = ..., - colors = colors), - "beeswarm" = plot_shap_global_beeswarm(x = x, - ... = ..., - colors = colors), - "profile" = plot_shap_global_profile(x = x, - ... = ..., - colors = colors), + switch(geom, + "importance" = plot_shap_global_importance( + x = x, + ... = ..., + colors = colors + ), + "beeswarm" = plot_shap_global_beeswarm( + x = x, + ... = ..., + colors = colors + ), + "profile" = plot_shap_global_profile( + x = x, + ... = ..., + colors = colors + ), stop("`geom` must be one of 'importance', 'beeswarm' or 'profile'") ) } @@ -195,26 +205,28 @@ plot_shap_global_importance <- function(x, rug = "all", rug_colors = c("#dd0000", "#222222"), xlab_left = "Average |aggregated SurvSHAP(t)| value", - ylab_right = "Average |SurvSHAP(t)| value"){ - + ylab_right = "Average |SurvSHAP(t)| value") { x$result <- aggregate_shap_multiple_observations(x$result, colnames(x$result[[1]]), function(x) mean(abs(x))) x$aggregate <- apply(do.call(rbind, x$aggregate), 2, function(x) mean(abs(x))) - right_plot <- plot.surv_shap(x = x, - title = NULL, - subtitle = NULL, - max_vars = max_vars, - colors = NULL, - rug = rug, - rug_colors = rug_colors) + + right_plot <- plot.surv_shap( + x = x, + title = NULL, + subtitle = NULL, + max_vars = max_vars, + colors = NULL, + rug = rug, + rug_colors = rug_colors + ) + labs(y = ylab_right) label <- attr(x, "label") long_df <- stack(x$aggregate) - long_df <- long_df[order(long_df$values, decreasing = TRUE),][1:min(max_vars, length(x$aggregate)), ] + long_df <- long_df[order(long_df$values, decreasing = TRUE), ][1:min(max_vars, length(x$aggregate)), ] - if (!is.null(subtitle) && subtitle == "default") + if (!is.null(subtitle) && subtitle == "default") { title <- "Feature importance according to aggregated |SurvSHAP(t)|" + } if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0( "created for the ", label, " model ", @@ -228,41 +240,42 @@ plot_shap_global_importance <- function(x, theme_default_survex() + labs(x = xlab_left) + theme(axis.title.y = element_blank()) - }) pl <- left_plot + right_plot + - patchwork::plot_layout(widths = c(3,5), guides = "collect") + + patchwork::plot_layout(widths = c(3, 5), guides = "collect") + patchwork::plot_annotation(title = title, subtitle = subtitle) & - theme(legend.position = "top", - plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), - plot.subtitle = element_text(color = "#371ea3", hjust = 0),) + theme( + legend.position = "top", + plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), + plot.subtitle = element_text(color = "#371ea3", hjust = 0), + ) return(pl) } plot_shap_global_beeswarm <- function(x, - ..., - title = "default", - subtitle = "default", - max_vars = 7, - colors = NULL){ - + ..., + title = "default", + subtitle = "default", + max_vars = 7, + colors = NULL) { df <- as.data.frame(do.call(rbind, x$aggregate)) cols <- names(sort(colMeans(abs(df))))[1:min(max_vars, length(df))] - df <- df[,cols] + df <- df[, cols] df <- stack(df) colnames(df) <- c("shap_value", "variable") - original_values <- as.data.frame(x$variable_values)[,cols] + original_values <- as.data.frame(x$variable_values)[, cols] var_value <- preprocess_values_to_common_scale(original_values) df <- cbind(df, var_value) label <- attr(x, "label") - if (!is.null(subtitle) && subtitle == "default") + if (!is.null(subtitle) && subtitle == "default") { title <- "Aggregated SurvSHAP(t) values summary" + } if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0( "created for the ", label, " model ", @@ -270,26 +283,28 @@ plot_shap_global_beeswarm <- function(x, ) } with(df, { - ggplot(data = df, aes(x = shap_value, y = variable, color = var_value)) + - geom_vline(xintercept = 0, color = "#ceced9", linetype="solid") + - geom_jitter(width=0, height=0.15) + - scale_color_gradient2( - name = "Variable value", - low = colors[1], - mid = colors[2], - high = colors[3], - midpoint = 0.5, - limits=c(0,1), - breaks = c(0, 1), - labels=c("", "")) + - labs(title = title, subtitle = subtitle, - x = "Aggregated SurvSHAP(t) value", - y = "Variable") + - theme_default_survex() + - theme(legend.position = "bottom") + - guides(color = guide_colorbar(title.position = "top", title.hjust = 0.5)) - } - ) + ggplot(data = df, aes(x = shap_value, y = variable, color = var_value)) + + geom_vline(xintercept = 0, color = "#ceced9", linetype = "solid") + + geom_jitter(width = 0, height = 0.15) + + scale_color_gradient2( + name = "Variable value", + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = 0.5, + limits = c(0, 1), + breaks = c(0, 1), + labels = c("", "") + ) + + labs( + title = title, subtitle = subtitle, + x = "Aggregated SurvSHAP(t) value", + y = "Variable" + ) + + theme_default_survex() + + theme(legend.position = "bottom") + + guides(color = guide_colorbar(title.position = "top", title.hjust = 0.5)) + }) } plot_shap_global_profile <- function(x, @@ -299,28 +314,28 @@ plot_shap_global_profile <- function(x, title = "default", subtitle = "default", max_vars = 7, - colors = NULL){ - + colors = NULL) { df <- as.data.frame(do.call(rbind, x$aggregate)) - if (is.null(variable)){ + if (is.null(variable)) { variable <- colnames(df)[1] } - if (is.null(color_variable)){ + if (is.null(color_variable)) { color_variable <- variable } - shap_val <- df[,variable] + shap_val <- df[, variable] original_values <- as.data.frame(x$variable_values) - var_vals <- original_values[,c(variable, color_variable)] + var_vals <- original_values[, c(variable, color_variable)] df <- cbind(shap_val, var_vals) colnames(df) <- c("shap_val", "variable_val", "color_variable_val") label <- attr(x, "label") - if (!is.null(subtitle) && subtitle == "default") + if (!is.null(subtitle) && subtitle == "default") { title <- "Aggregated SurvSHAP(t) profile" + } if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0( "created for the ", label, " model ", @@ -329,16 +344,18 @@ plot_shap_global_profile <- function(x, } p <- with(df, { - ggplot(df, aes(x = variable_val, y = shap_val, color = color_variable_val)) + - geom_hline(yintercept = 0, color = "#ceced9", linetype="solid") + - geom_point() + - geom_rug(aes(x = df$variable_val), inherit.aes=F, color = "#ceced9") + - labs(x = paste(variable, "value"), - y = "Aggregated SurvSHAP(t) value", - title = title, - subtitle = subtitle) + - theme_default_survex() + - theme(legend.position = "bottom") + ggplot(df, aes(x = variable_val, y = shap_val, color = color_variable_val)) + + geom_hline(yintercept = 0, color = "#ceced9", linetype = "solid") + + geom_point() + + geom_rug(aes(x = df$variable_val), inherit.aes = F, color = "#ceced9") + + labs( + x = paste(variable, "value"), + y = "Aggregated SurvSHAP(t) value", + title = title, + subtitle = subtitle + ) + + theme_default_survex() + + theme(legend.position = "bottom") }) if (!is.factor(df$color_variable_val)) { @@ -347,10 +364,13 @@ plot_shap_global_profile <- function(x, low = colors[1], mid = colors[2], high = colors[3], - midpoint = median(df$color_variable_val)) + midpoint = median(df$color_variable_val) + ) } else { - p + scale_color_manual(name = paste(color_variable, "value"), - values = generate_discrete_color_scale(length(unique(df$color_variable_val)), colors)) + p + scale_color_manual( + name = paste(color_variable, "value"), + values = generate_discrete_color_scale(length(unique(df$color_variable_val)), colors) + ) } } @@ -367,7 +387,5 @@ preprocess_values_to_common_scale <- function(data) { }) res <- stack(data) colnames(res) <- c("var_value", "variable") - return(res[,1]) + return(res[, 1]) } - - diff --git a/R/predict_parts.R b/R/predict_parts.R index 8adda397..4ace7d83 100644 --- a/R/predict_parts.R +++ b/R/predict_parts.R @@ -48,8 +48,11 @@ #' head(cph_predict_parts_survshap$result) #' plot(cph_predict_parts_survshap) #' -#' cph_predict_parts_survlime <- predict_parts(cph_exp, new_observation = veteran[1, -c(3, 4)], -#' type = "survlime") +#' cph_predict_parts_survlime <- predict_parts( +#' cph_exp, +#' new_observation = veteran[1, -c(3, 4)], +#' type = "survlime" +#' ) #' head(cph_predict_parts_survlime$result) #' plot(cph_predict_parts_survlime, type = "local_importance") #' } @@ -62,29 +65,27 @@ predict_parts <- function(explainer, ...) UseMethod("predict_parts", explainer) #' @export predict_parts.surv_explainer <- function(explainer, new_observation, ..., N = NULL, type = "survshap", output_type = "survival", explanation_label = NULL) { - if (output_type == "risk") { - return (DALEX::predict_parts(explainer = explainer, - new_observation = new_observation, - ... = ..., - N = N, - type = type)) - } - else { - + return(DALEX::predict_parts( + explainer = explainer, + new_observation = new_observation, + ... = ..., + N = N, + type = type + )) + } else { res <- switch(type, - "survshap" = surv_shap(explainer, new_observation, ...), - "survlime" = surv_lime(explainer, new_observation, ...), - stop("Only `survshap` and `survlime` methods are implemented for now")) - + "survshap" = surv_shap(explainer, new_observation, ...), + "survlime" = surv_lime(explainer, new_observation, ...), + stop("Only `survshap` and `survlime` methods are implemented for now") + ) } attr(res, "label") <- ifelse(is.null(explanation_label), explainer$label, explanation_label) res$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] res$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] - class(res) <- c('predict_parts_survival', class(res)) + class(res) <- c("predict_parts_survival", class(res)) res - } diff --git a/R/predict_profile.R b/R/predict_profile.R index 8ea935e6..98094455 100644 --- a/R/predict_profile.R +++ b/R/predict_profile.R @@ -27,8 +27,9 @@ #' rsf_src_exp <- explain(rsf_src) #' #' cph_predict_profile <- predict_profile(cph_exp, veteran[2, -c(3, 4)], -#' variables = c("trt", "celltype", "karno", "age"), -#' categorical_variables = "trt") +#' variables = c("trt", "celltype", "karno", "age"), +#' categorical_variables = "trt" +#' ) #' plot(cph_predict_profile, facet_ncol = 2) #' #' @@ -45,8 +46,9 @@ predict_profile <- function(explainer, ..., type = "ceteris_paribus", variable_splits_type = "uniform", - center = FALSE) + center = FALSE) { UseMethod("predict_profile", explainer) +} #' @rdname predict_profile.surv_explainer #' @export @@ -58,41 +60,43 @@ predict_profile.surv_explainer <- function(explainer, type = "ceteris_paribus", output_type = "survival", variable_splits_type = "uniform", - center = FALSE) -{ - + center = FALSE) { variables <- unique(variables, categorical_variables) if (!type %in% "ceteris_paribus") stop("Type not supported") if (!output_type %in% c("risk", "survival")) stop("output_type not supported") - if (length(dim(new_observation)) != 2 | nrow(new_observation) != 1) + if (length(dim(new_observation)) != 2 || nrow(new_observation) != 1) { stop("new_observation should be a single row data.frame") + } if (output_type == "risk") { - return(DALEX::predict_profile(explainer = explainer, - new_observation = new_observation, - variables = variables, - ... = ..., - type = type, - variable_splits_type = variable_splits_type)) - + return(DALEX::predict_profile( + explainer = explainer, + new_observation = new_observation, + variables = variables, + ... = ..., + type = type, + variable_splits_type = variable_splits_type + )) } if (output_type == "survival") { if (type == "ceteris_paribus") { - res <- surv_ceteris_paribus(explainer, - new_observation = new_observation, - variables = variables, - categorical_variables = categorical_variables, - variable_splits_type = variable_splits_type, - center = center, - ...) + res <- surv_ceteris_paribus( + explainer, + new_observation = new_observation, + variables = variables, + categorical_variables = categorical_variables, + variable_splits_type = variable_splits_type, + center = center, + ... + ) class(res) <- c("predict_profile_survival", class(res)) res$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] res$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] return(res) + } else { + stop("For survival output only type=`ceteris_paribus` is implemented") } - else stop("For survival output only type=`ceteris_paribus` is implemented") } - } #' @export diff --git a/R/predict_surv_explainer.R b/R/predict_surv_explainer.R index 37bf67bc..65ea4c0f 100644 --- a/R/predict_surv_explainer.R +++ b/R/predict_surv_explainer.R @@ -17,16 +17,19 @@ #' #' cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) #' rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., -#' data = veteran, -#' respect.unordered.factors = TRUE, -#' num.trees = 100, -#' mtry = 3, -#' max.depth = 5) +#' data = veteran, +#' respect.unordered.factors = TRUE, +#' num.trees = 100, +#' mtry = 3, +#' max.depth = 5 +#' ) #' #' cph_exp <- explain(cph) #' -#' rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], -#' y = Surv(veteran$time, veteran$status)) +#' rsf_ranger_exp <- explain(rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = Surv(veteran$time, veteran$status) +#' ) #' #' #' predict(cph_exp, veteran[1, ], output_type = "survival")[, 1:10] @@ -36,8 +39,7 @@ #' predict(rsf_ranger_exp, veteran[1, ], output_type = "chf")[, 1:10] #' #' @export -predict.surv_explainer <- function(object, newdata = NULL, output_type = "survival", times = NULL, ...) { - +predict.surv_explainer <- function(object, newdata = NULL, output_type = "survival", times = NULL, ...) { if (is.null(newdata)) newdata <- object$data if (is.null(times)) times <- object$times diff --git a/R/print.R b/R/print.R index c915b83c..c042cd61 100644 --- a/R/print.R +++ b/R/print.R @@ -32,7 +32,6 @@ print.surv_ceteris_paribus <- function(x, ...) { #' @export print.surv_feature_importance <- function(x, ...) { - res <- x$result text <- paste0("Permutational feature importance for the ", unique(res$label), " model:\n") cat(text) @@ -54,7 +53,6 @@ print.surv_lime <- function(x, ...) { #' @export print.surv_shap <- function(x, ...) { - res <- x$result cat("SurvSHAP(t) for observation:\n\n") print.data.frame(x$variable_values, row.names = FALSE) diff --git a/R/surv_ceteris_paribus.R b/R/surv_ceteris_paribus.R index f8dbd057..031a9b99 100644 --- a/R/surv_ceteris_paribus.R +++ b/R/surv_ceteris_paribus.R @@ -79,14 +79,17 @@ surv_ceteris_paribus.default <- function(x, factor_variables <- colnames(data)[sapply(data, is.factor)] categorical_variables <- unique(c(additional_categorical_variables, factor_variables)) - if (is.null(data)) + if (is.null(data)) { stop("The ceteris_paribus() function requires explainers created with specified 'data'.") + } # calculate splits if (is.null(variable_splits)) { - if (is.null(variables)) + if (is.null(variables)) { variables <- colnames(data) - variable_splits <- calculate_variable_split(data, + } + variable_splits <- calculate_variable_split( + data, variables = variables, categorical_variables = categorical_variables, grid_points = grid_points, @@ -189,7 +192,7 @@ calculate_variable_survival_profile.default <- function(data, variable_splits, m new_data[, variable] <- rep(split_points, nrow(data)) yhat <- c(t(predict_survival_function(model, new_data, times))) - if (center){ + if (center) { yhat <- yhat - mean_pred } diff --git a/R/surv_feature_importance.R b/R/surv_feature_importance.R index 1a70238d..f55ed2c1 100644 --- a/R/surv_feature_importance.R +++ b/R/surv_feature_importance.R @@ -35,7 +35,6 @@ surv_feature_importance.surv_explainer <- function(x, variable_groups = NULL, N = NULL, label = NULL) { - test_explainer(x, "feature_importance", has_data = TRUE, has_y = TRUE, has_survival = TRUE) model <- x$model @@ -50,21 +49,20 @@ surv_feature_importance.surv_explainer <- function(x, surv_feature_importance.default(model, - data, - y, - times, - predict_function = predict_function, - predict_survival_function = predict_survival_function, - loss_function = loss_function, - label = label, - type = type, - N = N, - B = B, - variables = variables, - variable_groups = variable_groups, - ... + data, + y, + times, + predict_function = predict_function, + predict_survival_function = predict_survival_function, + loss_function = loss_function, + label = label, + type = type, + N = N, + B = B, + variables = variables, + variable_groups = variable_groups, + ... ) - } @@ -82,10 +80,6 @@ surv_feature_importance.default <- function(x, variables = NULL, N = NULL, variable_groups = NULL) { - - - - if (!is.null(variable_groups)) { if (!inherits(variable_groups, "list")) stop("variable_groups should be of class list") @@ -144,7 +138,7 @@ surv_feature_importance.default <- function(x, loss_full <- loss_function(observed, risk_true, surv_true, times) prog() chosen <- sample(1:nrow(observed)) - loss_baseline <- loss_function(observed[chosen, ], risk_true, surv_true, times) + loss_baseline <- loss_function(observed[chosen, ], risk_true, surv_true, times) prog() # loss upon dropping a single variable (or a single group) loss_variables <- sapply(variables, function(variables_set) { @@ -173,13 +167,11 @@ surv_feature_importance.default <- function(x, res_full <- res[res$`_permutation_` == 0, c("_times_", "_full_model_")] colnames(res_full) <- c("_times_", "_reference_") res <- merge(res, res_full, by = "_times_") - res <- res[order(res$`_permutation_`, res$`_times_`),] + res <- res[order(res$`_permutation_`, res$`_times_`), ] } if (type == "ratio") { - res[, 2:(ncol(res) - 3)] <- res[, 2:(ncol(res) - 3)] / res[["_reference_"]] res$`_reference_` <- NULL - } if (type == "difference") { res[, 2:(ncol(res) - 3)] <- res[, 2:(ncol(res) - 3)] - res[["_reference_"]] diff --git a/R/surv_integrated_feature_importance.R b/R/surv_integrated_feature_importance.R index 8630a371..e14a1628 100644 --- a/R/surv_integrated_feature_importance.R +++ b/R/surv_integrated_feature_importance.R @@ -30,7 +30,6 @@ surv_integrated_feature_importance <- function(x, variable_groups = NULL, N = NULL, label = NULL) { - test_explainer(x, "feature_importance", has_data = TRUE, has_y = TRUE) # extracts model, data and predict function from the explainer @@ -109,7 +108,6 @@ surv_integrated_feature_importance <- function(x, predicted_survs <- predict_survival_function(x, ndf, times) prog() loss_function(observed, predicted_risks, predicted_survs, times) - }) c("_full_model_" = loss_full, loss_variables, "_baseline_" = loss_baseline) } @@ -129,10 +127,10 @@ surv_integrated_feature_importance <- function(x, row.names = NULL ) if (type == "ratio") { - res$dropout_loss = res$dropout_loss / res_full + res$dropout_loss <- res$dropout_loss / res_full } if (type == "difference") { - res$dropout_loss = res$dropout_loss - res_full + res$dropout_loss <- res$dropout_loss - res_full } @@ -149,10 +147,10 @@ surv_integrated_feature_importance <- function(x, # here mean full model is used (full model for given permutation is an option) if (type == "ratio") { - res_B$dropout_loss = res_B$dropout_loss / res_full + res_B$dropout_loss <- res_B$dropout_loss / res_full } if (type == "difference") { - res_B$dropout_loss = res_B$dropout_loss - res_full + res_B$dropout_loss <- res_B$dropout_loss - res_full } res <- rbind(res, res_B) @@ -164,5 +162,4 @@ surv_integrated_feature_importance <- function(x, attr(res, "loss_name") <- attr(loss_function, "loss_name") } res - } diff --git a/R/surv_lime.R b/R/surv_lime.R index 0d2fab36..cc06e64c 100644 --- a/R/surv_lime.R +++ b/R/surv_lime.R @@ -30,19 +30,20 @@ surv_lime <- function(explainer, new_observation, max_iter = 10000, categorical_variables = NULL, k = 1 + 1e-4) { - test_explainer(explainer, "surv_lime", has_data = TRUE, has_y = TRUE, has_chf = TRUE) new_observation <- new_observation[, colnames(new_observation) %in% colnames(explainer$data)] if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)") predicted_sf <- explainer$predict_survival_function(explainer$model, new_observation, explainer$times) - neighbourhood <- generate_neighbourhood(explainer$data, - new_observation, - N, - categorical_variables, - sampling_method, - sample_around_instance) + neighbourhood <- generate_neighbourhood( + explainer$data, + new_observation, + N, + categorical_variables, + sampling_method, + sample_around_instance + ) sc <- attr(neighbourhood$inverse, "scaled:scale") @@ -69,7 +70,6 @@ surv_lime <- function(explainer, new_observation, loss <- function(beta) { - multiplied <- as.matrix(neighbourhood$inverse_ohe) %*% as.matrix(beta) multiplied <- multiplied[, rep(1, times = ncol(model_chfs))] @@ -79,9 +79,9 @@ surv_lime <- function(explainer, new_observation, rowSums( (weights_v^2) * ((log_chfs - log_na_est - multiplied)^2) - * t_diffs) - ) - + * t_diffs + ) + ) } var_values <- neighbourhood$inverse_ohe[1, ] @@ -90,17 +90,17 @@ surv_lime <- function(explainer, new_observation, beta <- data.frame(t(res$par)) names(beta) <- colnames(neighbourhood$inverse_ohe) - ret_list <- list(result = beta, - variable_values = var_values, - black_box_sf_times = explainer$times, - black_box_sf = as.numeric(explainer$predict_survival_function(explainer$model, new_observation, explainer$times)), - expl_sf_times = na_est$time, - expl_sf = exp(-na_est$hazard * exp(sum(var_values * res$par)))) + ret_list <- list( + result = beta, + variable_values = var_values, + black_box_sf_times = explainer$times, + black_box_sf = as.numeric(explainer$predict_survival_function(explainer$model, new_observation, explainer$times)), + expl_sf_times = na_est$time, + expl_sf = exp(-na_est$hazard * exp(sum(var_values * res$par))) + ) class(ret_list) <- c("surv_lime", class(ret_list)) return(ret_list) - - } @@ -129,7 +129,6 @@ generate_neighbourhood <- function(data_org, feature_count <- summary(column) frequencies <- feature_count / sum(feature_count) feature_frequencies[[feature]] <- frequencies - } n_col <- ncol(data_org[, !colnames(data_org) %in% categorical_variables]) @@ -137,14 +136,14 @@ generate_neighbourhood <- function(data_org, me <- attr(scaled_data, "scaled:center") data <- switch(sampling_method, - "gaussian" = matrix(rnorm(n_samples * n_col), nrow = n_samples, ncol = n_col), - stop("Only `gaussian` sampling_method is implemented")) + "gaussian" = matrix(rnorm(n_samples * n_col), nrow = n_samples, ncol = n_col), + stop("Only `gaussian` sampling_method is implemented") + ) if (sample_around_instance) { - to_add <- data_row[, !colnames(data_row) %in% categorical_variables] + to_add <- data_row[, !colnames(data_row) %in% categorical_variables] data <- data %*% diag(sc) + to_add[col(data)] - } - else { + } else { data <- data %*% diag(sc) + me[col(data)] } @@ -175,19 +174,14 @@ generate_neighbourhood <- function(data_org, data <- data[, colnames(data_row)] - if (length(categorical_variables) > 0){ + if (length(categorical_variables) > 0) { expr <- paste0("~", paste(categorical_variables, collapse = "+")) categorical_matrix <- model.matrix(as.formula(expr), data = inverse)[, -1] inverse_ohe <- cbind(inverse, categorical_matrix) inverse_ohe[, factor_variables] <- NULL - } else{ + } else { inverse_ohe <- inverse } - - list(data = data, inverse = inverse, inverse_ohe = inverse_ohe) - - - } diff --git a/R/surv_model_performance.R b/R/surv_model_performance.R index 729dfe32..fc76d738 100644 --- a/R/surv_model_performance.R +++ b/R/surv_model_performance.R @@ -12,31 +12,32 @@ surv_model_performance <- function(explainer, ..., times = NULL, type = "metrics", metrics = NULL) { newdata <- explainer$data if (type == "metrics") { - if (is.null(times)) times <- explainer$times - sf <- explainer$predict_survival_function(explainer$model, newdata, times) - risk <- explainer$predict_function(explainer$model, newdata) - y <- explainer$y - ret_list <- lapply(metrics, function(x) { output <- x(y, risk, sf, times) - attr(output, "loss_type") <- attr(x, "loss_type") - output}) - - ret_list <- list(result = ret_list, eval_times = times) - ret_list$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] - ret_list$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] + if (is.null(times)) times <- explainer$times + + sf <- explainer$predict_survival_function(explainer$model, newdata, times) + risk <- explainer$predict_function(explainer$model, newdata) + y <- explainer$y + ret_list <- lapply(metrics, function(x) { + output <- x(y, risk, sf, times) + attr(output, "loss_type") <- attr(x, "loss_type") + output + }) - class(ret_list) <- c("surv_model_performance", class(ret_list)) - attr(ret_list, "label") <- explainer$label + ret_list <- list(result = ret_list, eval_times = times) + ret_list$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] + ret_list$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] - ret_list - } + class(ret_list) <- c("surv_model_performance", class(ret_list)) + attr(ret_list, "label") <- explainer$label - else { + ret_list + } else { if (is.null(times)) stop("Times cannot be NULL for type `roc`") rocs <- lapply(times, function(time) { censored_earlier_mask <- (explainer$y[, 1] < time & explainer$y[, 2] == 0) event_later_mask <- explainer$y[, 1] > time newdata_t <- newdata[!censored_earlier_mask, ] - labels <- explainer$y[,2] + labels <- explainer$y[, 2] labels[event_later_mask] <- 0 labels <- labels[!censored_earlier_mask] scores <- explainer$predict_survival_function(explainer$model, newdata_t, time) @@ -45,10 +46,12 @@ surv_model_performance <- function(explainer, ..., times = NULL, type = "metrics FPR <- cumsum(!labels) / sum(!labels) vals <- 1 - FPR n <- length(vals) - AUC <- sum((vals[1:(n-1)] + vals[2:n]) * diff(TPR) / 2) - cbind(time = time, data.frame(TPR = TPR, - FPR = FPR, - AUC = AUC)) + AUC <- sum((vals[1:(n - 1)] + vals[2:n]) * diff(TPR) / 2) + cbind(time = time, data.frame( + TPR = TPR, + FPR = FPR, + AUC = AUC + )) }) @@ -58,6 +61,4 @@ surv_model_performance <- function(explainer, ..., times = NULL, type = "metrics attr(rocs_df, "label") <- explainer$label rocs_df } - - } diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index 0bf2b8de..fea98745 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -15,25 +15,23 @@ surv_aggregate_profiles <- function(x, na.omit(as.character(unique(all_profiles$`_vname_`))) if (!is.null(variables)) { all_variables_intersect <- intersect(all_variables, variables) - if (length(all_variables_intersect) == 0) + if (length(all_variables_intersect) == 0) { stop(paste0( "parameter variables do not overlap with ", paste(all_variables, collapse = ", ") )) + } all_variables <- all_variables_intersect } - all_variables <- - intersect(all_variables, unique(all_profiles$`_vname_`)) + all_variables <- intersect(all_variables, unique(all_profiles$`_vname_`)) # select only suitable variables - all_profiles <- - all_profiles[all_profiles$`_vname_` %in% all_variables,] + all_profiles <- all_profiles[all_profiles$`_vname_` %in% all_variables, ] # create _x_ tmp <- as.character(all_profiles$`_vname_`) for (viname in unique(tmp)) { - all_profiles$`_x_`[tmp == viname] <- - all_profiles[tmp == viname, viname] + all_profiles$`_x_`[tmp == viname] <- all_profiles[tmp == viname, viname] } if (!inherits(class(all_profiles), "data.frame")) { @@ -42,53 +40,50 @@ surv_aggregate_profiles <- function(x, # change x column to proper character values for (variable in all_variables) { - if (variable %in% all_profiles[all_profiles$`_vtype_` == "categorical", "_vname_"]) - all_profiles[all_profiles$`_vname_` == variable,]$`_x_` <- - as.character(apply(all_profiles[all_profiles$`_vname_` == variable,], 1, function(all_profiles) - all_profiles[all_profiles["_vname_"]])) + if (variable %in% all_profiles[all_profiles$`_vtype_` == "categorical", "_vname_"]) { + all_profiles[all_profiles$`_vname_` == variable, ]$`_x_` <- + as.character(apply(all_profiles[all_profiles$`_vname_` == variable, ], 1, function(all_profiles) { + all_profiles[all_profiles["_vname_"]] + })) + } } - aggregated_profiles <- - surv_aggregate_profiles_partial(all_profiles) - class(aggregated_profiles) <- - c( - "aggregated_survival_profiles_explainer", - "partial_dependence_survival_explainer", - "data.frame" - ) + aggregated_profiles <- surv_aggregate_profiles_partial(all_profiles) + class(aggregated_profiles) <- c( + "aggregated_survival_profiles_explainer", + "partial_dependence_survival_explainer", + "data.frame" + ) return(aggregated_profiles) } surv_aggregate_profiles_partial <- function(all_profiles) { - tmp <- - all_profiles[, c("_vname_", "_vtype_", "_label_", "_x_", "_yhat_", "_times_")] - aggregated_profiles <- - aggregate( - tmp$`_yhat_`, - by = list( - tmp$`_vname_`, - tmp$`_vtype_`, - tmp$`_label_`, - tmp$`_x_`, - tmp$`_times_` - ), - FUN = mean, - na.rm = TRUE - ) - colnames(aggregated_profiles) <- - c("_vname_", "_vtype_", "_label_", "_x_", "_times_", "_yhat_") + tmp <- all_profiles[, c("_vname_", "_vtype_", "_label_", "_x_", "_yhat_", "_times_")] + aggregated_profiles <- aggregate( + tmp$`_yhat_`, + by = list( + tmp$`_vname_`, + tmp$`_vtype_`, + tmp$`_label_`, + tmp$`_x_`, + tmp$`_times_` + ), + FUN = mean, + na.rm = TRUE + ) + colnames(aggregated_profiles) <- c("_vname_", "_vtype_", "_label_", "_x_", "_times_", "_yhat_") aggregated_profiles$`_ids_` <- 0 # for factors, keep proper order # as in https://github.com/ModelOriented/ingredients/issues/82 if (!is.numeric(all_profiles$`_x_`)) { - aggregated_profiles$`_x_` <- - factor(aggregated_profiles$`_x_`, - levels = unique(all_profiles$`_x_`)) - aggregated_profiles <- - aggregated_profiles[order(aggregated_profiles$`_x_`),] + aggregated_profiles$`_x_` <- factor( + aggregated_profiles$`_x_`, + levels = unique(all_profiles$`_x_`) + ) + aggregated_profiles <- aggregated_profiles[order(aggregated_profiles$`_x_`), ] } aggregated_profiles @@ -103,12 +98,14 @@ surv_ale <- function(x, categorical_variables, grid_points, center = FALSE) { - if (is.null(variables)) + if (is.null(variables)) { variables <- colnames(data) + } # change categorical_features to column names - if (is.numeric(categorical_variables)) + if (is.numeric(categorical_variables)) { categorical_variables <- colnames(data)[categorical_variables] + } additional_categorical_variables <- categorical_variables factor_variables <- colnames(data)[sapply(data, is.factor)] character_variables <- colnames(data)[sapply(data, is.character)] @@ -128,7 +125,7 @@ surv_ale <- function(x, X_lower <- X_upper <- data variable_values <- data[, variable] if (variable %in% categorical_variables) { - if (!is.factor(variable_values)){ + if (!is.factor(variable_values)) { is_numeric <- is.numeric(variable_values) is_factorized <- TRUE variable_values <- as.factor(variable_values) @@ -148,107 +145,104 @@ surv_ale <- function(x, levels_ordered <- levels_original[level_order] # The feature with the levels in the new order - x_ordered <- - order(level_order)[as.numeric(droplevels(variable_values))] + x_ordered <- order(level_order)[as.numeric(droplevels(variable_values))] # Filter rows which are not already at maximum or minimum level values row_ind_increase <- (1:nrow(data))[x_ordered < levels_n] row_ind_decrease <- (1:nrow(data))[x_ordered > 1] - if (is_factorized){ + if (is_factorized) { levels_ordered <- as.character(levels_ordered) - if (is_numeric){ + if (is_numeric) { levels_ordered <- as.numeric(levels_ordered) } } - X_lower[row_ind_decrease, variable] <- - levels_ordered[x_ordered[row_ind_decrease] - 1] - X_upper[row_ind_increase, variable] <- - levels_ordered[x_ordered[row_ind_increase] + 1] + X_lower[row_ind_decrease, variable] <- levels_ordered[x_ordered[row_ind_decrease] - 1] + X_upper[row_ind_increase, variable] <- levels_ordered[x_ordered[row_ind_increase] + 1] # Make predictions for decreased levels (excluding minimum levels) - predictions_lower <- - predict_survival_function(model = model, - newdata = X_lower[row_ind_decrease,], - times = times) + predictions_lower <- predict_survival_function( + model = model, + newdata = X_lower[row_ind_decrease, ], + times = times + ) # Make predictions for increased levels (excluding maximum levels) - predictions_upper <- - predict_survival_function(model = model, - newdata = X_upper[row_ind_increase,], - times = times) - - d_increase <- - predictions_upper - predictions_original[row_ind_increase,] - d_decrease <- - predictions_original[row_ind_decrease,] - predictions_lower + predictions_upper <- predict_survival_function( + model = model, + newdata = X_upper[row_ind_increase, ], + times = times + ) + + d_increase <- predictions_upper - predictions_original[row_ind_increase, ] + d_decrease <- predictions_original[row_ind_decrease, ] - predictions_lower prediction_deltas <- rbind(d_increase, d_decrease) colnames(prediction_deltas) <- times deltas <- data.frame( - interval = rep(c(x_ordered[row_ind_increase], - x_ordered[row_ind_decrease] - 1), - each = length(times)), + interval = rep(c( + x_ordered[row_ind_increase], + x_ordered[row_ind_decrease] - 1 + ), + each = length(times) + ), time = rep(times, times = nrow(prediction_deltas)), yhat = c(t(prediction_deltas)) ) - deltas <- - aggregate(yhat ~ interval + time, - data = deltas, - FUN = mean) - deltas1 <- deltas[deltas$interval == 1,] + deltas <- aggregate( + yhat ~ interval + time, + data = deltas, + FUN = mean + ) + deltas1 <- deltas[deltas$interval == 1, ] deltas1$yhat <- 0 deltas$interval <- deltas$interval + 1 deltas <- rbind(deltas, deltas1) - deltas <- deltas[order(deltas$time, deltas$interval),] + deltas <- deltas[order(deltas$time, deltas$interval), ] rownames(deltas) <- NULL - deltas$yhat_cumsum <- - ave(deltas$yhat, deltas$time, FUN = cumsum) + deltas$yhat_cumsum <- ave(deltas$yhat, deltas$time, FUN = cumsum) x_count <- as.numeric(table(variable_values)) x_prob <- x_count / sum(x_count) - ale_means <- - aggregate( - yhat_cumsum ~ time, - data = deltas, - FUN = function(x) { - sum(x * x_prob[level_order]) - } - ) + ale_means <- aggregate( + yhat_cumsum ~ time, + data = deltas, + FUN = function(x) { + sum(x * x_prob[level_order]) + } + ) colnames(ale_means)[2] <- "ale0" - ale_values <- merge(deltas, - ale_means, - all.x = TRUE, - by = "time") + ale_values <- merge( + deltas, + ale_means, + all.x = TRUE, + by = "time" + ) - ale_values$ale <- - ale_values$yhat_cumsum - ale_values$ale0 + ale_values$ale <- ale_values$yhat_cumsum - ale_values$ale0 ale_values$level <- levels_ordered[ale_values$interval] - ale_values <- - ale_values[order(ale_values$interval, ale_values$time),] - if (!center){ + ale_values <- ale_values[order(ale_values$interval, ale_values$time), ] + + if (!center) { ale_values$ale <- ale_values$ale + mean_pred } - return( - data.frame( - `_vname_` = variable, - `_vtype_` = "categorical", - `_label_` = label, - `_x_` = ale_values$level, - `_times_` = ale_values$time, - `_yhat_` = ale_values$ale, - `_ids_` = 0, - check.names = FALSE - ) - ) - + return(data.frame( + `_vname_` = variable, + `_vtype_` = "categorical", + `_label_` = label, + `_x_` = ale_values$level, + `_times_` = ale_values$time, + `_yhat_` = ale_values$ale, + `_ids_` = 0, + check.names = FALSE + )) } else { # Number of quantile points for determined by grid length quantile_vals <- as.numeric(quantile( @@ -260,13 +254,13 @@ surv_ale <- function(x, # Quantile points vector quantile_vec <- c(min(variable_values), quantile_vals) quantile_vec <- unique(quantile_vec) - quantile_df <- - data.frame(id = 1:length(quantile_vec), - value = quantile_vec) + quantile_df <- data.frame( + id = 1:length(quantile_vec), + value = quantile_vec + ) # Match feature instances to quantile intervals - interval_index <- - findInterval(variable_values, quantile_vec, left.open = TRUE) + interval_index <- findInterval(variable_values, quantile_vec, left.open = TRUE) # Points in interval 0 should be in interval 1 interval_index[interval_index == 0] <- 1 @@ -275,19 +269,20 @@ surv_ale <- function(x, X_lower[, variable] <- quantile_vec[interval_index] X_upper[, variable] <- quantile_vec[interval_index + 1] # Get survival predictions for instances of upper and lower interval limits - predictions_lower <- - predict_survival_function(model = model, - newdata = X_lower, - times = times) + predictions_lower <- predict_survival_function( + model = model, + newdata = X_lower, + times = times + ) - predictions_upper <- - predict_survival_function(model = model, - newdata = X_upper, - times = times) + predictions_upper <- predict_survival_function( + model = model, + newdata = X_upper, + times = times + ) # First order finite differences - prediction_deltas <- - predictions_upper - predictions_lower + prediction_deltas <- predictions_upper - predictions_lower # Rename columns to timepoints for which predictions were made colnames(prediction_deltas) <- times @@ -298,49 +293,48 @@ surv_ale <- function(x, yhat = c(t(prediction_deltas)) ) - deltas <- - aggregate(yhat ~ interval + time, - data = deltas, - FUN = mean) - deltas$yhat_cumsum <- - ave(deltas$yhat, deltas$time, FUN = cumsum) + deltas <- aggregate( + yhat ~ interval + time, + data = deltas, + FUN = mean + ) + deltas$yhat_cumsum <- ave(deltas$yhat, deltas$time, FUN = cumsum) interval_n <- as.numeric(table(interval_index)) n <- sum(interval_n) - ale_means <- - aggregate( - yhat_cumsum ~ time, - data = deltas, - FUN = function(x) { - sum(((c( - 0, x[1:(length(x) - 1)] - ) + x) / 2) * interval_n / n) - } - ) + ale_means <- aggregate( + yhat_cumsum ~ time, + data = deltas, + FUN = function(x) { + sum(((c( + 0, x[1:(length(x) - 1)] + ) + x) / 2) * interval_n / n) + } + ) colnames(ale_means)[2] <- "ale0" # Centering the ALEs to obtain final ALE values ale_values <- merge(deltas, - ale_means, - all.x = TRUE, - by = "time") + ale_means, + all.x = TRUE, + by = "time" + ) - ale_values$ale <- - ale_values$yhat_cumsum - ale_values$ale0 + ale_values$ale <- ale_values$yhat_cumsum - ale_values$ale0 ale_values$interval <- ale_values$interval + 1 - ale_values1 <- - ale_values[seq(1, nrow(ale_values), length(quantile_vec) - 1), ] + ale_values1 <- ale_values[seq(1, nrow(ale_values), length(quantile_vec) - 1), ] ale_values1$interval <- 1 ale_values <- rbind(ale_values, ale_values1) - ale_values <- merge(ale_values, - quantile_df, - by.x = "interval", - by.y = "id") - ale_values <- - ale_values[order(ale_values$interval, ale_values$time),] + ale_values <- merge( + ale_values, + quantile_df, + by.x = "interval", + by.y = "id" + ) + ale_values <- ale_values[order(ale_values$interval, ale_values$time), ] - if (!center){ + if (!center) { ale_values$ale <- ale_values$ale + mean_pred } @@ -357,7 +351,6 @@ surv_ale <- function(x, ) ) } - }) profiles <- do.call(rbind, profiles) diff --git a/R/surv_shap.R b/R/surv_shap.R index 54f24a02..7970e751 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -19,9 +19,7 @@ surv_shap <- function(explainer, ..., y_true = NULL, calculation_method = "kernelshap", - aggregation_method = "integral") -{ - # make this code work for multiple observations + aggregation_method = "integral") { stopifnot( "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( !is.null(y_true), @@ -31,7 +29,8 @@ surv_shap <- function(explainer, is.null(dim(y_true)) && length(y_true) == 2L ), TRUE - )) + ) + ) test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) @@ -43,7 +42,9 @@ surv_shap <- function(explainer, new_observation <- new_observation[, col_index] } - if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)") + if (ncol(explainer$data) != ncol(new_observation)) { + stop("New observation and data have different number of columns (variables)") + } if (!is.null(y_true)) { if (is.matrix(y_true)) { # above, we have already checked that nrows of observations are @@ -62,15 +63,16 @@ surv_shap <- function(explainer, # to display final object correctly, when is.matrix(new_observation) == TRUE res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, - "exact_kernel" = use_exact_shap(explainer, new_observation, ...), - "kernelshap" = use_kernelshap(explainer, new_observation, ...), - stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")) + "exact_kernel" = use_exact_shap(explainer, new_observation, ...), + "kernelshap" = use_kernelshap(explainer, new_observation, ...), + stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented") + ) if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind) res$aggregate <- lapply(res$result, aggregate_surv_shap, method = aggregation_method, times = res$eval_times) - if(nrow(new_observation) > 1){ + if (nrow(new_observation) > 1) { class(res) <- "aggregated_surv_shap" # res$aggregation_method <- aggregation_method res$n_observations <- nrow(new_observation) @@ -83,25 +85,21 @@ surv_shap <- function(explainer, return(res) } -use_exact_shap <- function(explainer, new_observation, observation_aggregation_method, ...){ - +use_exact_shap <- function(explainer, new_observation, observation_aggregation_method, ...) { shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { - as.data.frame(shap_kernel(explainer, new_observation[as.integer(i),], ...)) + as.data.frame(shap_kernel(explainer, new_observation[as.integer(i), ], ...)) }, USE.NAMES = TRUE, simplify = FALSE ) return(shap_values) - } shap_kernel <- function(explainer, new_observation, ...) { - - timestamps <- explainer$times p <- ncol(explainer$data) @@ -113,32 +111,39 @@ shap_kernel <- function(explainer, new_observation, ...) { permutations <- expand.grid(rep(list(0:1), p)) kernel_weights <- generate_shap_kernel_weights(permutations, p) - shap_values <- calculate_shap_values(explainer, explainer$model, baseline_sf, as.data.frame(explainer$data), permutations, kernel_weights, as.data.frame(new_observation), timestamps) + shap_values <- calculate_shap_values( + explainer, + explainer$model, + baseline_sf, + as.data.frame(explainer$data), + permutations, kernel_weights, + as.data.frame(new_observation), + timestamps + ) shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data)) colnames(shap_values) <- paste("t=", timestamps, sep = "") - return (t(shap_values)) + return(t(shap_values)) } generate_shap_kernel_weights <- function(permutations, p) { - apply(permutations, 1, function(row) { row <- as.numeric(row) - num_available_variables = sum(row != 0) + num_available_variables <- sum(row != 0) - if (num_available_variables == 0 || num_available_variables == p) 1e12 - else { + if (num_available_variables == 0 || num_available_variables == p) { + 1e12 + } else { (p - 1) / (choose(p, num_available_variables) * num_available_variables * (p - num_available_variables)) } - }) + }) } calculate_shap_values <- function(explainer, model, avg_survival_function, data, simplified_inputs, shap_kernel_weights, new_observation, timestamps) { - w <- shap_kernel_weights X <- as.matrix(simplified_inputs) @@ -146,16 +151,16 @@ calculate_shap_values <- function(explainer, model, avg_survival_function, data, y <- make_prediction_for_simplified_input(explainer, model, data, simplified_inputs, new_observation, timestamps) - y <- sweep(y, - 2, - avg_survival_function) + y <- sweep( + y, + 2, + avg_survival_function + ) R %*% y - } make_prediction_for_simplified_input <- function(explainer, model, data, simplified_inputs, new_observation, timestamps) { - preds <- apply(simplified_inputs, 1, function(row) { row <- as.logical(row) @@ -164,30 +169,25 @@ make_prediction_for_simplified_input <- function(explainer, model, data, simplif colnames(X_tmp) <- colnames(data) colMeans(explainer$predict_survival_function(model, X_tmp, timestamps)) - }) return(t(preds)) - - } aggregate_surv_shap <- function(survshap, times, method, ...) { - switch( - method, + switch(method, "sum_of_squares" = return(apply(survshap, 2, function(x) sum(x^2))), "mean_absolute" = return(apply(survshap, 2, function(x) mean(abs(x)))), "max_absolute" = return(apply(survshap, 2, function(x) max(abs(x)))), "integral" = return(apply(survshap, 2, function(x) calculate_integral(x, times, normalization = "t_max"))), "integral_absolute" = return(apply(survshap, 2, function(x) calculate_integral(abs(x), times, normalization = "t_max"))), stop("aggregation_method has to be one of 'integral', 'integral_absolute', 'mean_absolute', 'max_absolute', or 'sum_of_squares'") - ) + ) } -use_kernelshap <- function(explainer, new_observation, observation_aggregation_method, ...){ - - predfun <- function(model, newdata){ +use_kernelshap <- function(explainer, new_observation, observation_aggregation_method, ...) { + predfun <- function(model, newdata) { explainer$predict_survival_function( model, newdata, @@ -215,12 +215,10 @@ use_kernelshap <- function(explainer, new_observation, observation_aggregation_m ) return(shap_values) - } -#'@keywords internal +#' @keywords internal aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) { - if (length(shap_res_list) > 1) { shap_res_list <- lapply(shap_res_list, function(x) { x$rn <- rownames(x) @@ -234,12 +232,13 @@ aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, a # multiple observations tmp_res <- aggregate(full_survshap_results[, !colnames(full_survshap_results) %in% c("rn")], - by = list(full_survshap_results$rn), - FUN = aggregation_function) + by = list(full_survshap_results$rn), + FUN = aggregation_function + ) rownames(tmp_res) <- tmp_res$Group.1 - ordering <- order(as.numeric(substring(rownames(tmp_res),3))) + ordering <- order(as.numeric(substring(rownames(tmp_res), 3))) - tmp_res <- tmp_res[ordering, !colnames(tmp_res) %in% c("rn","Group.1")] + tmp_res <- tmp_res[ordering, !colnames(tmp_res) %in% c("rn", "Group.1")] } else { # no aggregation required tmp_res <- shap_res_list[[1]] diff --git a/R/utils.R b/R/utils.R index 146290b4..a195a0c0 100644 --- a/R/utils.R +++ b/R/utils.R @@ -18,7 +18,7 @@ #' #' @export cumulative_hazard_to_survival <- function(hazard_functions) { - return(exp(-hazard_functions)) + return(exp(-hazard_functions)) } #' Transform Survival to Cumulative Hazard @@ -41,7 +41,7 @@ cumulative_hazard_to_survival <- function(hazard_functions) { #' #' @export survival_to_cumulative_hazard <- function(survival_functions, epsilon = 0) { - return(-log(survival_functions)) + return(-log(survival_functions)) } # tests if the explainer has all the required fields @@ -51,29 +51,35 @@ test_explainer <- function(explainer, has_y = FALSE, has_survival = FALSE, has_chf = FALSE, - has_predict = FALSE) -{ - if (!("surv_explainer" %in% class(explainer))) - stop(paste0("The ", function_name, " function requires an object created with survex::explain() function.")) - if (has_data && is.null(explainer$data)) - stop(paste0("The ", function_name, " function requires explainers with specified `data` parameter")) - if (has_y && is.null(explainer$y)) + has_predict = FALSE) { + if (!("surv_explainer" %in% class(explainer))) { + stop(paste0("The ", function_name, " function requires an object created with survex::explain() function.")) + } + if (has_data && is.null(explainer$data)) { + stop(paste0("The ", function_name, " function requires explainers with specified `data` parameter")) + } + if (has_y && is.null(explainer$y)) { stop(paste0("The ", function_name, " function requires explainers with specified `y` parameter")) - if (has_survival && is.null(explainer$predict_survival_function)) + } + if (has_survival && is.null(explainer$predict_survival_function)) { stop(paste0("The ", function_name, " function requires explainers with specified `predict_survival_function` parameter")) - if (has_chf && is.null(explainer$predict_cumulative_hazard_function)) + } + if (has_chf && is.null(explainer$predict_cumulative_hazard_function)) { stop(paste0("The ", function_name, " function requires explainers with specified `predict_cumulative_hazard_function` parameter")) - if (has_predict && is.null(explainer$predict_function)) + } + if (has_predict && is.null(explainer$predict_function)) { stop(paste0("The ", function_name, " function requires explainers with specified `predict_risk` parameter")) + } } #' @importFrom DALEX colors_discrete_drwhy generate_discrete_color_scale <- function(n, colors = NULL) { - - if (is.null(colors) || length(colors) < n) return(colors_discrete_drwhy(n)) - else return(colors[(0:(n - 1) %% length(colors)) + 1]) - + if (is.null(colors) || length(colors) < n) { + return(colors_discrete_drwhy(n)) + } else { + return(colors[(0:(n - 1) %% length(colors)) + 1]) + } } #' Transform Fixed Point Prediction into a Stepfunction @@ -96,46 +102,45 @@ generate_discrete_color_scale <- function(n, colors = NULL) { #' rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) #' #' chf_function <- transform_to_stepfunction(predict, -#' type = "chf", -#' prediction_element = "chf", -#' times_element = "time.interest") +#' type = "chf", +#' prediction_element = "chf", +#' times_element = "time.interest" +#' ) #' #' explainer <- explain(rsf_src, predict_cumulative_hazard_function = chf_function) #' #' @export transform_to_stepfunction <- function(predict_function, eval_times = NULL, ..., type = NULL, prediction_element = NULL, times_element = NULL) { - - function(model, newdata, times) { - raw_prediction <- predict_function(model, newdata, ...) - if (!is.null(times_element)) eval_times <- raw_prediction[[times_element]] - if (!is.null(prediction_element)) prediction <- raw_prediction[[prediction_element]] - n_rows <- ifelse(is.null(dim(prediction)), 1, nrow(prediction)) - return_matrix <- matrix(nrow = n_rows, ncol = length(times)) + function(model, newdata, times) { + raw_prediction <- predict_function(model, newdata, ...) + if (!is.null(times_element)) eval_times <- raw_prediction[[times_element]] + if (!is.null(prediction_element)) prediction <- raw_prediction[[prediction_element]] + n_rows <- ifelse(is.null(dim(prediction)), 1, nrow(prediction)) + return_matrix <- matrix(nrow = n_rows, ncol = length(times)) - if (is.null(dim(prediction))) { + if (is.null(dim(prediction))) { + padding <- switch(type, + "survival" = 1, + "chf" = 0, + prediction[1] + ) + stepfunction <- stepfun(eval_times, c(padding, prediction)) + return_matrix[1, ] <- stepfunction(times) + } else { + for (i in 1:n_rows) { padding <- switch(type, - "survival" = 1, - "chf" = 0, - prediction[1]) - stepfunction <- stepfun(eval_times, c(padding, prediction)) - return_matrix[1, ] <- stepfunction(times) - - } - else { - for (i in 1:n_rows) { - padding <- switch(type, - "survival" = 1, - "chf" = 0, - prediction[i, 1]) - stepfunction <- stepfun(eval_times, c(padding, prediction[i, ])) - return_matrix[i, ] <- stepfunction(times) - } + "survival" = 1, + "chf" = 0, + prediction[i, 1] + ) + stepfunction <- stepfun(eval_times, c(padding, prediction[i, ])) + return_matrix[i, ] <- stepfunction(times) } - - return_matrix } + return_matrix + } } #' Generate Risk Prediction based on the Survival Function @@ -154,14 +159,16 @@ transform_to_stepfunction <- function(predict_function, eval_times = NULL, ..., #' rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) #' #' chf_function <- transform_to_stepfunction(predict, -#' type = "chf", -#' prediction_element = "chf", -#' times_element = "time.interest") +#' type = "chf", +#' prediction_element = "chf", +#' times_element = "time.interest" +#' ) #' risk_function <- risk_from_chf(chf_function, unique(veteran$time)) #' #' explainer <- explain(rsf_src, -#' predict_cumulative_hazard_function = chf_function, -#' predict_function = risk_function) +#' predict_cumulative_hazard_function = chf_function, +#' predict_function = risk_function +#' ) #' #' @export risk_from_chf <- function(predict_cumulative_hazard_function, times) { @@ -179,44 +186,51 @@ risk_from_chf <- function(predict_cumulative_hazard_function, times) { #' @return An object of classes `c("predict_parts_survival", "surv_shap")`. It is a list with the element `result` containing the results of the explanation. #' #' @examples +#' \donttest{ #' veteran <- survival::veteran #' rsf_ranger <- ranger::ranger( -#' survival::Surv(time, status) ~ ., -#' data = veteran, -#' respect.unordered.factors = TRUE, -#' num.trees = 100, -#' mtry = 3, -#' max.depth = 5 +#' survival::Surv(time, status) ~ ., +#' data = veteran, +#' respect.unordered.factors = TRUE, +#' num.trees = 100, +#' mtry = 3, +#' max.depth = 5 #' ) #' rsf_ranger_exp <- explain( -#' rsf_ranger, -#' data = veteran[, -c(3, 4)], -#' y = survival::Surv(veteran$time, veteran$status), -#' verbose = FALSE +#' rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = survival::Surv(veteran$time, veteran$status), +#' verbose = FALSE #' ) #' #' ranger_global_survshap <- model_survshap( -#' explainer = rsf_ranger_exp, -#' new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), -#' !colnames(veteran) %in% c("time", "status")]) +#' explainer = rsf_ranger_exp, +#' new_observation = veteran[ +#' c(1:4, 17:20, 110:113, 126:129), +#' !colnames(veteran) %in% c("time", "status") +#' ] +#' ) #' #' local_survshap_1 <- extract_predict_survshap(ranger_global_survshap, index = 1) #' plot(local_survshap_1) +#' } #' #' @export -extract_predict_survshap <- function(aggregated_survshap, index){ - if (!inherits(aggregated_survshap, "aggregated_surv_shap")) +extract_predict_survshap <- function(aggregated_survshap, index) { + if (!inherits(aggregated_survshap, "aggregated_surv_shap")) { stop("`aggregated_survshap` object must be of class 'aggregated_surv_shap'") + } - if (index > aggregated_survshap$n_observations) + if (index > aggregated_survshap$n_observations) { stop(paste("Incorrect `index`, number of observations in `aggregated_survshap` is", aggregated_survshap$n_observations)) + } res <- list() res$eval_times <- aggregated_survshap$eval_times res$event_times <- aggregated_survshap$event_times res$event_statuses <- aggregated_survshap$event_statuses - res$variable_values <- aggregated_survshap$variable_values[index,] + res$variable_values <- aggregated_survshap$variable_values[index, ] res$result <- aggregated_survshap$result[[index]] res$aggregate <- aggregated_survshap$aggregate[[index]] class(res) <- c("predict_parts_survival", "surv_shap") @@ -227,17 +241,23 @@ extract_predict_survshap <- function(aggregated_survshap, index){ #' @keywords internal -add_rug_to_plot <- function(base_plot, rug_df, rug, rug_colors){ - if (rug == "all"){ - return_plot <- with(rug_df, { base_plot + - geom_rug(data = rug_df[rug_df$statuses == 1,], mapping = aes(x=times, color = statuses), inherit.aes=F, color = rug_colors[1]) + - geom_rug(data = rug_df[rug_df$statuses == 0,], mapping = aes(x=times, color = statuses), inherit.aes=F, color = rug_colors[2]) }) +add_rug_to_plot <- function(base_plot, rug_df, rug, rug_colors) { + if (rug == "all") { + return_plot <- with(rug_df, { + base_plot + + geom_rug(data = rug_df[rug_df$statuses == 1, ], mapping = aes(x = times, color = statuses), inherit.aes = F, color = rug_colors[1]) + + geom_rug(data = rug_df[rug_df$statuses == 0, ], mapping = aes(x = times, color = statuses), inherit.aes = F, color = rug_colors[2]) + }) } else if (rug == "events") { - return_plot <- with(rug_df, { base_plot + - geom_rug(data = rug_df[rug_df$statuses == 1,], mapping = aes(x=times, color = statuses), inherit.aes=F, color = rug_colors[1]) }) + return_plot <- with(rug_df, { + base_plot + + geom_rug(data = rug_df[rug_df$statuses == 1, ], mapping = aes(x = times, color = statuses), inherit.aes = F, color = rug_colors[1]) + }) } else if (rug == "censors") { - return_plot <- with(rug_df, { base_plot + - geom_rug(data = rug_df[rug_df$statuses == 0,], mapping = aes(x=times, color = statuses), inherit.aes=F, color = rug_colors[2]) }) + return_plot <- with(rug_df, { + base_plot + + geom_rug(data = rug_df[rug_df$statuses == 0, ], mapping = aes(x = times, color = statuses), inherit.aes = F, color = rug_colors[2]) + }) } else { return_plot <- base_plot } @@ -245,19 +265,18 @@ add_rug_to_plot <- function(base_plot, rug_df, rug, rug_colors){ #' @keywords internal -calculate_integral <- function(values, times, normalization = "t_max", ...){ +calculate_integral <- function(values, times, normalization = "t_max", ...) { n <- length(values) - if (is.null(normalization)){ + if (is.null(normalization)) { tmp <- (values[1:(n - 1)] + values[2:n]) * diff(times) / 2 integrated_metric <- sum(tmp) / (max(times) - min(times)) return(integrated_metric) - } - else if (normalization == "t_max") { + } else if (normalization == "t_max") { tmp <- (values[1:(n - 1)] + values[2:n]) * diff(times) / 2 integrated_metric <- sum(tmp) - return(integrated_metric/max(times)) - } else if (normalization == "survival"){ + return(integrated_metric / max(times)) + } else if (normalization == "survival") { y_true <- list(...)$y_true km <- survival::survfit(y_true ~ 1) estimator <- stepfun(km$time, c(1, km$surv)) @@ -266,7 +285,7 @@ calculate_integral <- function(values, times, normalization = "t_max", ...){ tmp <- (values[1:(n - 1)] + values[2:n]) * diff(dwt) / 2 integrated_metric <- sum(tmp) - return(integrated_metric/(1 - estimator(max(times)))) + return(integrated_metric / (1 - estimator(max(times)))) } } @@ -306,4 +325,3 @@ order_levels <- function(data, variable_values, variable_name) { scaled <- cmdscale(dists.cumulated, k = 1) order(scaled) } - diff --git a/man/c_index.Rd b/man/c_index.Rd index 83d0d02b..3c346222 100644 --- a/man/c_index.Rd +++ b/man/c_index.Rd @@ -37,7 +37,8 @@ rotterdam <- survival::rotterdam rotterdam$year <- NULL cox_rotterdam_rec <- coxph(Surv(rtime, recur) ~ ., data = rotterdam, - model = TRUE, x = TRUE, y = TRUE) + model = TRUE, x = TRUE, y = TRUE +) coxph_explainer <- explain(cox_rotterdam_rec) risk <- coxph_explainer$predict_function(coxph_explainer$model, coxph_explainer$data) diff --git a/man/explain_survival.Rd b/man/explain_survival.Rd index 3ec1fcc5..afdd90e1 100644 --- a/man/explain_survival.Rd +++ b/man/explain_survival.Rd @@ -125,14 +125,20 @@ with \code{pec::predictSurvProb()} method. library(survival) library(survex) -cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, - model = TRUE, x = TRUE) +cph <- survival::coxph(survival::Surv(time, status) ~ ., + data = veteran, + model = TRUE, x = TRUE +) cph_exp <- explain(cph) -rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, - respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) -rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], - y = Surv(veteran$time, veteran$status)) +rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., + data = veteran, + respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 +) +rsf_ranger_exp <- explain(rsf_ranger, + data = veteran[, -c(3, 4)], + y = Surv(veteran$time, veteran$status) +) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) rsf_src_exp <- explain(rsf_src) @@ -140,28 +146,30 @@ rsf_src_exp <- explain(rsf_src) library(censored, quietly = TRUE) bt <- parsnip::boost_tree() \%>\% - parsnip::set_engine("mboost") \%>\% - parsnip::set_mode("censored regression") \%>\% - generics::fit(survival::Surv(time, status) ~ ., data = veteran) + parsnip::set_engine("mboost") \%>\% + parsnip::set_mode("censored regression") \%>\% + generics::fit(survival::Surv(time, status) ~ ., data = veteran) bt_exp <- explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status)) ###### explain_survival() ###### -cph <- coxph(Surv(time, status) ~ ., data=veteran) +cph <- coxph(Surv(time, status) ~ ., data = veteran) -veteran_data <- veteran[, -c(3,4)] +veteran_data <- veteran[, -c(3, 4)] veteran_y <- Surv(veteran$time, veteran$status) risk_pred <- function(model, newdata) predict(model, newdata, type = "risk") surv_pred <- function(model, newdata, times) pec::predictSurvProb(model, newdata, times) chf_pred <- function(model, newdata, times) -log(surv_pred(model, newdata, times)) -manual_cph_explainer <- explain_survival(model = cph, - data = veteran_data, - y = veteran_y, - predict_function = risk_pred, - predict_survival_function = surv_pred, - predict_cumulative_hazard_function = chf_pred, - label = "manual coxph") +manual_cph_explainer <- explain_survival( + model = cph, + data = veteran_data, + y = veteran_y, + predict_function = risk_pred, + predict_survival_function = surv_pred, + predict_cumulative_hazard_function = chf_pred, + label = "manual coxph" +) } } diff --git a/man/extract_predict_survshap.Rd b/man/extract_predict_survshap.Rd index 855cbdc1..7a1c988a 100644 --- a/man/extract_predict_survshap.Rd +++ b/man/extract_predict_survshap.Rd @@ -19,28 +19,33 @@ Helper function to extract local SurvSHAP(t) explanation from global one. Can be can be useful for creating SurvSHAP(t) plots for single observations. } \examples{ +\donttest{ veteran <- survival::veteran rsf_ranger <- ranger::ranger( - survival::Surv(time, status) ~ ., - data = veteran, - respect.unordered.factors = TRUE, - num.trees = 100, - mtry = 3, - max.depth = 5 + survival::Surv(time, status) ~ ., + data = veteran, + respect.unordered.factors = TRUE, + num.trees = 100, + mtry = 3, + max.depth = 5 ) rsf_ranger_exp <- explain( - rsf_ranger, - data = veteran[, -c(3, 4)], - y = survival::Surv(veteran$time, veteran$status), - verbose = FALSE + rsf_ranger, + data = veteran[, -c(3, 4)], + y = survival::Surv(veteran$time, veteran$status), + verbose = FALSE ) ranger_global_survshap <- model_survshap( - explainer = rsf_ranger_exp, - new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), - !colnames(veteran) \%in\% c("time", "status")]) + explainer = rsf_ranger_exp, + new_observation = veteran[ + c(1:4, 17:20, 110:113, 126:129), + !colnames(veteran) \%in\% c("time", "status") + ] +) local_survshap_1 <- extract_predict_survshap(ranger_global_survshap, index = 1) plot(local_survshap_1) +} } diff --git a/man/loss_one_minus_c_index.Rd b/man/loss_one_minus_c_index.Rd index ce0a587d..87b04457 100644 --- a/man/loss_one_minus_c_index.Rd +++ b/man/loss_one_minus_c_index.Rd @@ -37,7 +37,8 @@ rotterdam <- survival::rotterdam rotterdam$year <- NULL cox_rotterdam_rec <- coxph(Surv(rtime, recur) ~ ., data = rotterdam, - model = TRUE, x = TRUE, y = TRUE) + model = TRUE, x = TRUE, y = TRUE +) coxph_explainer <- explain(cox_rotterdam_rec) risk <- coxph_explainer$predict_function(coxph_explainer$model, coxph_explainer$data) diff --git a/man/model_performance.surv_explainer.Rd b/man/model_performance.surv_explainer.Rd index 91313d4d..3b885366 100644 --- a/man/model_performance.surv_explainer.Rd +++ b/man/model_performance.surv_explainer.Rd @@ -59,18 +59,22 @@ library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., - data = veteran, - respect.unordered.factors = TRUE, - num.trees = 100, - mtry = 3, - max.depth = 5) + data = veteran, + respect.unordered.factors = TRUE, + num.trees = 100, + mtry = 3, + max.depth = 5 +) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., - data = veteran) + data = veteran +) cph_exp <- explain(cph) -rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], - y = Surv(veteran$time, veteran$status)) +rsf_ranger_exp <- explain(rsf_ranger, + data = veteran[, -c(3, 4)], + y = Surv(veteran$time, veteran$status) +) rsf_src_exp <- explain(rsf_src) @@ -81,7 +85,9 @@ rsf_src_model_performance <- model_performance(rsf_src_exp) print(cph_model_performance) plot(rsf_ranger_model_performance, cph_model_performance, - rsf_src_model_performance, metrics_type = "scalar") + rsf_src_model_performance, + metrics_type = "scalar" +) plot(rsf_ranger_model_performance, cph_model_performance, rsf_src_model_performance) diff --git a/man/model_profile.surv_explainer.Rd b/man/model_profile.surv_explainer.Rd index 6d56bdef..bd2bea6c 100644 --- a/man/model_profile.surv_explainer.Rd +++ b/man/model_profile.surv_explainer.Rd @@ -75,16 +75,20 @@ rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph) rsf_src_exp <- explain(rsf_src) -cph_model_profile <- model_profile(cph_exp, output_type = "survival", - variables = c("age")) +cph_model_profile <- model_profile(cph_exp, + output_type = "survival", + variables = c("age") +) head(cph_model_profile$result) plot(cph_model_profile) -rsf_model_profile <- model_profile(rsf_src_exp, output_type = "survival", - variables = c("age", "celltype"), - type = "accumulated") +rsf_model_profile <- model_profile(rsf_src_exp, + output_type = "survival", + variables = c("age", "celltype"), + type = "accumulated" +) head(rsf_model_profile$result) diff --git a/man/model_profile_2d.surv_explainer.Rd b/man/model_profile_2d.surv_explainer.Rd index 56cf0e44..2204f706 100644 --- a/man/model_profile_2d.surv_explainer.Rd +++ b/man/model_profile_2d.surv_explainer.Rd @@ -64,13 +64,15 @@ cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = cph_exp <- explain(cph) cph_model_profile_2d <- model_profile_2d(cph_exp, - variables = list(c("age", "celltype"))) + variables = list(c("age", "celltype")) +) head(cph_model_profile_2d$result) plot(cph_model_profile_2d) cph_model_profile_2d_ale <- model_profile_2d(cph_exp, - variables = list(c("age", "karno")), - type = "accumulated") + variables = list(c("age", "karno")), + type = "accumulated" +) head(cph_model_profile_2d_ale$result) plot(cph_model_profile_2d_ale) } diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 491fe3e7..9a51ced0 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -44,28 +44,32 @@ global SHAP values are computed for the data, the \code{explainer} was trained w \donttest{ veteran <- survival::veteran rsf_ranger <- ranger::ranger( - survival::Surv(time, status) ~ ., - data = veteran, - respect.unordered.factors = TRUE, - num.trees = 100, - mtry = 3, - max.depth = 5 + survival::Surv(time, status) ~ ., + data = veteran, + respect.unordered.factors = TRUE, + num.trees = 100, + mtry = 3, + max.depth = 5 ) rsf_ranger_exp <- explain( - rsf_ranger, - data = veteran[, -c(3, 4)], - y = survival::Surv(veteran$time, veteran$status), - verbose = FALSE + rsf_ranger, + data = veteran[, -c(3, 4)], + y = survival::Surv(veteran$time, veteran$status), + verbose = FALSE ) ranger_global_survshap <- model_survshap( - explainer = rsf_ranger_exp, - new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), - !colnames(veteran) \%in\% c("time", "status")], - y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], - veteran$status[c(1:4, 17:20, 110:113, 126:129)]), - aggregation_method = "integral", - calculation_method = "kernelshap", + explainer = rsf_ranger_exp, + new_observation = veteran[ + c(1:4, 17:20, 110:113, 126:129), + !colnames(veteran) \%in\% c("time", "status") + ], + y_true = survival::Surv( + veteran$time[c(1:4, 17:20, 110:113, 126:129)], + veteran$status[c(1:4, 17:20, 110:113, 126:129)] + ), + aggregation_method = "integral", + calculation_method = "kernelshap", ) plot(ranger_global_survshap) plot(ranger_global_survshap, geom = "beeswarm") diff --git a/man/plot.aggregated_surv_shap.Rd b/man/plot.aggregated_surv_shap.Rd index b1e0bbb1..3e9a836e 100644 --- a/man/plot.aggregated_surv_shap.Rd +++ b/man/plot.aggregated_surv_shap.Rd @@ -64,28 +64,32 @@ explanations of survival models created using the \code{model_survshap()} functi \donttest{ veteran <- survival::veteran rsf_ranger <- ranger::ranger( - survival::Surv(time, status) ~ ., - data = veteran, - respect.unordered.factors = TRUE, - num.trees = 100, - mtry = 3, - max.depth = 5 + survival::Surv(time, status) ~ ., + data = veteran, + respect.unordered.factors = TRUE, + num.trees = 100, + mtry = 3, + max.depth = 5 ) rsf_ranger_exp <- explain( - rsf_ranger, - data = veteran[, -c(3, 4)], - y = survival::Surv(veteran$time, veteran$status), - verbose = FALSE + rsf_ranger, + data = veteran[, -c(3, 4)], + y = survival::Surv(veteran$time, veteran$status), + verbose = FALSE ) ranger_global_survshap <- model_survshap( - explainer = rsf_ranger_exp, - new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), - !colnames(veteran) \%in\% c("time", "status")], - y_true = survival::Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], - veteran$status[c(1:4, 17:20, 110:113, 126:129)]), - aggregation_method = "integral", - calculation_method = "kernelshap", + explainer = rsf_ranger_exp, + new_observation = veteran[ + c(1:4, 17:20, 110:113, 126:129), + !colnames(veteran) \%in\% c("time", "status") + ], + y_true = survival::Surv( + veteran$time[c(1:4, 17:20, 110:113, 126:129)], + veteran$status[c(1:4, 17:20, 110:113, 126:129)] + ), + aggregation_method = "integral", + calculation_method = "kernelshap", ) plot(ranger_global_survshap) plot(ranger_global_survshap, geom = "beeswarm") diff --git a/man/plot.model_profile_2d_survival.Rd b/man/plot.model_profile_2d_survival.Rd index 36bad02f..a6b7d30f 100644 --- a/man/plot.model_profile_2d_survival.Rd +++ b/man/plot.model_profile_2d_survival.Rd @@ -51,14 +51,18 @@ cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = cph_exp <- explain(cph) cph_model_profile_2d <- model_profile_2d(cph_exp, - variables = list(c("age", "celltype"), - c("age", "karno"))) + variables = list( + c("age", "celltype"), + c("age", "karno") + ) +) head(cph_model_profile_2d$result) plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 103) cph_model_profile_2d_ale <- model_profile_2d(cph_exp, - variables = list(c("age", "karno")), - type = "accumulated") + variables = list(c("age", "karno")), + type = "accumulated" +) head(cph_model_profile_2d_ale$result) plot(cph_model_profile_2d_ale, times = c(8, 103), marginalize_over_time = TRUE) } diff --git a/man/predict.surv_explainer.Rd b/man/predict.surv_explainer.Rd index f8189bf6..f45ae437 100644 --- a/man/predict.surv_explainer.Rd +++ b/man/predict.surv_explainer.Rd @@ -30,16 +30,19 @@ library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., - data = veteran, - respect.unordered.factors = TRUE, - num.trees = 100, - mtry = 3, - max.depth = 5) + data = veteran, + respect.unordered.factors = TRUE, + num.trees = 100, + mtry = 3, + max.depth = 5 +) cph_exp <- explain(cph) -rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], - y = Surv(veteran$time, veteran$status)) +rsf_ranger_exp <- explain(rsf_ranger, + data = veteran[, -c(3, 4)], + y = Surv(veteran$time, veteran$status) +) predict(cph_exp, veteran[1, ], output_type = "survival")[, 1:10] diff --git a/man/predict_parts.surv_explainer.Rd b/man/predict_parts.surv_explainer.Rd index 9c67cb7a..2b705be7 100644 --- a/man/predict_parts.surv_explainer.Rd +++ b/man/predict_parts.surv_explainer.Rd @@ -84,8 +84,11 @@ cph_predict_parts_survshap <- predict_parts(cph_exp, new_observation = veteran[1 head(cph_predict_parts_survshap$result) plot(cph_predict_parts_survshap) -cph_predict_parts_survlime <- predict_parts(cph_exp, new_observation = veteran[1, -c(3, 4)], - type = "survlime") +cph_predict_parts_survlime <- predict_parts( + cph_exp, + new_observation = veteran[1, -c(3, 4)], + type = "survlime" +) head(cph_predict_parts_survlime$result) plot(cph_predict_parts_survlime, type = "local_importance") } diff --git a/man/predict_profile.surv_explainer.Rd b/man/predict_profile.surv_explainer.Rd index 9818a637..b9905d10 100644 --- a/man/predict_profile.surv_explainer.Rd +++ b/man/predict_profile.surv_explainer.Rd @@ -65,8 +65,9 @@ cph_exp <- explain(cph) rsf_src_exp <- explain(rsf_src) cph_predict_profile <- predict_profile(cph_exp, veteran[2, -c(3, 4)], - variables = c("trt", "celltype", "karno", "age"), - categorical_variables = "trt") + variables = c("trt", "celltype", "karno", "age"), + categorical_variables = "trt" +) plot(cph_predict_profile, facet_ncol = 2) diff --git a/man/risk_from_chf.Rd b/man/risk_from_chf.Rd index 5c312686..f70e43b0 100644 --- a/man/risk_from_chf.Rd +++ b/man/risk_from_chf.Rd @@ -24,13 +24,15 @@ library(survival) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) chf_function <- transform_to_stepfunction(predict, - type = "chf", - prediction_element = "chf", - times_element = "time.interest") + type = "chf", + prediction_element = "chf", + times_element = "time.interest" +) risk_function <- risk_from_chf(chf_function, unique(veteran$time)) explainer <- explain(rsf_src, - predict_cumulative_hazard_function = chf_function, - predict_function = risk_function) + predict_cumulative_hazard_function = chf_function, + predict_function = risk_function +) } diff --git a/man/surv_model_info.Rd b/man/surv_model_info.Rd index 42de5972..6c0ed901 100644 --- a/man/surv_model_info.Rd +++ b/man/surv_model_info.Rd @@ -56,14 +56,18 @@ Currently supported packages are: \examples{ library(survival) library(survex) -cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, - model = TRUE, x = TRUE, y = TRUE) +cph <- survival::coxph(survival::Surv(time, status) ~ ., + data = veteran, + model = TRUE, x = TRUE, y = TRUE +) surv_model_info(cph) \donttest{ library(ranger) -rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, - num.trees = 50, mtry = 3, max.depth = 5) +rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., + data = veteran, + num.trees = 50, mtry = 3, max.depth = 5 +) surv_model_info(rsf_ranger) } diff --git a/man/transform_to_stepfunction.Rd b/man/transform_to_stepfunction.Rd index e0d1fb56..970b3d8d 100644 --- a/man/transform_to_stepfunction.Rd +++ b/man/transform_to_stepfunction.Rd @@ -39,9 +39,10 @@ library(survival) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) chf_function <- transform_to_stepfunction(predict, - type = "chf", - prediction_element = "chf", - times_element = "time.interest") + type = "chf", + prediction_element = "chf", + times_element = "time.interest" +) explainer <- explain(rsf_src, predict_cumulative_hazard_function = chf_function) From 746d9bd7908bed1aa111668a0416db519ac85d0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Tue, 22 Aug 2023 11:22:21 +0200 Subject: [PATCH 153/207] Request censored version (>=0.2.0) because of changed argument in predict function --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 545b8b09..a5a29e88 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ Imports: survival, patchwork Suggests: - censored, + censored (>= 0.2.0), covr, gbm, generics, From 0ef84261d08c9462bfe9ac752bc794e342f8bb13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Tue, 22 Aug 2023 15:08:45 +0200 Subject: [PATCH 154/207] Increase test coverage --- tests/testthat/test-model_parts.R | 4 +++- tests/testthat/test-model_performance.R | 5 +++++ tests/testthat/test-model_profile.R | 12 +++++++++++- tests/testthat/test-predict_profile.R | 1 + 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-model_parts.R b/tests/testthat/test-model_parts.R index 101c9ca0..1a5c94f0 100644 --- a/tests/testthat/test-model_parts.R +++ b/tests/testthat/test-model_parts.R @@ -33,6 +33,8 @@ test_that("C-index fpi works", { y = survival::Surv(rotterdam$rtime, rotterdam$recur), verbose = FALSE) + + expect_error(model_parts(coxph_explainer, output_type = "nonexistent")) mp_cph_cind <- model_parts(coxph_explainer, loss = loss_one_minus_c_index, type = "variable_importance", output_type = "risk") mp_rsf_cind <- model_parts(forest_explainer, loss = loss_one_minus_c_index, output_type = "risk") @@ -202,6 +204,6 @@ test_that("integrated metrics fpi works", { rsf_ranger_exp$data <- NULL expect_error(model_parts(rsf_ranger_exp, loss = loss_integrated_brier_score, B = 10, type = "raw")) - + }) diff --git a/tests/testthat/test-model_performance.R b/tests/testthat/test-model_performance.R index eeae53e6..24203b51 100644 --- a/tests/testthat/test-model_performance.R +++ b/tests/testthat/test-model_performance.R @@ -44,6 +44,9 @@ test_that("model_performance works", { plot(cph_rot_perf, rsf_rot_perf, rug = "censors") plot(cph_rot_perf, rsf_rot_perf, rug = "none") plot(cph_rot_perf, rsf_rot_perf, metrics_type = "scalar") + plot(cph_rot_perf, metrics = "Integrated Brier score") + + expect_error(plot(cph_rot_perf, metrics_type = "nonexistent")) cph_rot_perf_roc <- model_performance(cph_exp_rot, type = "roc", times = c(100, 200)) rsf_rot_perf_roc <- model_performance(rsf_exp_rot, type = "roc", times = c(100, 200)) @@ -51,6 +54,8 @@ test_that("model_performance works", { plot(cph_rot_perf_roc) plot(rsf_rot_perf_roc) + + expect_error(model_performance(rsf_exp_rot, type = "roc", times = NULL)) }) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index c781a672..0530404f 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -37,6 +37,12 @@ test_that("model_profile with type = 'partial' works", { plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "pdp") plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "ice") + expect_error(plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "nonexistent", times = cph_exp$times[1])) + expect_error(plot(mp_cph_num, geom = "variable", variables = 1, plot_type = "pdp+ice", times = cph_exp$times[1])) + expect_error(plot(mp_cph_num, geom = "variable", variables = c("karno", "diagtime"), plot_type = "pdp+ice", times = cph_exp$times[1])) + expect_error(plot(mp_cph_num, geom = "variable", variables = "nonexistent", plot_type = "pdp+ice", times = cph_exp$times[1])) + + expect_s3_class(mp_cph_num, "model_profile_survival") expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) expect_equal(ncol(mp_cph_num$result), 7) @@ -83,7 +89,11 @@ test_that("model_profile with type = 'partial' works", { expect_error(plot(mp_rsf_num, geom = "variable", variables = "age", times = -1)) expect_error(plot(mp_rsf_num, geom = "nonexistent")) expect_error(plot(mp_rsf_num, nonsense_argument = "character")) - }) + + centered_mp <- model_profile(rsf_src_exp, variables = "karno", center = TRUE) + plot(centered_mp, geom = "variable", variables = "karno", times = rsf_src_exp$times[1]) + +}) test_that("model_profile with type = 'accumulated' works", { veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] diff --git a/tests/testthat/test-predict_profile.R b/tests/testthat/test-predict_profile.R index e28b4fdd..b66fe9be 100644 --- a/tests/testthat/test-predict_profile.R +++ b/tests/testthat/test-predict_profile.R @@ -20,6 +20,7 @@ test_that("ceteris_paribus works", { plot(cph_pp, rug = "censors", variable_type = "numerical") plot(cph_pp, rug = "none") + expect_error(plot(cph_pp, variables = "aaa")) expect_error(plot(cph_pp, variable_type = "nonexistent")) expect_error(plot(cph_pp, numerical_plot_type = "nonexistent")) expect_error(predict_profile(cph_exp, veteran[2, -c(3, 4)], output_type = "nonexistent")) From 2c139577ab04f2f32aa5b983d66968d994aa7999 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 22 Aug 2023 17:11:40 +0200 Subject: [PATCH 155/207] fix plot for survlime --- R/plot_surv_lime.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/plot_surv_lime.R b/R/plot_surv_lime.R index c00eccfe..ecefd562 100644 --- a/R/plot_surv_lime.R +++ b/R/plot_surv_lime.R @@ -98,4 +98,5 @@ plot.surv_lime <- function(x, }) return(patchwork::wrap_plots(pl, pl2, nrow = 1, widths = c(3, 5))) } + pl } From b3b466b471713ad1402b69aa3f7b13be7805b38d Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 16:38:02 +0200 Subject: [PATCH 156/207] fix model_survshap plots --- R/plot_surv_shap.R | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 5028dbd4..945ea117 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -180,16 +180,24 @@ plot.aggregated_surv_shap <- function(x, "importance" = plot_shap_global_importance( x = x, ... = ..., + title = title, + subtitle = subtitle, + max_vars = max_vars, colors = colors ), "beeswarm" = plot_shap_global_beeswarm( x = x, ... = ..., + title = title, + subtitle = subtitle, + max_vars = max_vars, colors = colors ), "profile" = plot_shap_global_profile( x = x, ... = ..., + title = title, + subtitle = subtitle, colors = colors ), stop("`geom` must be one of 'importance', 'beeswarm' or 'profile'") @@ -224,7 +232,7 @@ plot_shap_global_importance <- function(x, long_df <- stack(x$aggregate) long_df <- long_df[order(long_df$values, decreasing = TRUE), ][1:min(max_vars, length(x$aggregate)), ] - if (!is.null(subtitle) && subtitle == "default") { + if (!is.null(title) && title == "default") { title <- "Feature importance according to aggregated |SurvSHAP(t)|" } if (!is.null(subtitle) && subtitle == "default") { @@ -273,7 +281,7 @@ plot_shap_global_beeswarm <- function(x, df <- cbind(df, var_value) label <- attr(x, "label") - if (!is.null(subtitle) && subtitle == "default") { + if (!is.null(title) && title == "default") { title <- "Aggregated SurvSHAP(t) values summary" } if (!is.null(subtitle) && subtitle == "default") { @@ -333,7 +341,7 @@ plot_shap_global_profile <- function(x, colnames(df) <- c("shap_val", "variable_val", "color_variable_val") label <- attr(x, "label") - if (!is.null(subtitle) && subtitle == "default") { + if (!is.null(title) && title == "default") { title <- "Aggregated SurvSHAP(t) profile" } if (!is.null(subtitle) && subtitle == "default") { From 0b1790da24fa6cb9245f5a72965aaa8b80554c8f Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 22:39:57 +0200 Subject: [PATCH 157/207] change wording to be consistent --- vignettes/pdp.Rmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vignettes/pdp.Rmd b/vignettes/pdp.Rmd index 26247d00..4484b8eb 100644 --- a/vignettes/pdp.Rmd +++ b/vignettes/pdp.Rmd @@ -73,7 +73,7 @@ The same plot can be generated for the categorical `celltype` variable. In this plot(pdp, geom = "variable", variables = "celltype", times = 80) ``` -Of course, the plots can be prepared for multiple timepoints, at the same time and presented on one plot. +Of course, the plots can be prepared for multiple time points, at the same time and presented on one plot. ```{r} plot(pdp, geom = "variable", variables = "karno", times = c(1, 80, 151.72)) From 3a2ff246fa222c395cbe7bcfdc1f0f5112926c2a Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 22:41:21 +0200 Subject: [PATCH 158/207] move plotting cp to plot_predict_profile_survival, add plot with `geom = "variable"` --- R/plot_predict_profile_survival.R | 295 ++++++++++++++++++++++++--- R/plot_surv_ceteris_paribus.R | 121 +---------- man/plot.predict_profile_survival.Rd | 70 ++++--- man/plot.surv_ceteris_paribus.Rd | 77 ------- 4 files changed, 311 insertions(+), 252 deletions(-) delete mode 100644 man/plot.surv_ceteris_paribus.Rd diff --git a/R/plot_predict_profile_survival.R b/R/plot_predict_profile_survival.R index 3da481f5..8bbf930b 100644 --- a/R/plot_predict_profile_survival.R +++ b/R/plot_predict_profile_survival.R @@ -1,32 +1,25 @@ #' Plot Predict Profile for Survival Models #' -#' This function plots objects of class `"predict_profile_survival"` - local explanations -#' for survival models created using the `predict_profile()` function. +#' This function plots objects of class `"predict_profile_survival"` created using +#' the `predict_profile()` function. #' -#' @param x an object of class `"predict_profile_survival"` to be plotted -#' @param ... additional parameters passed to the `plot.surv_ceteris_paribus` function +#' @param x an object of class `predict_profile_survival` to be plotted +#' @param ... additional objects of class `"predict_profile_survival"` to be plotted together. Only available for `geom = "time"`. +#' @param geom character, either `"time"` or `"variable"`. Selects the type of plot to be prepared. If `"time"` then the x-axis represents survival times, and variable is denoted by colors, if `"variable"` then the x-axis represents the variable values, and y-axis represents the predictions at selected time points. +#' @param variables character, names of the variables to be plotted. When `geom = "variable"` it needs to be a name of a single variable, when `geom = "time"` it can be a vector of variable names. If `NULL` (default) then first variable (for `geom = "variable"`) or all variables (for `geom = "time"`) are plotted. +#' @param variable_type character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all. Only used when `geom = "time"`. +#' @param facet_ncol number of columns for arranging subplots. Only used when `geom = "time"`. +#' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots. Only used when `geom = "time"`. +#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If `NULL` (default) then the median time from the explainer object is used. Only used when `geom = "variable"` and `marginalize_over_time = FALSE`. +#' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately. Only used when `geom = "variable"`. +#' @param title character, title of the plot +#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels +#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") +#' @param rug character, one of `"all"`, `"events"`, `"censors"`, `"none"` or `NULL`. Which times to mark on the x axis in `geom_rug()`. +#' @param rug_colors character vector containing two colors (containing either hex codes `"#FF69B4"`, or names `"blue"`). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. #' #' @return A collection of `ggplot` objects arranged with the `patchwork` package. #' -#' @section Plot options: -#' -#' ## `plot.surv_ceteris_paribus` -#' -#' * `x` - an object of class `predict_profile_survival` to be plotted -#' * `...` - additional parameters, unused, currently ignored -#' * `colors` - character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") -#' * `variable_type` - character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all -#' * `facet_ncol` - number of columns for arranging subplots -#' * `variables` - character, names of the variables to be plotted -#' * `numerical_plot_type` - character, either `"lines"`, or `"contours"` selects the type of numerical variable plots -#' * `title` - character, title of the plot -#' * `subtitle` - character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels -#' * `rug` - character, one of `"all"`, `"events"`, `"censors"`, `"none"` or `NULL`. Which times to mark on the x axis in `geom_rug()`. -#' * `rug_colors` - character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. -#' -#' -#' @family functions for plotting 'predict_profile_survival' objects -#' #' @examples #' \donttest{ #' library(survival) @@ -47,8 +40,258 @@ #' #' plot(p_profile_with_cat) #' } +#' #' @export -plot.predict_profile_survival <- function(x, ...) { - class(x) <- class(x)[-1] - plot(x, ...) +plot.predict_profile_survival <- function(x, + ..., + geom = "time", + variables = NULL, + variable_type = NULL, + facet_ncol = NULL, + numerical_plot_type = "lines", + times = NULL, + marginalize_over_time = FALSE, + title = "default", + subtitle = "default", + colors = NULL, + rug = "all", + rug_colors = c("#dd0000", "#222222")) { + + if (!geom %in% c("time", "variable")) { + stop("`geom` needs to be one of 'time' or 'variable'.") + } + + if (!(numerical_plot_type %in% c("lines", "contours"))) { + stop("`numerical_plot_type` needs to be 'lines' or 'contours'") + } + + if (!is.null(variable_type) && + !(variable_type %in% c("numerical", "categorical"))) { + stop("`variable_type` needs to be 'numerical' or 'categorical'") + } + + if (title == "default"){ + title <- "Ceteris paribus survival profile" + } + + if (geom == "variable") { + pl <- plot2_cp( + x = x, + variable = variables, + times = times, + marginalize_over_time = marginalize_over_time, + ... = ..., + title = title, + subtitle = subtitle, + colors = colors + ) + if (x$center) { + pl <- pl + labs(y = "centered profile value") + } + return(pl) + } + + lapply(list(x, ...), function(x) { + if (!inherits(x, "predict_profile_survival")) { + stop("All ... must be objects of class `predict_profile_survival`.") + } + }) + explanations_list <- c(list(x), list(...)) + num_models <- length(explanations_list) + + if (num_models == 1) { + result <- prepare_ceteris_paribus_plots( + x, + colors, + variable_type, + facet_ncol, + variables, + numerical_plot_type, + title, + subtitle, + rug, + rug_colors + ) + return(result) + } + + return_list <- list() + labels <- list() + for (i in 1:num_models) { + this_title <- unique(explanations_list[[i]]$result$`_label_`) + return_list[[i]] <- prepare_ceteris_paribus_plots( + explanations_list[[i]], + colors, + variable_type, + 1, + variables, + numerical_plot_type, + this_title, + NULL, + rug, + rug_colors + ) + labels[[i]] <- c(this_title, rep("", length(return_list[[i]]$patches) - 2)) + } + + labels <- unlist(labels) + return_plot <- patchwork::wrap_plots(return_list, nrow = 1, tag_level = "keep") + + patchwork::plot_annotation(title, tag_levels = list(labels)) & theme_default_survex() + + return(return_plot) +} + +#' @keywords internal +plot2_cp <- function(x, + variable, + times = NULL, + marginalize_over_time = FALSE, + ..., + title = "default", + subtitle = "default", + colors = NULL) { + if (is.null(variable)) { + variable <- unique(x$result$`_vname_`)[1] + warning("Plot will be prepared for the first variable from the explanation `result`. \nFor another variable, set the value of `variable`.") + } + + if (!is.character(variable)) { + stop("The variable must be specified by name") + } + + if (length(variable) > 1) { + stop("Only one variable can be specified for `geom`='variable'") + } + + if (!variable %in% x$result$`_vname_`) { + stop(paste0("Variable ", variable, " not found")) + } + + if (is.null(times)) { + if (marginalize_over_time){ + times <- x$eval_times + warning("Plot will be prepared with marginalization over all time points from the explainer's `times` vector. \nFor subset of time points, set the value of `times`.") + } else{ + times <- quantile(x$eval_times, p = 0.5, type = 1) + warning("Plot will be prepared for the median time point from the explainer's `times` vector. \nFor another time point, set the value of `times`.") + } + } + + if (!all(times %in% x$eval_times)) { + stop(paste0( + "For one of the provided times the explanations has not been calculated not found. + Please modify the times argument in your explainer or use only values from the following: ", + paste(x$eval_times, collapse = ", ") + )) + } + + single_timepoint <- ((length(times) == 1) || marginalize_over_time) + if (!is.null(subtitle) && subtitle == "default") { + subtitle <- paste0("created for the ", unique(variable), " variable") + if (single_timepoint && !marginalize_over_time) { + subtitle <- paste0(subtitle, " and time =", times) + } + } + + is_categorical <- (unique(x$result[x$result$`_vname_` == variable, "_vtype_"]) == "categorical") + + if (single_timepoint) { + ice_df <- x$result[(x$result$`_vname_` == variable) & (x$result$`_times_` %in% times), c(variable, "_yhat_")] + colnames(ice_df) <- c(variable, "ice") + } else { + ice_df <- x$result[(x$result$`_vname_` == variable) & (x$result$`_times_` %in% times), c(variable, "_times_", "_yhat_")] + colnames(ice_df) <- c(variable, "time", "ice") + ice_df$time <- as.factor(ice_df$time) + } + + if (is_categorical) { + ice_df[, variable] <- as.factor(ice_df[, variable]) + } + + feature_name_sym <- sym(variable) + + if (marginalize_over_time) { + ice_df <- aggregate(ice ~ ., data = ice_df, mean) + color_scale <- generate_discrete_color_scale(1, colors) + } else { + if (is.null(colors) || length(colors) < 3) { + color_scale <- c( + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3" + ) + } + } + + if (is_categorical) { + pl <- plot_ice_cat( + ice_df = ice_df, + feature_name_sym = feature_name_sym, + single_timepoint = single_timepoint, + colors = color_scale + ) + } else { + ice_df[, 1] <- as.numeric(as.character(ice_df[, 1])) + x_width <- diff(range(ice_df[, variable])) + pl <- plot_ice_num( + ice_df = ice_df, + feature_name_sym = feature_name_sym, + x_width = x_width, + single_timepoint = single_timepoint, + colors = color_scale + ) + } + + pl + + labs( + title = title, + subtitle = subtitle + ) + + theme_default_survex() +} + + + +plot_ice_num <-function(ice_df, + feature_name_sym, + x_width = x_width, + single_timepoint, + colors){ + with(ice_df, { + if (single_timepoint == TRUE) { + ggplot(data = ice_df, aes(x = !!feature_name_sym, y = ice)) + + geom_line(color = colors_discrete_drwhy(1), linewidth=1) + } else { + ice_df$time <- as.numeric(as.character(ice_df$time)) + ggplot(data = ice_df, aes(x = !!feature_name_sym, y = ice)) + + geom_line(aes(color = time, group = time), linewidth=1) + + scale_colour_gradient2( + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = median(as.numeric(as.character(pdp_dt$time))) + ) + } + }) +} + + +plot_ice_cat <- function(ice_df, + feature_name_sym, + single_timepoint, + colors){ + with(ice_df, { + if (single_timepoint == TRUE) { + ggplot(data = ice_df, aes(x = !!feature_name_sym, y = ice), ) + + geom_bar(stat = "identity", width = 0.5, fill = colors_discrete_drwhy(1)) + + scale_y_continuous() + + scale_fill_manual(name = "time", values = colors) + + geom_hline(yintercept = 0, linetype="dashed") + } else { + ggplot(data = ice_df, aes(x = !!feature_name_sym, y = ice, fill = time)) + + geom_bar(stat = "identity", width = 0.5, position = "dodge") + + scale_y_continuous() + + scale_fill_manual(name = "time", values = colors) + } + }) } diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index b98c0dd5..f148a604 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -1,115 +1,10 @@ -#' Plot Predict Profile for Survival Models -#' -#' This function plots objects of class `"predict_profile_survival"` created using -#' the `predict_profile()` function. -#' -#' @param x an object of class `predict_profile_survival` to be plotted -#' @param ... additional objects of class `"predict_profile_survival"` to be plotted together -#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") -#' @param variable_type character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all -#' @param facet_ncol number of columns for arranging subplots -#' @param variables character, names of the variables to be plotted -#' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots -#' @param title character, title of the plot -#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels -#' @param rug character, one of `"all"`, `"events"`, `"censors"`, `"none"` or `NULL`. Which times to mark on the x axis in `geom_rug()`. -#' @param rug_colors character vector containing two colors (containing either hex codes `"#FF69B4"`, or names `"blue"`). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. -#' -#' @return A collection of `ggplot` objects arranged with the `patchwork` package. -#' -#' @family functions for plotting 'predict_profile_survival' objects -#' -#' @examples -#' \donttest{ -#' library(survival) -#' library(survex) -#' -#' model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) -#' exp <- explain(model) -#' -#' p_profile <- predict_profile(exp, veteran[1, -c(3, 4)]) -#' -#' plot(p_profile) -#' -#' p_profile_with_cat <- predict_profile( -#' exp, -#' veteran[1, -c(3, 4)], -#' categorical_variables = c("trt", "prior") -#' ) -#' -#' plot(p_profile_with_cat) -#' } -#' -#' @export -plot.surv_ceteris_paribus <- function(x, - ..., - colors = NULL, - variable_type = NULL, - facet_ncol = NULL, - variables = NULL, - numerical_plot_type = "lines", - title = "Ceteris paribus survival profile", - subtitle = "default", - rug = "all", - rug_colors = c("#dd0000", "#222222")) { - if (!is.null(variable_type)) { - check_variable_type(variable_type) - } - check_numerical_plot_type(numerical_plot_type) - - explanations_list <- c(list(x), list(...)) - num_models <- length(explanations_list) - - if (num_models == 1) { - result <- prepare_ceteris_paribus_plots( - x, - colors, - variable_type, - facet_ncol, - variables, - numerical_plot_type, - title, - subtitle, - rug, - rug_colors - ) - return(result) - } - - return_list <- list() - labels <- list() - for (i in 1:num_models) { - this_title <- unique(explanations_list[[i]]$result$`_label_`) - return_list[[i]] <- prepare_ceteris_paribus_plots( - explanations_list[[i]], - colors, - variable_type, - 1, - variables, - numerical_plot_type, - this_title, - NULL, - rug, - rug_colors - ) - labels[[i]] <- c(this_title, rep("", length(return_list[[i]]$patches) - 2)) - } - - labels <- unlist(labels) - return_plot <- patchwork::wrap_plots(return_list, nrow = 1, tag_level = "keep") + - patchwork::plot_annotation(title, tag_levels = list(labels)) & theme_default_survex() - - return(return_plot) -} - - prepare_ceteris_paribus_plots <- function(x, colors = NULL, variable_type = NULL, facet_ncol = NULL, variables = NULL, numerical_plot_type = "lines", - title = "Ceteris paribus survival profile", + title = "default", subtitle = "default", rug = "all", rug_colors = c("#dd0000", "#222222")) { @@ -118,7 +13,6 @@ prepare_ceteris_paribus_plots <- function(x, center <- x$center x <- x$result - all_profiles <- x class(all_profiles) <- "data.frame" @@ -349,16 +243,3 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, pl } - - -check_variable_type <- function(variable_type) { - if (!(variable_type %in% c("numerical", "categorical"))) { - stop("variable_type needs to be 'numerical' or 'categorical'") - } -} - -check_numerical_plot_type <- function(numerical_plot_type) { - if (!(numerical_plot_type %in% c("lines", "contours"))) { - stop("numerical_plot_type needs to be 'lines' or 'contours'") - } -} diff --git a/man/plot.predict_profile_survival.Rd b/man/plot.predict_profile_survival.Rd index 760123ca..d9e07c58 100644 --- a/man/plot.predict_profile_survival.Rd +++ b/man/plot.predict_profile_survival.Rd @@ -4,39 +4,55 @@ \alias{plot.predict_profile_survival} \title{Plot Predict Profile for Survival Models} \usage{ -\method{plot}{predict_profile_survival}(x, ...) +\method{plot}{predict_profile_survival}( + x, + ..., + geom = "time", + variables = NULL, + variable_type = NULL, + facet_ncol = NULL, + numerical_plot_type = "lines", + times = NULL, + marginalize_over_time = FALSE, + title = "default", + subtitle = "default", + colors = NULL, + rug = "all", + rug_colors = c("#dd0000", "#222222") +) } \arguments{ -\item{x}{an object of class \code{"predict_profile_survival"} to be plotted} +\item{x}{an object of class \code{predict_profile_survival} to be plotted} + +\item{...}{additional objects of class \code{"predict_profile_survival"} to be plotted together. Only available for \code{geom = "time"}.} + +\item{geom}{character, either \code{"time"} or \code{"variable"}. Selects the type of plot to be prepared. If \code{"time"} then the x-axis represents survival times, and variable is denoted by colors, if \code{"variable"} then the x-axis represents the variable values, and y-axis represents the predictions at selected time points.} + +\item{variables}{character, names of the variables to be plotted. When \code{geom = "variable"} it needs to be a name of a single variable, when \code{geom = "time"} it can be a vector of variable names. If \code{NULL} (default) then first variable (for \code{geom = "variable"}) or all variables (for \code{geom = "time"}) are plotted.} + +\item{variable_type}{character, either \code{"numerical"}, \code{"categorical"} or \code{NULL} (default), select only one type of variable for plotting, or leave \code{NULL} for all. Only used when \code{geom = "time"}.} + +\item{facet_ncol}{number of columns for arranging subplots. Only used when \code{geom = "time"}.} + +\item{numerical_plot_type}{character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots. Only used when \code{geom = "time"}.} + +\item{title}{character, title of the plot} + +\item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} -\item{...}{additional parameters passed to the \code{plot.surv_ceteris_paribus} function} +\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} + +\item{rug}{character, one of \code{"all"}, \code{"events"}, \code{"censors"}, \code{"none"} or \code{NULL}. Which times to mark on the x axis in \code{geom_rug()}.} + +\item{rug_colors}{character vector containing two colors (containing either hex codes \code{"#FF69B4"}, or names \code{"blue"}). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times.} } \value{ A collection of \code{ggplot} objects arranged with the \code{patchwork} package. } \description{ -This function plots objects of class \code{"predict_profile_survival"} - local explanations -for survival models created using the \code{predict_profile()} function. -} -\section{Plot options}{ - -\subsection{\code{plot.surv_ceteris_paribus}}{ -\itemize{ -\item \code{x} - an object of class \code{predict_profile_survival} to be plotted -\item \code{...} - additional parameters, unused, currently ignored -\item \code{colors} - character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") -\item \code{variable_type} - character, either \code{"numerical"}, \code{"categorical"} or \code{NULL} (default), select only one type of variable for plotting, or leave \code{NULL} for all -\item \code{facet_ncol} - number of columns for arranging subplots -\item \code{variables} - character, names of the variables to be plotted -\item \code{numerical_plot_type} - character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots -\item \code{title} - character, title of the plot -\item \code{subtitle} - character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels -\item \code{rug} - character, one of \code{"all"}, \code{"events"}, \code{"censors"}, \code{"none"} or \code{NULL}. Which times to mark on the x axis in \code{geom_rug()}. -\item \code{rug_colors} - character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. -} +This function plots objects of class \code{"predict_profile_survival"} created using +the \code{predict_profile()} function. } -} - \examples{ \donttest{ library(survival) @@ -57,9 +73,5 @@ p_profile_with_cat <- predict_profile( plot(p_profile_with_cat) } + } -\seealso{ -Other functions for plotting 'predict_profile_survival' objects: -\code{\link{plot.surv_ceteris_paribus}()} -} -\concept{functions for plotting 'predict_profile_survival' objects} diff --git a/man/plot.surv_ceteris_paribus.Rd b/man/plot.surv_ceteris_paribus.Rd deleted file mode 100644 index b675ba7c..00000000 --- a/man/plot.surv_ceteris_paribus.Rd +++ /dev/null @@ -1,77 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot_surv_ceteris_paribus.R -\name{plot.surv_ceteris_paribus} -\alias{plot.surv_ceteris_paribus} -\title{Plot Predict Profile for Survival Models} -\usage{ -\method{plot}{surv_ceteris_paribus}( - x, - ..., - colors = NULL, - variable_type = NULL, - facet_ncol = NULL, - variables = NULL, - numerical_plot_type = "lines", - title = "Ceteris paribus survival profile", - subtitle = "default", - rug = "all", - rug_colors = c("#dd0000", "#222222") -) -} -\arguments{ -\item{x}{an object of class \code{predict_profile_survival} to be plotted} - -\item{...}{additional objects of class \code{"predict_profile_survival"} to be plotted together} - -\item{colors}{character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")} - -\item{variable_type}{character, either \code{"numerical"}, \code{"categorical"} or \code{NULL} (default), select only one type of variable for plotting, or leave \code{NULL} for all} - -\item{facet_ncol}{number of columns for arranging subplots} - -\item{variables}{character, names of the variables to be plotted} - -\item{numerical_plot_type}{character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots} - -\item{title}{character, title of the plot} - -\item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} - -\item{rug}{character, one of \code{"all"}, \code{"events"}, \code{"censors"}, \code{"none"} or \code{NULL}. Which times to mark on the x axis in \code{geom_rug()}.} - -\item{rug_colors}{character vector containing two colors (containing either hex codes \code{"#FF69B4"}, or names \code{"blue"}). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times.} -} -\value{ -A collection of \code{ggplot} objects arranged with the \code{patchwork} package. -} -\description{ -This function plots objects of class \code{"predict_profile_survival"} created using -the \code{predict_profile()} function. -} -\examples{ -\donttest{ -library(survival) -library(survex) - -model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) -exp <- explain(model) - -p_profile <- predict_profile(exp, veteran[1, -c(3, 4)]) - -plot(p_profile) - -p_profile_with_cat <- predict_profile( - exp, - veteran[1, -c(3, 4)], - categorical_variables = c("trt", "prior") -) - -plot(p_profile_with_cat) -} - -} -\seealso{ -Other functions for plotting 'predict_profile_survival' objects: -\code{\link{plot.predict_profile_survival}()} -} -\concept{functions for plotting 'predict_profile_survival' objects} From 21e00dc291cd0acca4f9708d262086f10e4264f1 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 22:44:05 +0200 Subject: [PATCH 159/207] fix typos, reorder parameters to be consistent with predict --- R/plot_model_profile_survival.R | 54 ++++++++++++++++++++---------- man/plot.model_profile_survival.Rd | 18 +++++----- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 2dd0c587..5c56b8a2 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -4,15 +4,15 @@ #' using the `model_profile()` function. #' #' @param x an object of class `model_profile_survival` to be plotted -#' @param ... additional objects of class `model_profile_survival` to be plotted together. Only available for `geom = "time"` -#' @param geom character, either "time" or "variable". Selects the type of plot to be prepared. If `"time"` then the x-axis represents survival times, and variable is denoted by colors, if `"variable"` then the x-axis represents the variable values, and mean predictions at selected timepoints. -#' @param variables character, names of the variables to be plotted. When `geom = "variable"` it needs to be a name of a single variable, when `geom = "time"` it can be a vector of variable names. If `NULL` (default) then all variables are plotted. -#' @param variable_type character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all. Only used when `geom = "time"` -#' @param facet_ncol number of columns for arranging subplots. Only used when `geom = "time"` -#' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots. Only used when `geom = "time"` -#' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately. Only used when `geom = "variable"` -#' @param plot_type character, one of `"pdp"`, `"ice"`, `"pdp+ice"`, or `NULL` (default). If `NULL` then the type of plot is chosen automatically based on the number of variables to be plotted. Only used when `geom = "variable"` -#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If `NULL` (default) then the median time from the explainer object is used. Only used when `geom = "variable"` and `marginalize_over_time = FALSE +#' @param ... additional objects of class `model_profile_survival` to be plotted together. Only available for `geom = "time"`. +#' @param geom character, either `"time"` or `"variable"`. Selects the type of plot to be prepared. If `"time"` then the x-axis represents survival times, and variable is denoted by colors, if `"variable"` then the x-axis represents the variable values, and y-axis represents the predictions at selected time points. +#' @param variables character, names of the variables to be plotted. When `geom = "variable"` it needs to be a name of a single variable, when `geom = "time"` it can be a vector of variable names. If `NULL` (default) then first variable (for `geom = "variable"`) or all variables (for `geom = "time"`) are plotted. +#' @param variable_type character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all. Only used when `geom = "time"`. +#' @param facet_ncol number of columns for arranging subplots. Only used when `geom = "time"`. +#' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots. Only used when `geom = "time"`. +#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If `NULL` (default) then the median time from the explainer object is used. Only used when `geom = "variable"` and `marginalize_over_time = FALSE`. +#' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately. Only used when `geom = "variable"`. +#' @param plot_type character, one of `"pdp"`, `"ice"`, `"pdp+ice"`, or `NULL` (default). If `NULL` then the type of plot is chosen automatically based on the number of variables to be plotted. Only used when `geom = "variable"`. #' @param title character, title of the plot #' @param subtitle character, subtitle of the plot, `"default"` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels #' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue"). @@ -61,7 +61,16 @@ plot.model_profile_survival <- function(x, rug = "all", rug_colors = c("#dd0000", "#222222")) { if (!geom %in% c("time", "variable")) { - stop("`geom` must be one of 'time' or 'survival'.") + stop("`geom` needs to be one of 'time' or 'variable'.") + } + + if (!(numerical_plot_type %in% c("lines", "contours"))) { + stop("`numerical_plot_type` needs to be 'lines' or 'contours'") + } + + if (!is.null(variable_type) && + !(variable_type %in% c("numerical", "categorical"))) { + stop("`variable_type` needs to be 'numerical' or 'categorical'") } if (title == "default") { @@ -77,7 +86,7 @@ plot.model_profile_survival <- function(x, } if (geom == "variable") { - pl <- plot2( + pl <- plot2_mp( x = x, variable = variables, times = times, @@ -100,7 +109,6 @@ plot.model_profile_survival <- function(x, } }) explanations_list <- c(list(x), list(...)) - num_models <- length(explanations_list) if (num_models == 1) { @@ -147,7 +155,7 @@ plot.model_profile_survival <- function(x, #' @keywords internal -plot2 <- function(x, +plot2_mp <- function(x, variable, times = NULL, marginalize_over_time = FALSE, @@ -178,7 +186,12 @@ plot2 <- function(x, stop("plot_type must be one of 'pdp', 'ice', 'pdp+ice'") } - if (is.null(variable) || !is.character(variable)) { + if (is.null(variable)) { + variable <- unique(x$result$`_vname_`)[1] + warning("Plot will be prepared for the first variable from the explanation `result`. \nFor another variable, set the value of `variable`.") + } + + if (!is.character(variable)) { stop("The variable must be specified by name") } @@ -191,8 +204,13 @@ plot2 <- function(x, } if (is.null(times)) { - times <- quantile(x$eval_times, p = 0.5, type = 1) - warning("Plot will be prepared for the median time point from the explainer's `times` vector. For another time point, set the value of `times`.") + if (marginalize_over_time){ + times <- x$eval_times + warning("Plot will be prepared with marginalization over all time points from the explainer's `times` vector. \nFor subset of time points, set the value of `times`.") + } else{ + times <- quantile(x$eval_times, p = 0.5, type = 1) + warning("Plot will be prepared for the median time point from the explainer's `times` vector. \nFor another time point, set the value of `times`.") + } } if (!all(times %in% x$eval_times)) { @@ -207,7 +225,7 @@ plot2 <- function(x, if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", unique(variable), " variable") if (single_timepoint && !marginalize_over_time) { - subtitle <- paste0(subtitle, " and time=", times) + subtitle <- paste0(subtitle, " and time =", times) } } @@ -354,7 +372,7 @@ plot_pdp_num <- function(pdp_dt, geom_rug(data = data_dt, aes(x = !!feature_name_sym, y = y_ceiling_pd), sides = "b", alpha = 0.8, position = position_jitter(width = 0.01 * x_width)) + ylim(y_floor_pd, y_ceiling_pd) } - } else { ## multiple timepoints + } else { ## multiple time points pdp_dt$time <- as.numeric(as.character(pdp_dt$time)) if (!is.null(ice_dt)) { ice_dt$time <- as.numeric(as.character(ice_dt$time)) diff --git a/man/plot.model_profile_survival.Rd b/man/plot.model_profile_survival.Rd index 9c548ab9..db8d8865 100644 --- a/man/plot.model_profile_survival.Rd +++ b/man/plot.model_profile_survival.Rd @@ -25,23 +25,23 @@ \arguments{ \item{x}{an object of class \code{model_profile_survival} to be plotted} -\item{...}{additional objects of class \code{model_profile_survival} to be plotted together. Only available for \code{geom = "time"}} +\item{...}{additional objects of class \code{model_profile_survival} to be plotted together. Only available for \code{geom = "time"}.} -\item{geom}{character, either "time" or "variable". Selects the type of plot to be prepared. If \code{"time"} then the x-axis represents survival times, and variable is denoted by colors, if \code{"variable"} then the x-axis represents the variable values, and mean predictions at selected timepoints.} +\item{geom}{character, either \code{"time"} or \code{"variable"}. Selects the type of plot to be prepared. If \code{"time"} then the x-axis represents survival times, and variable is denoted by colors, if \code{"variable"} then the x-axis represents the variable values, and y-axis represents the predictions at selected time points.} -\item{variables}{character, names of the variables to be plotted. When \code{geom = "variable"} it needs to be a name of a single variable, when \code{geom = "time"} it can be a vector of variable names. If \code{NULL} (default) then all variables are plotted.} +\item{variables}{character, names of the variables to be plotted. When \code{geom = "variable"} it needs to be a name of a single variable, when \code{geom = "time"} it can be a vector of variable names. If \code{NULL} (default) then first variable (for \code{geom = "variable"}) or all variables (for \code{geom = "time"}) are plotted.} -\item{variable_type}{character, either \code{"numerical"}, \code{"categorical"} or \code{NULL} (default), select only one type of variable for plotting, or leave \code{NULL} for all. Only used when \code{geom = "time"}} +\item{variable_type}{character, either \code{"numerical"}, \code{"categorical"} or \code{NULL} (default), select only one type of variable for plotting, or leave \code{NULL} for all. Only used when \code{geom = "time"}.} -\item{facet_ncol}{number of columns for arranging subplots. Only used when \code{geom = "time"}} +\item{facet_ncol}{number of columns for arranging subplots. Only used when \code{geom = "time"}.} -\item{numerical_plot_type}{character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots. Only used when \code{geom = "time"}} +\item{numerical_plot_type}{character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots. Only used when \code{geom = "time"}.} -\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used. Only used when \code{geom = "variable"} and `marginalize_over_time = FALSE} +\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used. Only used when \code{geom = "variable"} and \code{marginalize_over_time = FALSE}.} -\item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately. Only used when \code{geom = "variable"}} +\item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately. Only used when \code{geom = "variable"}.} -\item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"}, or \code{NULL} (default). If \code{NULL} then the type of plot is chosen automatically based on the number of variables to be plotted. Only used when \code{geom = "variable"}} +\item{plot_type}{character, one of \code{"pdp"}, \code{"ice"}, \code{"pdp+ice"}, or \code{NULL} (default). If \code{NULL} then the type of plot is chosen automatically based on the number of variables to be plotted. Only used when \code{geom = "variable"}.} \item{title}{character, title of the plot} From 323328034bdb95e96204bfbbce0d5b4b7472b189 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 22:44:23 +0200 Subject: [PATCH 160/207] remove redundant method --- NAMESPACE | 1 - 1 file changed, 1 deletion(-) diff --git a/NAMESPACE b/NAMESPACE index 5e354fb4..d8e202b5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -21,7 +21,6 @@ S3method(plot,model_profile_2d_survival) S3method(plot,model_profile_survival) S3method(plot,predict_parts_survival) S3method(plot,predict_profile_survival) -S3method(plot,surv_ceteris_paribus) S3method(plot,surv_feature_importance) S3method(plot,surv_lime) S3method(plot,surv_model_performance) From bcc670afb1e9237e2c648c2b89673f87759a39a3 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 22:44:42 +0200 Subject: [PATCH 161/207] add warnings --- R/plot_surv_shap.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 945ea117..a47351e0 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -327,9 +327,11 @@ plot_shap_global_profile <- function(x, if (is.null(variable)) { variable <- colnames(df)[1] + warning("`variable` was not specified, the first from the result will be used.") } if (is.null(color_variable)) { color_variable <- variable + warning("`color_variable` was not specified, the first from the result will be used.") } shap_val <- df[, variable] From 4d50c732c0d41254099629911a418a6a817476e2 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 22:47:08 +0200 Subject: [PATCH 162/207] timepoints -> time points --- R/plot_surv_model_performance_rocs.R | 2 +- R/surv_model_profiles.R | 2 +- man/plot.surv_model_performance_rocs.Rd | 2 +- tests/testthat/test-model_profile.R | 16 ++++++++-------- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/R/plot_surv_model_performance_rocs.R b/R/plot_surv_model_performance_rocs.R index 33806e80..b1aa0a2b 100644 --- a/R/plot_surv_model_performance_rocs.R +++ b/R/plot_surv_model_performance_rocs.R @@ -29,7 +29,7 @@ #' @export plot.surv_model_performance_rocs <- function(x, ..., - title = "ROC curves for selected timepoints", + title = "ROC curves for selected time points", subtitle = "default", auc = TRUE, colors = NULL, diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index fea98745..6a958bc2 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -283,7 +283,7 @@ surv_ale <- function(x, # First order finite differences prediction_deltas <- predictions_upper - predictions_lower - # Rename columns to timepoints for which predictions were made + # Rename columns to time points for which predictions were made colnames(prediction_deltas) <- times deltas <- data.frame( diff --git a/man/plot.surv_model_performance_rocs.Rd b/man/plot.surv_model_performance_rocs.Rd index d76175f6..4b528253 100644 --- a/man/plot.surv_model_performance_rocs.Rd +++ b/man/plot.surv_model_performance_rocs.Rd @@ -7,7 +7,7 @@ \method{plot}{surv_model_performance_rocs}( x, ..., - title = "ROC curves for selected timepoints", + title = "ROC curves for selected time points", subtitle = "default", auc = TRUE, colors = NULL, diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 0530404f..4bc7a998 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -28,11 +28,11 @@ test_that("model_profile with type = 'partial' works", { plot(mp_cph_num, numerical_plot_type = "contours") ### Add tests for plot for numerical PDP - # single timepoint + # single time point plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "pdp+ice", times = cph_exp$times[1]) plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "pdp", times = cph_exp$times[1]) plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "ice", times = cph_exp$times[1]) - # multiple timepoints + # multiple time points plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "pdp+ice") plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "pdp") plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "ice") @@ -55,11 +55,11 @@ test_that("model_profile with type = 'partial' works", { plot(mp_cph_cat, mp_rsf_cat) ### Add tests for plot for categorical PDP - # single timepoint + # single time point plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "pdp+ice", times = rsf_ranger_exp$times[1]) plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "pdp", times = rsf_ranger_exp$times[1]) plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "ice", times = rsf_ranger_exp$times[1]) - # multiple timepoints + # multiple time points plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "pdp+ice") plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), marginalize_over_time = T, variables = "celltype", plot_type = "pdp+ice") plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "pdp") @@ -115,9 +115,9 @@ test_that("model_profile with type = 'accumulated' works", { plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") ### Add tests for plot for categorical ALE - # single timepoint + # single time point plot(mp_cph_cat, geom = "variable", variables = "celltype", times=cph_exp$times[1]) - # multiple timepoints + # multiple time points plot(mp_cph_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "ale") expect_s3_class(mp_cph_cat, "model_profile_survival") @@ -139,9 +139,9 @@ test_that("model_profile with type = 'accumulated' works", { plot(mp_cph_num, numerical_plot_type = "contours") ### Add tests for plot for numerical ALE - # single timepoint + # single time point plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "ale", times=cph_exp$times[1]) - # multiple timepoints + # multiple time points plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "ale") expect_s3_class(mp_cph_num, "model_profile_survival") From 48e0e8610ffbe052a2608b2384cb5a729b8104de Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 22:54:55 +0200 Subject: [PATCH 163/207] fix pkgdown reference --- _pkgdown.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/_pkgdown.yml b/_pkgdown.yml index 7ee2eff2..70ffca02 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -40,9 +40,7 @@ reference: - plot.surv_shap - plot.surv_lime - subtitle: Predict Profile -- contents: - - plot.predict_profile_survival - - plot.surv_ceteris_paribus +- contents: plot.predict_profile_survival - title: Utility functions - contents: - extract_predict_survshap From 81be892b4c8d4a6d22be11c5ead3cb0f149f6cdb Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 22:55:14 +0200 Subject: [PATCH 164/207] add missing description --- man/plot.predict_profile_survival.Rd | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/man/plot.predict_profile_survival.Rd b/man/plot.predict_profile_survival.Rd index d9e07c58..2720321b 100644 --- a/man/plot.predict_profile_survival.Rd +++ b/man/plot.predict_profile_survival.Rd @@ -36,6 +36,10 @@ \item{numerical_plot_type}{character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots. Only used when \code{geom = "time"}.} +\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used. Only used when \code{geom = "variable"} and \code{marginalize_over_time = FALSE}.} + +\item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately. Only used when \code{geom = "variable"}.} + \item{title}{character, title of the plot} \item{subtitle}{character, subtitle of the plot, \code{'default'} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} From fb1f0ea399bbc90774dee08f6b0a53a4a9c68407 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 22:55:28 +0200 Subject: [PATCH 165/207] update tests --- tests/testthat/test-model_survshap.R | 2 -- tests/testthat/test-predict_profile.R | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index 2508f4d1..1fe1f73f 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -23,7 +23,6 @@ test_that("global survshap explanations with kernelshap work for ranger, using n ) plot(ranger_global_survshap) plot(ranger_global_survshap, geom = "beeswarm") - plot(ranger_global_survshap, geom = "profile") plot(ranger_global_survshap, geom = "profile", variable = "karno", color_variable = "celltype") plot(ranger_global_survshap, geom = "profile", variable = "karno", color_variable = "age") expect_error(plot(ranger_global_survshap, geom = "nonexistent")) @@ -48,7 +47,6 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex ) plot(cph_global_survshap) plot(cph_global_survshap, geom = "beeswarm") - plot(cph_global_survshap, geom = "profile") plot(cph_global_survshap, geom = "profile", variable = "karno", color_variable = "celltype") plot(cph_global_survshap, geom = "profile", variable = "karno", color_variable = "age") diff --git a/tests/testthat/test-predict_profile.R b/tests/testthat/test-predict_profile.R index b66fe9be..11795915 100644 --- a/tests/testthat/test-predict_profile.R +++ b/tests/testthat/test-predict_profile.R @@ -19,6 +19,8 @@ test_that("ceteris_paribus works", { plot(cph_pp, ranger_pp, rug = "events", variables = c("karno", "age")) plot(cph_pp, rug = "censors", variable_type = "numerical") plot(cph_pp, rug = "none") + plot(cph_pp, geom = "variable", variables = "karno", times=cph_exp$times[1]) + plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "karno", marginalize_over_time = TRUE) expect_error(plot(cph_pp, variables = "aaa")) expect_error(plot(cph_pp, variable_type = "nonexistent")) From 10221ca20afe3ac08213afef2a0492cd8a559ffa Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 23:09:36 +0200 Subject: [PATCH 166/207] add tests --- tests/testthat/test-predict_profile.R | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test-predict_profile.R b/tests/testthat/test-predict_profile.R index 11795915..308f58a7 100644 --- a/tests/testthat/test-predict_profile.R +++ b/tests/testthat/test-predict_profile.R @@ -20,11 +20,20 @@ test_that("ceteris_paribus works", { plot(cph_pp, rug = "censors", variable_type = "numerical") plot(cph_pp, rug = "none") plot(cph_pp, geom = "variable", variables = "karno", times=cph_exp$times[1]) - plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "karno", marginalize_over_time = TRUE) + plot(cph_pp, geom = "variable", times=cph_exp$times[1]) + plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "karno") + plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "celltype") + plot(cph_pp, geom = "variable", variables = "karno", marginalize_over_time = TRUE) + plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "celltype", marginalize_over_time = TRUE) expect_error(plot(cph_pp, variables = "aaa")) expect_error(plot(cph_pp, variable_type = "nonexistent")) expect_error(plot(cph_pp, numerical_plot_type = "nonexistent")) + expect_error(plot(cph_pp, geom = "variable", variables = "nonexistent")) + expect_error(plot(cph_pp, geom = "variable", variables = "age", times = -1)) + expect_error(plot(cph_pp, geom = "nonexistent")) + expect_error(plot(cph_pp, geom = "variable", variables = 1, plot_type = "pdp+ice", times = cph_exp$times[1])) + expect_error(plot(cph_pp, geom = "variable", variables = c("karno", "diagtime"), times = cph_exp$times[1])) expect_error(predict_profile(cph_exp, veteran[2, -c(3, 4)], output_type = "nonexistent")) expect_error(predict_profile(cph_exp, veteran[2:3, -c(3, 4)])) expect_error(predict_profile(cph_exp, veteran[2, -c(3, 4)], type = "nonexistent")) @@ -32,6 +41,7 @@ test_that("ceteris_paribus works", { cph_pp_centered <- predict_profile(cph_exp, veteran[2, -c(3, 4)], center = TRUE) plot(cph_pp_centered) plot(cph_pp_centered, numerical_plot_type = "contours") + plot(cph_pp_centered, geom = "variable", variables = "karno") cph_pp_cat <- predict_profile(cph_exp, veteran[2, -c(3, 4)], variables = c("celltype")) plot(predict_profile(cph_exp, veteran[2, -c(3, 4)], categorical_variables = 1)) From 190cc4331a5648a275297d68c56e119048237669 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 23 Aug 2023 23:15:46 +0200 Subject: [PATCH 167/207] fix errors --- tests/interactive_tests/test-new_explainer.R | 2 +- vignettes/survex-usage.Rmd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/interactive_tests/test-new_explainer.R b/tests/interactive_tests/test-new_explainer.R index e5ef600f..e13738a7 100644 --- a/tests/interactive_tests/test-new_explainer.R +++ b/tests/interactive_tests/test-new_explainer.R @@ -105,7 +105,7 @@ test_that("all functionality of a new explainer works correctly", { plot(m_profile) plot(m_profile, variable_type = "categorical") - plot(m_profile, variables = c("diagtime", "prior"), variable_type = "numerical", numerical_plot_type = "contour", facet_ncol = 1) + plot(m_profile, variables = c("diagtime", "prior"), variable_type = "numerical", numerical_plot_type = "contours", facet_ncol = 1) plot(m_profile_subset) diff --git a/vignettes/survex-usage.Rmd b/vignettes/survex-usage.Rmd index aaad1a7a..a5ca0821 100644 --- a/vignettes/survex-usage.Rmd +++ b/vignettes/survex-usage.Rmd @@ -145,7 +145,7 @@ dev.off() ``` ``` {r, fig.height=18} model_profile_rsf <- model_profile(rsf_exp, categorical_variables=c("trt", "prior")) -plot(model_profile_rsf, facet_ncol = 1, numerical_plot_type = "contour") +plot(model_profile_rsf, facet_ncol = 1, numerical_plot_type = "contours") ``` This type of plot also gives us valuable insight that is easy to overlook in the other type. For example, we see a sharp drop in survival function values around `diagtime=25`. We also observe that the most significant influence of the `karno` variable is consistent across the proportional hazards and random survival forest. From 5155d1fba35f03c8aefc0aa0c4dd8c1024fc733f Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 00:19:38 +0200 Subject: [PATCH 168/207] add explainer for scikit-survival models --- DESCRIPTION | 1 + NAMESPACE | 2 + R/explain.R | 147 ++++++++++++++++++++++++++++++++++++----- R/model_info.R | 19 ++++-- man/surv_model_info.Rd | 3 + 5 files changed, 150 insertions(+), 22 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index a5a29e88..36289b7f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -41,6 +41,7 @@ Suggests: progressr, randomForestSRC, ranger, + reticulate, rmarkdown, rms, testthat (>= 3.0.0), diff --git a/NAMESPACE b/NAMESPACE index d8e202b5..9f35a13e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,6 +6,7 @@ S3method(explain,default) S3method(explain,model_fit) S3method(explain,ranger) S3method(explain,rfsrc) +S3method(explain,sksurv) S3method(model_parts,default) S3method(model_parts,surv_explainer) S3method(model_performance,default) @@ -44,6 +45,7 @@ S3method(surv_model_info,default) S3method(surv_model_info,model_fit) S3method(surv_model_info,ranger) S3method(surv_model_info,rfsrc) +S3method(surv_model_info,sksurv) export(brier_score) export(c_index) export(cd_auc) diff --git a/R/explain.R b/R/explain.R index 049d4549..7cb44e8a 100644 --- a/R/explain.R +++ b/R/explain.R @@ -190,7 +190,11 @@ explain_survival <- if (!is.null(attr(data, "verbose_info")) && attr(data, "verbose_info") == "extracted") { verbose_cat(" -> data : ", n, " rows ", ncol(data), " cols", "(", color_codes$yellow_start, "extracted from the model", color_codes$yellow_end, ")", verbose = verbose) attr(data, "verbose_info") <- NULL - } else { + } else if (!is.null(attr(data, "verbose_info")) && attr(data, "verbose_info") == "colnames_changed") { + verbose_cat(" -> data : ", n, " rows ", ncol(data), " cols", "(", color_codes$yellow_start, "colnames changed to comply with the model", color_codes$yellow_end, ")", verbose = verbose) + attr(data, "verbose_info") <- NULL + } + else { verbose_cat(" -> data : ", n, " rows ", ncol(data), " cols", verbose = verbose) } } @@ -254,7 +258,7 @@ explain_survival <- # verbose predict function if (is.null(predict_function)) { if (!is.null(predict_cumulative_hazard_function)) { - predict_function <- function(model, newdata) risk_from_chf(predict_cumulative_hazard_function(model, newdata, times = times)) + predict_function <- risk_from_chf(predict_cumulative_hazard_function, times) verbose_cat(" -> predict function : ", "sum over the predict_cumulative_hazard_function will be used", is.default = TRUE, verbose = verbose) } else { verbose_cat(" -> predict function : not specified! (", color_codes$red_start, "WARNING", color_codes$red_end, ")", verbose = verbose) @@ -411,22 +415,41 @@ explain.default <- function(model, predict_survival_function = pec::predictSurvProb ) ) - } else { - DALEX::explain(model, - data = data, - y = y, - predict_function = predict_function, - predict_function_target_column = predict_function_target_column, - residual_function = residual_function, - weights = weights, - ... = ..., - label = label, - verbose = verbose, - colorize = !isTRUE(getOption("knitr.in.progress")), - model_info = model_info, - type = type + } + if (inherits(model, "sksurv.base.SurvivalAnalysisMixin")){ + return( + explain.sksurv(model, + data = data, + y = y, + predict_function = predict_function, + predict_function_target_column = predict_function_target_column, + residual_function = residual_function, + weights = weights, + ..., + label = label, + verbose = verbose, + colorize = colorize, + model_info = model_info, + type = type + ) ) } + + DALEX::explain(model, + data = data, + y = y, + predict_function = predict_function, + predict_function_target_column = predict_function_target_column, + residual_function = residual_function, + weights = weights, + ... = ..., + label = label, + verbose = verbose, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = model_info, + type = type + ) + } #' @export @@ -845,6 +868,98 @@ explain.LearnerSurv <- function(model, ) } +#' @export +explain.sksurv <- function(model, + data = NULL, + y = NULL, + predict_function = NULL, + predict_function_target_column = NULL, + residual_function = NULL, + weights = NULL, + ..., + label = NULL, + verbose = TRUE, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = NULL, + type = NULL, + times = NULL, + times_generation = "quantiles", + predict_survival_function = NULL, + predict_cumulative_hazard_function = NULL){ + if (is.null(label)) { + label <- class(model)[1] + attr(label, "verbose_info") <- "default" + } + + if (is.null(predict_survival_function)) { + if (reticulate::py_has_attr(model, "predict_survival_function")) { + predict_survival_function <- function(model, newdata, times){ + raw_preds <- model$predict_survival_function(newdata) + t(sapply(raw_preds, function(sf) as.vector(sf(times)))) + } + attr(predict_survival_function, "verbose_info") <- "predict_survival_function from scikit-survival will be used" + attr(predict_survival_function, "is.default") <- TRUE + } + } else { + attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) + } + + if (is.null(predict_cumulative_hazard_function)) { + if (reticulate::py_has_attr(model, "predict_cumulative_hazard_function")) { + predict_cumulative_hazard_function <- function(model, newdata, times){ + raw_preds <- model$predict_cumulative_hazard_function(newdata) + t(sapply(raw_preds, function(chf) as.vector(chf(times)))) + } + attr(predict_cumulative_hazard_function, "verbose_info") <- "predict_cumulative_hazard_function from scikit-survival will be used" + attr(predict_cumulative_hazard_function, "is.default") <- TRUE + } + } else { + attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) + } + + if (is.null(predict_function)) { + if (reticulate::py_has_attr(model, "predict")) { + predict_function <- function(model, newdata, times) model$predict(newdata) + attr(predict_function, "verbose_info") <- "predict from scikit-survival will be used" + } else { + predict_function <- function(model, newdata, times) { + rowSums(predict_cumulative_hazard_function(model, newdata, times)) + } + attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" + } + attr(predict_function, "is.default") <- TRUE + attr(predict_function, "use.times") <- TRUE + } else { + attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) + } + + if (!is.null(data) & any(colnames(data) != model$feature_names_in_)) { + colnames(data) <- sub("[.]", "=", colnames(data)) + attr(data, "verbose_info") <- "colnames_changed" + } + + class(model) <- c("sksurv", class(model)) + + explain_survival( + model, + data = data, + y = y, + predict_function = predict_function, + predict_function_target_column = predict_function_target_column, + residual_function = residual_function, + weights = weights, + ... = ..., + label = label, + verbose = verbose, + colorize = colorize, + model_info = model_info, + type = type, + times = times, + times_generation = times_generation, + predict_survival_function = predict_survival_function, + predict_cumulative_hazard_function = predict_cumulative_hazard_function + ) +} verbose_cat <- function(..., is.default = NULL, verbose = TRUE) { diff --git a/R/model_info.R b/R/model_info.R index 0bb71b28..0048421c 100644 --- a/R/model_info.R +++ b/R/model_info.R @@ -43,7 +43,6 @@ surv_model_info <- function(model, ...) { UseMethod("surv_model_info") } - #' @rdname surv_model_info #' @export surv_model_info.coxph <- function(model, ...) { @@ -103,7 +102,6 @@ surv_model_info.cph <- function(model, ...) { model_info } - #' @rdname surv_model_info #' @export surv_model_info.LearnerSurv <- function(model, ...) { @@ -115,14 +113,23 @@ surv_model_info.LearnerSurv <- function(model, ...) { model_info } - +#' @rdname surv_model_info +#' @export +surv_model_info.sksurv <- function(model, ...) { + type <- "survival" + package <- "scikit-survival" + ver <- get_pkg_ver_safe(package) + model_info <- list(package = package, ver = ver, type = type) + class(model_info) <- "model_info" + model_info +} #' @rdname surv_model_info #' @export surv_model_info.default <- function(model, ...) { type <- "survival" - package <- paste("Model of class:", class(model), "package unrecognized") - ver <- "Unknown" + package <- paste("unrecognized ,", "model of class:", class(model)) + ver <- "unknown" model_info <- list(package = package, ver = ver, type = type) class(model_info) <- "model_info" model_info @@ -131,7 +138,7 @@ surv_model_info.default <- function(model, ...) { get_pkg_ver_safe <- function(package) { ver <- try(as.character(utils::packageVersion(package)), silent = TRUE) if (inherits(ver, "try-error")) { - ver <- "Unknown" + ver <- "unknown" } ver } diff --git a/man/surv_model_info.Rd b/man/surv_model_info.Rd index 6c0ed901..335011f6 100644 --- a/man/surv_model_info.Rd +++ b/man/surv_model_info.Rd @@ -8,6 +8,7 @@ \alias{surv_model_info.model_fit} \alias{surv_model_info.cph} \alias{surv_model_info.LearnerSurv} +\alias{surv_model_info.sksurv} \alias{surv_model_info.default} \title{Extract additional information from the model} \usage{ @@ -25,6 +26,8 @@ surv_model_info(model, ...) \method{surv_model_info}{LearnerSurv}(model, ...) +\method{surv_model_info}{sksurv}(model, ...) + \method{surv_model_info}{default}(model, ...) } \arguments{ From c4528f05c25b9dcf9b2c3a6c7eae9b15df38fce0 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 01:33:16 +0200 Subject: [PATCH 169/207] check if class was not added before --- R/explain.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/explain.R b/R/explain.R index 7cb44e8a..a92bcbdb 100644 --- a/R/explain.R +++ b/R/explain.R @@ -938,7 +938,9 @@ explain.sksurv <- function(model, attr(data, "verbose_info") <- "colnames_changed" } - class(model) <- c("sksurv", class(model)) + if (class(model)[1] != "sksurv"){ + class(model) <- c("sksurv", class(model)) + } explain_survival( model, From 1dbfa18de3411edd130e677c552f7af765c37583 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 01:34:09 +0200 Subject: [PATCH 170/207] add model_diagnostics (work in progress) --- NAMESPACE | 3 + R/model_diagnostics.R | 79 +++++++++++++++++++++++++ R/plot_model_diagnostics_survival.R | 21 +++++++ man/model_diagnostics.surv_explainer.Rd | 49 +++++++++++++++ man/plot.model_diagnostics_survival.Rd | 11 ++++ 5 files changed, 163 insertions(+) create mode 100644 R/model_diagnostics.R create mode 100644 R/plot_model_diagnostics_survival.R create mode 100644 man/model_diagnostics.surv_explainer.Rd create mode 100644 man/plot.model_diagnostics_survival.Rd diff --git a/NAMESPACE b/NAMESPACE index 9f35a13e..f34345d2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -7,6 +7,7 @@ S3method(explain,model_fit) S3method(explain,ranger) S3method(explain,rfsrc) S3method(explain,sksurv) +S3method(model_diagnostics,surv_explainer) S3method(model_parts,default) S3method(model_parts,surv_explainer) S3method(model_performance,default) @@ -16,6 +17,7 @@ S3method(model_profile,surv_explainer) S3method(model_profile_2d,surv_explainer) S3method(model_survshap,surv_explainer) S3method(plot,aggregated_surv_shap) +S3method(plot,model_diagnostics_survival) S3method(plot,model_parts_survival) S3method(plot,model_performance_survival) S3method(plot,model_profile_2d_survival) @@ -62,6 +64,7 @@ export(loss_integrated_brier_score) export(loss_one_minus_c_index) export(loss_one_minus_cd_auc) export(loss_one_minus_integrated_cd_auc) +export(model_diagnostics) export(model_parts) export(model_performance) export(model_profile) diff --git a/R/model_diagnostics.R b/R/model_diagnostics.R new file mode 100644 index 00000000..8bdc110f --- /dev/null +++ b/R/model_diagnostics.R @@ -0,0 +1,79 @@ +#' Dataset Level Model Diagnostics +#' +#' This function calculates martingale and deviance residuals. +#' +#' +#' @param explainer an explainer object - model preprocessed by the `explain()` function +#' @param ... other parameters passed to `DALEX::model_diagnostics` if `output_type == "risk"`, otherwise passed to internal functions. +#' @param output_type either `"chf"`, `"survival"` or `"risk"` the type of survival model output that should be used for explanations. If `"chf"` or `"survival"` the explanations are based on survival residuals. Otherwise the scalar risk predictions are used by the `DALEX::model_diagnostics` function. +#' +#' @return An object of class `c("model_diagnostics_survival")`. It's a list with the explanations in the `result` element. +#' +#' @examples +#' \donttest{ +#' library(survival) +#' library(survex) +#' +#' cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) +#' rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., +#' data = veteran, +#' respect.unordered.factors = TRUE, +#' num.trees = 100, +#' mtry = 3, +#' max.depth = 5 +#' ) +#' +#' cph_exp <- explain(cph) +#' +#' rsf_ranger_exp <- explain(rsf_ranger, +#' data = veteran[, -c(3, 4)], +#' y = Surv(veteran$time, veteran$status) +#' ) +#' +#' TODO kod +#' +#' } +#' @rdname model_diagnostics.surv_explainer +#' @export +model_diagnostics <- function(explainer, ...) UseMethod("model_diagnostics", explainer) + +#' @rdname model_diagnostics.surv_explainer +#' @export +model_diagnostics.surv_explainer <- function(explainer, + ..., + output_type = "chf") { + n <- nrow(explainer$data) + original_times <- explainer$y[, 1] + statuses <- explainer$y[, 2] + + unique_times <- sort(unique(original_times)) + which_el <- matrix(c(1:n, match(original_times, unique_times)), + nrow = n) + chf_preds <- + predict(explainer, times = unique_times, output_type = "chf") + cox_snell_residuals <- chf_preds[which_el] + martingale_residuals <- statuses - cox_snell_residuals + deviance_residuals <- sign(martingale_residuals) * + sqrt(-2 * ( + martingale_residuals + statuses * + log(statuses - martingale_residuals) + )) + cox_snell_residuals[statuses == 0] <- + cox_snell_residuals[statuses == 0] + 1 + + result <- cbind( + data.frame( + "time" = original_times, + "status" = factor(statuses), + "cox_snell_residuals" = cox_snell_residuals, + "martingale_residuals" = martingale_residuals, + "deviance_residuals" = deviance_residuals, + "label" = explainer$label + ), + explainer$data + ) + + res <- list(result = result) + class(res) <- c("model_diagnostics_survival", class(res)) + res +} diff --git a/R/plot_model_diagnostics_survival.R b/R/plot_model_diagnostics_survival.R new file mode 100644 index 00000000..25aea702 --- /dev/null +++ b/R/plot_model_diagnostics_survival.R @@ -0,0 +1,21 @@ +#' Plot Model Diagnostics for Survival Models +#' +#' @export +plot.model_diagnostics_survival <- function(x, + ..., + type = "deviance", + colors = "red"){ + lapply(list(x, ...), function(x) { + if (!inherits(x, "model_diagnostics_survival")) { + stop("All ... must be objects of class `model_diagnostics_survival`.") + } + }) + explanations_list <- c(list(x), list(...)) + result_list <- lapply(explanations_list, function(x) x$result) + df <- do.call(rbind, result_list) + print(df) + ggplot(df, aes(x = time, y = deviance_residuals, color = status)) + + geom_point() + + theme_default_survex() + + facet_wrap(~label) +} diff --git a/man/model_diagnostics.surv_explainer.Rd b/man/model_diagnostics.surv_explainer.Rd new file mode 100644 index 00000000..3710dd9f --- /dev/null +++ b/man/model_diagnostics.surv_explainer.Rd @@ -0,0 +1,49 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model_diagnostics.R +\name{model_diagnostics} +\alias{model_diagnostics} +\alias{model_diagnostics.surv_explainer} +\title{Dataset Level Model Diagnostics} +\usage{ +model_diagnostics(explainer, ...) + +\method{model_diagnostics}{surv_explainer}(explainer, ..., output_type = "chf") +} +\arguments{ +\item{explainer}{an explainer object - model preprocessed by the \code{explain()} function} + +\item{...}{other parameters passed to \code{DALEX::model_diagnostics} if \code{output_type == "risk"}, otherwise passed to internal functions.} + +\item{output_type}{either \code{"chf"}, \code{"survival"} or \code{"risk"} the type of survival model output that should be used for explanations. If \code{"chf"} or \code{"survival"} the explanations are based on survival residuals. Otherwise the scalar risk predictions are used by the \code{DALEX::model_diagnostics} function.} +} +\value{ +An object of class \code{c("model_diagnostics_survival")}. It's a list with the explanations in the \code{result} element. +} +\description{ +This function calculates martingale and deviance residuals. +} +\examples{ +\donttest{ +library(survival) +library(survex) + +cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) +rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., + data = veteran, + respect.unordered.factors = TRUE, + num.trees = 100, + mtry = 3, + max.depth = 5 +) + +cph_exp <- explain(cph) + +rsf_ranger_exp <- explain(rsf_ranger, + data = veteran[, -c(3, 4)], + y = Surv(veteran$time, veteran$status) +) + +TODO kod + +} +} diff --git a/man/plot.model_diagnostics_survival.Rd b/man/plot.model_diagnostics_survival.Rd new file mode 100644 index 00000000..db3cca1d --- /dev/null +++ b/man/plot.model_diagnostics_survival.Rd @@ -0,0 +1,11 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot_model_diagnostics_survival.R +\name{plot.model_diagnostics_survival} +\alias{plot.model_diagnostics_survival} +\title{Plot Model Diagnostics for Survival Models} +\usage{ +\method{plot}{model_diagnostics_survival}(x, ...) +} +\description{ +Plot Model Diagnostics for Survival Models +} From d7915b9daa33005f6930f805b5a2b4132b5907d5 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 17:12:41 +0200 Subject: [PATCH 171/207] remove Cox-Snell residuals modification --- R/model_diagnostics.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/model_diagnostics.R b/R/model_diagnostics.R index 8bdc110f..d7e95105 100644 --- a/R/model_diagnostics.R +++ b/R/model_diagnostics.R @@ -58,8 +58,9 @@ model_diagnostics.surv_explainer <- function(explainer, martingale_residuals + statuses * log(statuses - martingale_residuals) )) - cox_snell_residuals[statuses == 0] <- - cox_snell_residuals[statuses == 0] + 1 + # cox_snell_residuals[statuses == 0] <- + # cox_snell_residuals[statuses == 0] + 1 + # modification for censored observations result <- cbind( data.frame( From c36a422d53b3b532e3f8d1161c9a6ec2ebb7b077 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 17:13:17 +0200 Subject: [PATCH 172/207] add final diagnostic plots --- R/plot_model_diagnostics_survival.R | 74 ++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 6 deletions(-) diff --git a/R/plot_model_diagnostics_survival.R b/R/plot_model_diagnostics_survival.R index 25aea702..c66a1a2e 100644 --- a/R/plot_model_diagnostics_survival.R +++ b/R/plot_model_diagnostics_survival.R @@ -4,7 +4,12 @@ plot.model_diagnostics_survival <- function(x, ..., type = "deviance", - colors = "red"){ + xvariable = "index", + smooth = as.logical(xvariable != "index"), + title = "Model diagnostics", + subtitle = "default", + facet_ncol =NULL, + colors = c("#160e3b", "#f05a71", "#ceced9")){ lapply(list(x, ...), function(x) { if (!inherits(x, "model_diagnostics_survival")) { stop("All ... must be objects of class `model_diagnostics_survival`.") @@ -12,10 +17,67 @@ plot.model_diagnostics_survival <- function(x, }) explanations_list <- c(list(x), list(...)) result_list <- lapply(explanations_list, function(x) x$result) + n_observations <- print(sapply(result_list, nrow)) df <- do.call(rbind, result_list) - print(df) - ggplot(df, aes(x = time, y = deviance_residuals, color = status)) + - geom_point() + - theme_default_survex() + - facet_wrap(~label) + + if (!is.null(subtitle) && subtitle == "default") { + subtitle <- paste0("created for the ", paste(unique(df$label), collapse = ", "), + ifelse(length(unique(df$label)) > 1, " models", " model")) + } + + if (type %in% c("deviance", "martingale")){ + if (!xvariable %in% c("index", colnames(df))){ + stop(paste("`xvariable`", xvariable, "not found")) + } + df$y <- switch(type, + "deviance" = df$deviance_residuals, + "martingale" = df$martingale_residuals) + df$x <- switch(xvariable, + "index" = unlist(lapply(n_observations, function(x) seq_len(x))), + df[[xvariable]]) + + pl <- ggplot(df, aes(x = x, y = y, color = status)) + + geom_hline(yintercept = 0, color = colors[3], lty = 2, linewidth = 1) + + geom_point() + + theme_default_survex() + + scale_color_manual(values = colors, labels = c("censored", "event")) + + facet_wrap(~label, ncol = facet_ncol) + + labs(title = title, + subtitle = subtitle, + x = xvariable, + y = paste(type, "residuals")) + if (smooth) + pl <- pl + geom_smooth(se = FALSE, color = colors[3], alpha = 0.5) + return(pl) + } else if (type == "Cox-Snell"){ + split_df <- split(df, df$label) + df_list <- lapply(split_df, function(df_tmp){ + fit_coxsnell <- survival::survfit(survival::Surv(cox_snell_residuals, as.numeric(status)) ~ 1, data=df_tmp) + confint <- survival:::survfit_confint(fit_coxsnell$cumhaz, fit_coxsnell$std.chaz, + logse=FALSE, "log", 0.95, ulimit = FALSE) + data.frame( + "time" = fit_coxsnell$time, + "cumhaz" = fit_coxsnell$cumhaz, + "lower" = confint$lower, + "upper" = confint$upper) + }) + + df <- do.call(rbind, df_list) + df$label <- sapply(strsplit(rownames(df), "[.]"), function(x) x[1]) + rownames(df) <- NULL + + ggplot(df, aes(x = time, y = cumhaz)) + + geom_step(color = colors[1], linewidth = 1) + + geom_step(aes(y = lower), linetype = "dashed", color = colors[1], alpha = 0.8) + + geom_step(aes(y = upper), linetype = "dashed", color = colors[1], alpha = 0.8) + + geom_abline(slope = 1, color = colors[2], linewidth = 1) + + labs(x = "Cox-Snell residuals (pseudo observed times)", + y = "Cumulative hazard at pseudo observed times") + + theme_default_survex() + + facet_wrap(~label, ncol = facet_ncol) + + labs(title = title, + subtitle = subtitle) + } else{ + stop('`type` should be one of `deviance`, `martingale` or `Cox-Snell`') + } } From dcedaca8686e8e872aef3c25ebd174ac79c76b4a Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 17:50:55 +0200 Subject: [PATCH 173/207] add print --- NAMESPACE | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index f34345d2..4b3c66b1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -34,6 +34,7 @@ S3method(predict_parts,default) S3method(predict_parts,surv_explainer) S3method(predict_profile,default) S3method(predict_profile,surv_explainer) +S3method(print,model_diagnostics_survival) S3method(print,model_profile_2d_survival) S3method(print,model_profile_survival) S3method(print,surv_ceteris_paribus) @@ -99,6 +100,7 @@ importFrom(stats,model.matrix) importFrom(stats,na.omit) importFrom(stats,optim) importFrom(stats,predict) +importFrom(stats,qnorm) importFrom(stats,quantile) importFrom(stats,reorder) importFrom(stats,reshape) From a95ba7ac98ae5bfba194fa175de370e61977f4ae Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 17:51:08 +0200 Subject: [PATCH 174/207] add example --- R/model_diagnostics.R | 7 ++++++- man/model_diagnostics.surv_explainer.Rd | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/R/model_diagnostics.R b/R/model_diagnostics.R index d7e95105..383e3ee6 100644 --- a/R/model_diagnostics.R +++ b/R/model_diagnostics.R @@ -30,7 +30,12 @@ #' y = Surv(veteran$time, veteran$status) #' ) #' -#' TODO kod +#' cph_residuals <- model_diagnostics(cph_exp) +#' rsf_residuals <- model_diagnostics(rsf_ranger_exp) +#' +#' head(cph_residuals$result) +#' plot(cph_residuals, rsf_residuals, xvariable = "age") +#' plot(cph_residuals, rsf_residuals, type = "Cox-Snell") #' #' } #' @rdname model_diagnostics.surv_explainer diff --git a/man/model_diagnostics.surv_explainer.Rd b/man/model_diagnostics.surv_explainer.Rd index 3710dd9f..8b338f14 100644 --- a/man/model_diagnostics.surv_explainer.Rd +++ b/man/model_diagnostics.surv_explainer.Rd @@ -43,7 +43,12 @@ rsf_ranger_exp <- explain(rsf_ranger, y = Surv(veteran$time, veteran$status) ) -TODO kod +cph_residuals <- model_diagnostics(cph_exp) +rsf_residuals <- model_diagnostics(rsf_ranger_exp) + +head(cph_residuals$result) +plot(cph_residuals, rsf_residuals, xvariable = "age") +plot(cph_residuals, rsf_residuals, type = "Cox-Snell") } } From 5875f66f49974531b4f9b7b79d566665008a2531 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 17:51:30 +0200 Subject: [PATCH 175/207] fix plot code --- R/plot_model_diagnostics_survival.R | 75 ++++++++++++++------------ man/plot.model_diagnostics_survival.Rd | 12 ++++- 2 files changed, 53 insertions(+), 34 deletions(-) diff --git a/R/plot_model_diagnostics_survival.R b/R/plot_model_diagnostics_survival.R index c66a1a2e..d353a67e 100644 --- a/R/plot_model_diagnostics_survival.R +++ b/R/plot_model_diagnostics_survival.R @@ -1,5 +1,7 @@ #' Plot Model Diagnostics for Survival Models #' +#' @importFrom stats qnorm +#' #' @export plot.model_diagnostics_survival <- function(x, ..., @@ -17,12 +19,13 @@ plot.model_diagnostics_survival <- function(x, }) explanations_list <- c(list(x), list(...)) result_list <- lapply(explanations_list, function(x) x$result) - n_observations <- print(sapply(result_list, nrow)) + n_observations <- sapply(result_list, nrow) df <- do.call(rbind, result_list) if (!is.null(subtitle) && subtitle == "default") { - subtitle <- paste0("created for the ", paste(unique(df$label), collapse = ", "), - ifelse(length(unique(df$label)) > 1, " models", " model")) + labels <- unique(df$label) + endword <- ifelse(length(labels) > 1, " models", " model") + subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) } if (type %in% c("deviance", "martingale")){ @@ -35,48 +38,54 @@ plot.model_diagnostics_survival <- function(x, df$x <- switch(xvariable, "index" = unlist(lapply(n_observations, function(x) seq_len(x))), df[[xvariable]]) - - pl <- ggplot(df, aes(x = x, y = y, color = status)) + - geom_hline(yintercept = 0, color = colors[3], lty = 2, linewidth = 1) + - geom_point() + - theme_default_survex() + - scale_color_manual(values = colors, labels = c("censored", "event")) + - facet_wrap(~label, ncol = facet_ncol) + - labs(title = title, - subtitle = subtitle, - x = xvariable, - y = paste(type, "residuals")) - if (smooth) - pl <- pl + geom_smooth(se = FALSE, color = colors[3], alpha = 0.5) + pl <- with(df, { + pl <- ggplot(df, aes(x = x, y = y, color = status)) + + geom_hline(yintercept = 0, color = colors[3], lty = 2, linewidth = 1) + + geom_point() + + theme_default_survex() + + scale_color_manual(values = colors, labels = c("censored", "event")) + + facet_wrap(~label, ncol = facet_ncol) + + labs(title = title, + subtitle = subtitle, + x = xvariable, + y = paste(type, "residuals")) + if (smooth) + pl <- pl + geom_smooth(se = FALSE, color = colors[3], alpha = 0.5) + pl + }) return(pl) } else if (type == "Cox-Snell"){ split_df <- split(df, df$label) df_list <- lapply(split_df, function(df_tmp){ - fit_coxsnell <- survival::survfit(survival::Surv(cox_snell_residuals, as.numeric(status)) ~ 1, data=df_tmp) - confint <- survival:::survfit_confint(fit_coxsnell$cumhaz, fit_coxsnell$std.chaz, - logse=FALSE, "log", 0.95, ulimit = FALSE) - data.frame( + fit_coxsnell <- survival::survfit(survival::Surv(cox_snell_residuals, as.numeric(status)) ~ 1, data=df_tmp) + x <- ifelse(fit_coxsnell$cumhaz == 0, NA, fit_coxsnell$cumhaz) + se2 <- qnorm(0.975, 0, 1) * fit_coxsnell$std.chaz / x + lower <- exp(log(x) - se2) + upper <- exp(log(x) + se2) + data.frame( "time" = fit_coxsnell$time, "cumhaz" = fit_coxsnell$cumhaz, - "lower" = confint$lower, - "upper" = confint$upper) + "lower" = lower, + "upper" = upper) }) df <- do.call(rbind, df_list) df$label <- sapply(strsplit(rownames(df), "[.]"), function(x) x[1]) rownames(df) <- NULL - ggplot(df, aes(x = time, y = cumhaz)) + - geom_step(color = colors[1], linewidth = 1) + - geom_step(aes(y = lower), linetype = "dashed", color = colors[1], alpha = 0.8) + - geom_step(aes(y = upper), linetype = "dashed", color = colors[1], alpha = 0.8) + - geom_abline(slope = 1, color = colors[2], linewidth = 1) + - labs(x = "Cox-Snell residuals (pseudo observed times)", - y = "Cumulative hazard at pseudo observed times") + - theme_default_survex() + - facet_wrap(~label, ncol = facet_ncol) + - labs(title = title, - subtitle = subtitle) + with(df, + {ggplot(df, aes(x = time, y = cumhaz)) + + geom_step(color = colors[1], linewidth = 1) + + geom_step(aes(y = lower), linetype = "dashed", color = colors[1], alpha = 0.8) + + geom_step(aes(y = upper), linetype = "dashed", color = colors[1], alpha = 0.8) + + geom_abline(slope = 1, color = colors[2], linewidth = 1) + + labs(x = "Cox-Snell residuals (pseudo observed times)", + y = "Cumulative hazard at pseudo observed times") + + theme_default_survex() + + facet_wrap(~label, ncol = facet_ncol) + + labs(title = title, + subtitle = subtitle) + }) } else{ stop('`type` should be one of `deviance`, `martingale` or `Cox-Snell`') } diff --git a/man/plot.model_diagnostics_survival.Rd b/man/plot.model_diagnostics_survival.Rd index db3cca1d..686ef036 100644 --- a/man/plot.model_diagnostics_survival.Rd +++ b/man/plot.model_diagnostics_survival.Rd @@ -4,7 +4,17 @@ \alias{plot.model_diagnostics_survival} \title{Plot Model Diagnostics for Survival Models} \usage{ -\method{plot}{model_diagnostics_survival}(x, ...) +\method{plot}{model_diagnostics_survival}( + x, + ..., + type = "deviance", + xvariable = "index", + smooth = as.logical(xvariable != "index"), + title = "Model diagnostics", + subtitle = "default", + facet_ncol = NULL, + colors = c("#160e3b", "#f05a71", "#ceced9") +) } \description{ Plot Model Diagnostics for Survival Models From 2fe5eb44b7f37faa5a0cc9e0fc83fbb3cce541a3 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 17:51:36 +0200 Subject: [PATCH 176/207] add print --- R/print.R | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/R/print.R b/R/print.R index c042cd61..8803aeef 100644 --- a/R/print.R +++ b/R/print.R @@ -59,3 +59,12 @@ print.surv_shap <- function(x, ...) { cat("\n") print(res, ...) } + + +#' @export +print.model_diagnostics_survival <- function(x, ...){ + res <- x$result + text <- paste0("Survival residuals for the ", unique(res$label), " model:\n") + cat(text) + print.data.frame(res, ...) +} From 33b21dac3defea583fecd8d0376eb055ab846baf Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 17:51:59 +0200 Subject: [PATCH 177/207] fix subtitles (singular/plurar) --- R/plot_model_profile_2d.R | 8 +++++--- R/plot_model_profile_survival.R | 6 +++--- R/plot_surv_ceteris_paribus.R | 5 +++-- R/plot_surv_feature_importance.R | 6 +++--- R/plot_surv_model_performance.R | 8 ++++++-- R/plot_surv_model_performance_rocs.R | 4 +++- R/plot_surv_shap.R | 5 +++-- 7 files changed, 26 insertions(+), 16 deletions(-) diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index fd192a55..fd3337b5 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -204,9 +204,11 @@ prepare_model_profile_2d_plots <- function(x, return(p) }) if (!is.null(subtitle) && subtitle == "default") { - labels <- - paste0(unique(all_profiles$`_label_`), collapse = ", ") - subtitle <- paste0("created for the ", labels, " model") + labels <- unique(all_profiles$`_label_`) + endword <- ifelse(length(labels) > 1, " models", " model") + subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) + + if (!marginalize_over_time) { subtitle <- paste0(subtitle, " and t = ", times) } diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 5c56b8a2..f74425de 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -502,9 +502,9 @@ prepare_model_profile_plots <- function(x, if (!is.null(subtitle) && subtitle == "default") { - labels <- - paste0(unique(aggregated_profiles$`_label_`), collapse = ", ") - subtitle <- paste0("created for the ", labels, " model") + labels <- unique(aggregated_profiles$`_label_`) + endword <- ifelse(length(labels) > 1, " models", " model") + subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) } if (is.null(variables)) { diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index f148a604..075feb97 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -20,8 +20,9 @@ prepare_ceteris_paribus_plots <- function(x, # extract labels to use in the default subtitle if (!is.null(subtitle) && subtitle == "default") { - labels <- paste0(unique(all_profiles$`_label_`), collapse = ", ") - subtitle <- paste0("created for the ", labels, " model") + labels <- unique(all_profiles$`_label_`) + endword <- ifelse(length(labels) > 1, " models", " model") + subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) } # variables to use diff --git a/R/plot_surv_feature_importance.R b/R/plot_surv_feature_importance.R index 97365594..9dbb33f2 100644 --- a/R/plot_surv_feature_importance.R +++ b/R/plot_surv_feature_importance.R @@ -57,7 +57,7 @@ plot.surv_feature_importance <- function(x, ..., plotting_df <- do.call(rbind, transformed_dfs) rug_df <- do.call(rbind, transformed_rug_dfs) - label <- unique(plotting_df$label) + labels <- unique(plotting_df$label) subs <- aggregate(plotting_df$value, by = list(var = plotting_df$ind), function(x) sum(abs(x))) @@ -79,8 +79,8 @@ plot.surv_feature_importance <- function(x, ..., } if (!is.null(subtitle) && subtitle == "default") { - glm_labels <- paste0(label, collapse = ", ") - subtitle <- paste0("created for the ", glm_labels, " model") + endword <- ifelse(length(labels) > 1, " models", " model") + subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) } base_plot <- with(plotting_df, { diff --git a/R/plot_surv_model_performance.R b/R/plot_surv_model_performance.R index 7b14df42..9953f5e1 100644 --- a/R/plot_surv_model_performance.R +++ b/R/plot_surv_model_performance.R @@ -79,7 +79,9 @@ plot_td_surv_model_performance <- function(x, ..., metrics = NULL, title = NULL, df <- concatenate_td_dfs(x, ...) if (!is.null(subtitle) && subtitle == "default") { - subtitle <- paste0("created for the ", paste(unique(df$label), collapse = ", "), " model") + labels <- unique(df$label) + endword <- ifelse(length(labels) > 1, " models", " model") + subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) } if (is.null(metrics)) metrics <- c("C/D AUC", "Brier score") @@ -106,7 +108,9 @@ plot_scalar_surv_model_performance <- function(x, ..., metrics = NULL, title = N df <- concatenate_dfs(x, ...) if (!is.null(subtitle) && subtitle == "default") { - subtitle <- paste0("created for the ", paste(unique(df$label), collapse = ", "), " model") + labels <- unique(df$label) + endword <- ifelse(length(labels) > 1, " models", " model") + subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) } if (!is.null(metrics)) df <- df[df$ind %in% metrics, ] diff --git a/R/plot_surv_model_performance_rocs.R b/R/plot_surv_model_performance_rocs.R index b1aa0a2b..8e5923a2 100644 --- a/R/plot_surv_model_performance_rocs.R +++ b/R/plot_surv_model_performance_rocs.R @@ -45,7 +45,9 @@ plot.surv_model_performance_rocs <- function(x, df <- do.call(rbind, alldfs) if (!is.null(subtitle) && subtitle == "default") { - subtitle <- paste0("created for the ", paste(unique(df$label), collapse = ", "), " model") + labels <- unique(df$label) + endword <- ifelse(length(labels) > 1, " models", " model") + subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) } num_colors <- length(unique(df$label)) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index a47351e0..5c0c6681 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -61,10 +61,11 @@ plot.surv_shap <- function(x, rug_df <- do.call(rbind, transformed_rug_dfs) long_df <- do.call(rbind, long_df) - label <- unique(long_df$label) + labels <- unique(long_df$label) if (!is.null(subtitle) && subtitle == "default") { - subtitle <- paste0("created for the ", paste(label, collapse = ", "), " model") + endword <- ifelse(length(labels) > 1, " models", " model") + subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) } n_colors <- length(unique(long_df$ind)) From 26c893bfc6ddd6efcf36a24f6a1d457fba4be138 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 18:25:33 +0200 Subject: [PATCH 178/207] add switch --- R/model_diagnostics.R | 82 ++++++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/R/model_diagnostics.R b/R/model_diagnostics.R index 383e3ee6..d1741898 100644 --- a/R/model_diagnostics.R +++ b/R/model_diagnostics.R @@ -4,7 +4,7 @@ #' #' #' @param explainer an explainer object - model preprocessed by the `explain()` function -#' @param ... other parameters passed to `DALEX::model_diagnostics` if `output_type == "risk"`, otherwise passed to internal functions. +#' @param ... other parameters passed to `DALEX::model_diagnostics` if `output_type == "risk"`, otherwise passed to internal functions #' @param output_type either `"chf"`, `"survival"` or `"risk"` the type of survival model output that should be used for explanations. If `"chf"` or `"survival"` the explanations are based on survival residuals. Otherwise the scalar risk predictions are used by the `DALEX::model_diagnostics` function. #' #' @return An object of class `c("model_diagnostics_survival")`. It's a list with the explanations in the `result` element. @@ -35,7 +35,7 @@ #' #' head(cph_residuals$result) #' plot(cph_residuals, rsf_residuals, xvariable = "age") -#' plot(cph_residuals, rsf_residuals, type = "Cox-Snell") +#' plot(cph_residuals, rsf_residuals, plot_type = "Cox-Snell") #' #' } #' @rdname model_diagnostics.surv_explainer @@ -47,39 +47,55 @@ model_diagnostics <- function(explainer, ...) UseMethod("model_diagnostics", exp model_diagnostics.surv_explainer <- function(explainer, ..., output_type = "chf") { - n <- nrow(explainer$data) - original_times <- explainer$y[, 1] - statuses <- explainer$y[, 2] - unique_times <- sort(unique(original_times)) - which_el <- matrix(c(1:n, match(original_times, unique_times)), - nrow = n) - chf_preds <- - predict(explainer, times = unique_times, output_type = "chf") - cox_snell_residuals <- chf_preds[which_el] - martingale_residuals <- statuses - cox_snell_residuals - deviance_residuals <- sign(martingale_residuals) * - sqrt(-2 * ( - martingale_residuals + statuses * - log(statuses - martingale_residuals) - )) - # cox_snell_residuals[statuses == 0] <- - # cox_snell_residuals[statuses == 0] + 1 - # modification for censored observations + test_explainer(explainer, has_data = TRUE, has_y = TRUE, has_chf = TRUE, function_name = "model_diagnostics") - result <- cbind( - data.frame( - "time" = original_times, - "status" = factor(statuses), - "cox_snell_residuals" = cox_snell_residuals, - "martingale_residuals" = martingale_residuals, - "deviance_residuals" = deviance_residuals, - "label" = explainer$label - ), - explainer$data + if (output_type == "survival"){ + output_type <- "chf" + warning("Cumulative hazard function (not survival function) is used for calculating survival residuals.") + } + + switch("output_type", + "risk" = DALEX::model_diagnostics(explainer, ...), + "chf" = { + n <- nrow(explainer$data) + original_times <- explainer$y[, 1] + statuses <- explainer$y[, 2] + + unique_times <- sort(unique(original_times)) + which_el <- matrix(c(1:n, match(original_times, unique_times)), + nrow = n) + chf_preds <- + predict(explainer, times = unique_times, output_type = "chf") + cox_snell_residuals <- chf_preds[which_el] + martingale_residuals <- statuses - cox_snell_residuals + deviance_residuals <- sign(martingale_residuals) * + sqrt(-2 * ( + martingale_residuals + statuses * + log(statuses - martingale_residuals) + )) + # cox_snell_residuals[statuses == 0] <- + # cox_snell_residuals[statuses == 0] + 1 + # modification for censored observations + + result <- cbind( + data.frame( + "time" = original_times, + "status" = factor(statuses), + "cox_snell_residuals" = cox_snell_residuals, + "martingale_residuals" = martingale_residuals, + "deviance_residuals" = deviance_residuals, + "label" = explainer$label + ), + explainer$data + ) + + res <- list(result = result) + class(res) <- c("model_diagnostics_survival", class(res)) + res + }, + stop("Type should be either `chf` or `risk`") ) - res <- list(result = result) - class(res) <- c("model_diagnostics_survival", class(res)) - res + } From 24d9641def3af56e0efec5c9593f7808690c1f71 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 18:25:48 +0200 Subject: [PATCH 179/207] add docs --- R/plot_model_diagnostics_survival.R | 30 +++++++++++++++++++------ man/model_diagnostics.surv_explainer.Rd | 4 ++-- man/plot.model_diagnostics_survival.Rd | 29 +++++++++++++++++++++--- 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/R/plot_model_diagnostics_survival.R b/R/plot_model_diagnostics_survival.R index d353a67e..6e9cfed8 100644 --- a/R/plot_model_diagnostics_survival.R +++ b/R/plot_model_diagnostics_survival.R @@ -1,16 +1,31 @@ #' Plot Model Diagnostics for Survival Models #' +#' This function plots objects of class `"model_diagnostics_survival"` created +#' using the `model_diagnostics()` function. +#' +#' @param x an object of class `model_diagnostics_survival` to be plotted +#' @param ... additional objects of class `model_diagnostics_survival` to be plotted together +#' @param plot_type character, either `"deviance"`, `"martingale"` or `"Cox-Snell"`. Selects the type of plot to be prepared. If `"deviance"` or `"martingale` then deviance/martingale residuals are plotted against `xvariable`. If `"Cox-Snell"` then diagnostic plot of Cox-Snell residuals is prepared, which is CHF estimated based on Cox-Snell residuals against theoretical cumulative hazard trajectory of the Exp(1) -- diagonal line. +#' @param xvariable character, name of the variable to be plotted on x-axis (can be name of the variable to be drawn on the x-axis (can be any column from the `x$result`: explanatory variable, time, other residuals). By default `"index"` which gives the order of observations. +#' @param smooth logical, shall the smooth line be added. Only used when `plot_type = "deviance"` or `plot_type = "martingale"`. +#' @param facet_ncol number of columns for arranging subplots +#' @param title character, title of the plot +#' @param subtitle character, subtitle of the plot, `"default"` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels +#' @param colors character vector containing the colors to be used for plotting (containing either hex codes "#FF69B4", or names "blue"). +#' +#' @return An object of the class `ggplot`. +#' #' @importFrom stats qnorm #' #' @export plot.model_diagnostics_survival <- function(x, ..., - type = "deviance", + plot_type = "deviance", xvariable = "index", smooth = as.logical(xvariable != "index"), + facet_ncol = NULL, title = "Model diagnostics", subtitle = "default", - facet_ncol =NULL, colors = c("#160e3b", "#f05a71", "#ceced9")){ lapply(list(x, ...), function(x) { if (!inherits(x, "model_diagnostics_survival")) { @@ -28,11 +43,11 @@ plot.model_diagnostics_survival <- function(x, subtitle <- paste0("created for the ", paste0(labels, collapse = ", "), endword) } - if (type %in% c("deviance", "martingale")){ + if (plot_type %in% c("deviance", "martingale")){ if (!xvariable %in% c("index", colnames(df))){ stop(paste("`xvariable`", xvariable, "not found")) } - df$y <- switch(type, + df$y <- switch(plot_type, "deviance" = df$deviance_residuals, "martingale" = df$martingale_residuals) df$x <- switch(xvariable, @@ -54,7 +69,7 @@ plot.model_diagnostics_survival <- function(x, pl }) return(pl) - } else if (type == "Cox-Snell"){ + } else if (plot_type == "Cox-Snell"){ split_df <- split(df, df$label) df_list <- lapply(split_df, function(df_tmp){ fit_coxsnell <- survival::survfit(survival::Surv(cox_snell_residuals, as.numeric(status)) ~ 1, data=df_tmp) @@ -73,7 +88,7 @@ plot.model_diagnostics_survival <- function(x, df$label <- sapply(strsplit(rownames(df), "[.]"), function(x) x[1]) rownames(df) <- NULL - with(df, + pl <- with(df, {ggplot(df, aes(x = time, y = cumhaz)) + geom_step(color = colors[1], linewidth = 1) + geom_step(aes(y = lower), linetype = "dashed", color = colors[1], alpha = 0.8) + @@ -86,7 +101,8 @@ plot.model_diagnostics_survival <- function(x, labs(title = title, subtitle = subtitle) }) + return(pl) } else{ - stop('`type` should be one of `deviance`, `martingale` or `Cox-Snell`') + stop('`plot_type` should be one of `deviance`, `martingale` or `Cox-Snell`') } } diff --git a/man/model_diagnostics.surv_explainer.Rd b/man/model_diagnostics.surv_explainer.Rd index 8b338f14..2e50e8a9 100644 --- a/man/model_diagnostics.surv_explainer.Rd +++ b/man/model_diagnostics.surv_explainer.Rd @@ -12,7 +12,7 @@ model_diagnostics(explainer, ...) \arguments{ \item{explainer}{an explainer object - model preprocessed by the \code{explain()} function} -\item{...}{other parameters passed to \code{DALEX::model_diagnostics} if \code{output_type == "risk"}, otherwise passed to internal functions.} +\item{...}{other parameters passed to \code{DALEX::model_diagnostics} if \code{output_type == "risk"}, otherwise passed to internal functions} \item{output_type}{either \code{"chf"}, \code{"survival"} or \code{"risk"} the type of survival model output that should be used for explanations. If \code{"chf"} or \code{"survival"} the explanations are based on survival residuals. Otherwise the scalar risk predictions are used by the \code{DALEX::model_diagnostics} function.} } @@ -48,7 +48,7 @@ rsf_residuals <- model_diagnostics(rsf_ranger_exp) head(cph_residuals$result) plot(cph_residuals, rsf_residuals, xvariable = "age") -plot(cph_residuals, rsf_residuals, type = "Cox-Snell") +plot(cph_residuals, rsf_residuals, plot_type = "Cox-Snell") } } diff --git a/man/plot.model_diagnostics_survival.Rd b/man/plot.model_diagnostics_survival.Rd index 686ef036..934fd1ef 100644 --- a/man/plot.model_diagnostics_survival.Rd +++ b/man/plot.model_diagnostics_survival.Rd @@ -7,15 +7,38 @@ \method{plot}{model_diagnostics_survival}( x, ..., - type = "deviance", + plot_type = "deviance", xvariable = "index", smooth = as.logical(xvariable != "index"), + facet_ncol = NULL, title = "Model diagnostics", subtitle = "default", - facet_ncol = NULL, colors = c("#160e3b", "#f05a71", "#ceced9") ) } +\arguments{ +\item{x}{an object of class \code{model_diagnostics_survival} to be plotted} + +\item{...}{additional objects of class \code{model_diagnostics_survival} to be plotted together} + +\item{plot_type}{character, either \code{"deviance"}, \code{"martingale"} or \code{"Cox-Snell"}. Selects the type of plot to be prepared. If \code{"deviance"} or \verb{"martingale} then deviance/martingale residuals are plotted against \code{xvariable}. If \code{"Cox-Snell"} then diagnostic plot of Cox-Snell residuals is prepared, which is CHF estimated based on Cox-Snell residuals against theoretical cumulative hazard trajectory of the Exp(1) -- diagonal line.} + +\item{xvariable}{character, name of the variable to be plotted on x-axis (can be name of the variable to be drawn on the x-axis (can be any column from the \code{x$result}: explanatory variable, time, other residuals). By default \code{"index"} which gives the order of observations.} + +\item{smooth}{logical, shall the smooth line be added. Only used when \code{plot_type = "deviance"} or \code{plot_type = "martingale"}.} + +\item{facet_ncol}{number of columns for arranging subplots} + +\item{title}{character, title of the plot} + +\item{subtitle}{character, subtitle of the plot, \code{"default"} automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels} + +\item{colors}{character vector containing the colors to be used for plotting (containing either hex codes "#FF69B4", or names "blue").} +} +\value{ +An object of the class \code{ggplot}. +} \description{ -Plot Model Diagnostics for Survival Models +This function plots objects of class \code{"model_diagnostics_survival"} created +using the \code{model_diagnostics()} function. } From f6f88939f6dd9dd66233cf2b34db44851855055a Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 18:42:17 +0200 Subject: [PATCH 180/207] remove switch for output_type - only one supported --- R/model_diagnostics.R | 78 ++++++++++--------------- man/model_diagnostics.surv_explainer.Rd | 4 -- 2 files changed, 31 insertions(+), 51 deletions(-) diff --git a/R/model_diagnostics.R b/R/model_diagnostics.R index d1741898..20970a77 100644 --- a/R/model_diagnostics.R +++ b/R/model_diagnostics.R @@ -2,10 +2,7 @@ #' #' This function calculates martingale and deviance residuals. #' -#' #' @param explainer an explainer object - model preprocessed by the `explain()` function -#' @param ... other parameters passed to `DALEX::model_diagnostics` if `output_type == "risk"`, otherwise passed to internal functions -#' @param output_type either `"chf"`, `"survival"` or `"risk"` the type of survival model output that should be used for explanations. If `"chf"` or `"survival"` the explanations are based on survival residuals. Otherwise the scalar risk predictions are used by the `DALEX::model_diagnostics` function. #' #' @return An object of class `c("model_diagnostics_survival")`. It's a list with the explanations in the `result` element. #' @@ -50,52 +47,39 @@ model_diagnostics.surv_explainer <- function(explainer, test_explainer(explainer, has_data = TRUE, has_y = TRUE, has_chf = TRUE, function_name = "model_diagnostics") - if (output_type == "survival"){ - output_type <- "chf" - warning("Cumulative hazard function (not survival function) is used for calculating survival residuals.") - } - - switch("output_type", - "risk" = DALEX::model_diagnostics(explainer, ...), - "chf" = { - n <- nrow(explainer$data) - original_times <- explainer$y[, 1] - statuses <- explainer$y[, 2] + n <- nrow(explainer$data) + original_times <- explainer$y[, 1] + statuses <- explainer$y[, 2] - unique_times <- sort(unique(original_times)) - which_el <- matrix(c(1:n, match(original_times, unique_times)), - nrow = n) - chf_preds <- - predict(explainer, times = unique_times, output_type = "chf") - cox_snell_residuals <- chf_preds[which_el] - martingale_residuals <- statuses - cox_snell_residuals - deviance_residuals <- sign(martingale_residuals) * - sqrt(-2 * ( - martingale_residuals + statuses * - log(statuses - martingale_residuals) - )) - # cox_snell_residuals[statuses == 0] <- - # cox_snell_residuals[statuses == 0] + 1 - # modification for censored observations + unique_times <- sort(unique(original_times)) + which_el <- matrix(c(1:n, match(original_times, unique_times)), + nrow = n) + chf_preds <- + predict(explainer, times = unique_times, output_type = "chf") + cox_snell_residuals <- chf_preds[which_el] + martingale_residuals <- statuses - cox_snell_residuals + deviance_residuals <- sign(martingale_residuals) * + sqrt(-2 * ( + martingale_residuals + statuses * + log(statuses - martingale_residuals) + )) + # cox_snell_residuals[statuses == 0] <- + # cox_snell_residuals[statuses == 0] + 1 + # modification for censored observations - result <- cbind( - data.frame( - "time" = original_times, - "status" = factor(statuses), - "cox_snell_residuals" = cox_snell_residuals, - "martingale_residuals" = martingale_residuals, - "deviance_residuals" = deviance_residuals, - "label" = explainer$label - ), - explainer$data - ) - - res <- list(result = result) - class(res) <- c("model_diagnostics_survival", class(res)) - res - }, - stop("Type should be either `chf` or `risk`") + result <- cbind( + data.frame( + "time" = original_times, + "status" = factor(statuses), + "cox_snell_residuals" = cox_snell_residuals, + "martingale_residuals" = martingale_residuals, + "deviance_residuals" = deviance_residuals, + "label" = explainer$label + ), + explainer$data ) - + res <- list(result = result) + class(res) <- c("model_diagnostics_survival", class(res)) + res } diff --git a/man/model_diagnostics.surv_explainer.Rd b/man/model_diagnostics.surv_explainer.Rd index 2e50e8a9..0d0fcc96 100644 --- a/man/model_diagnostics.surv_explainer.Rd +++ b/man/model_diagnostics.surv_explainer.Rd @@ -11,10 +11,6 @@ model_diagnostics(explainer, ...) } \arguments{ \item{explainer}{an explainer object - model preprocessed by the \code{explain()} function} - -\item{...}{other parameters passed to \code{DALEX::model_diagnostics} if \code{output_type == "risk"}, otherwise passed to internal functions} - -\item{output_type}{either \code{"chf"}, \code{"survival"} or \code{"risk"} the type of survival model output that should be used for explanations. If \code{"chf"} or \code{"survival"} the explanations are based on survival residuals. Otherwise the scalar risk predictions are used by the \code{DALEX::model_diagnostics} function.} } \value{ An object of class \code{c("model_diagnostics_survival")}. It's a list with the explanations in the \code{result} element. From b89958f9d524106aa965e1db549f416f8225f8cc Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 18:42:31 +0200 Subject: [PATCH 181/207] fix param name --- R/plot_model_diagnostics_survival.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/plot_model_diagnostics_survival.R b/R/plot_model_diagnostics_survival.R index 6e9cfed8..b6516be9 100644 --- a/R/plot_model_diagnostics_survival.R +++ b/R/plot_model_diagnostics_survival.R @@ -63,7 +63,7 @@ plot.model_diagnostics_survival <- function(x, labs(title = title, subtitle = subtitle, x = xvariable, - y = paste(type, "residuals")) + y = paste(plot_type, "residuals")) if (smooth) pl <- pl + geom_smooth(se = FALSE, color = colors[3], alpha = 0.5) pl From 4e7ae04643b6b762bfbc3a498ccc774c3c65f57c Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 18:42:38 +0200 Subject: [PATCH 182/207] add tests for model_diagnostics --- tests/testthat/test-model_diagnostics.R | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 tests/testthat/test-model_diagnostics.R diff --git a/tests/testthat/test-model_diagnostics.R b/tests/testthat/test-model_diagnostics.R new file mode 100644 index 00000000..66330f45 --- /dev/null +++ b/tests/testthat/test-model_diagnostics.R @@ -0,0 +1,24 @@ +test_that("model_diagnostics for survival residuals works", { + veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] + + cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) + rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + + cph_exp <- explain(cph, verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + + md_cph <- model_diagnostics(cph_exp) + expect_s3_class(md_cph, "model_diagnostics_survival") + expect_true(all(md_cph$result$time == cph_exp$y[,1])) + expect_equal(ncol(md_cph$result) - ncol(cph_exp$data), 6) + + plot(md_rsf) + plot(md_rsf, plot_type = "martingale") + plot(md_rsf, plot_type = "Cox-Snell") + plot(md_cph, md_rsf, xvariable = "age") + plot(md_cph, md_rsf, smooth = FALSE) + + expect_error(plot(md_cph, md_rsf, xvariable = "nonexistent")) + expect_error(plot(md_cph, md_rsf, plot_type = "nonexistent")) + expect_error(plot(md_cph, exp)) +}) From b963d81fb962b1071e0a659fdfc25fff5e7592f1 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 19:11:42 +0200 Subject: [PATCH 183/207] remove unused arguments --- R/model_diagnostics.R | 7 ++----- man/model_diagnostics.surv_explainer.Rd | 4 ++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/R/model_diagnostics.R b/R/model_diagnostics.R index 20970a77..32f04dda 100644 --- a/R/model_diagnostics.R +++ b/R/model_diagnostics.R @@ -37,14 +37,11 @@ #' } #' @rdname model_diagnostics.surv_explainer #' @export -model_diagnostics <- function(explainer, ...) UseMethod("model_diagnostics", explainer) +model_diagnostics <- function(explainer) UseMethod("model_diagnostics", explainer) #' @rdname model_diagnostics.surv_explainer #' @export -model_diagnostics.surv_explainer <- function(explainer, - ..., - output_type = "chf") { - +model_diagnostics.surv_explainer <- function(explainer) { test_explainer(explainer, has_data = TRUE, has_y = TRUE, has_chf = TRUE, function_name = "model_diagnostics") n <- nrow(explainer$data) diff --git a/man/model_diagnostics.surv_explainer.Rd b/man/model_diagnostics.surv_explainer.Rd index 0d0fcc96..f8caf68b 100644 --- a/man/model_diagnostics.surv_explainer.Rd +++ b/man/model_diagnostics.surv_explainer.Rd @@ -5,9 +5,9 @@ \alias{model_diagnostics.surv_explainer} \title{Dataset Level Model Diagnostics} \usage{ -model_diagnostics(explainer, ...) +model_diagnostics(explainer) -\method{model_diagnostics}{surv_explainer}(explainer, ..., output_type = "chf") +\method{model_diagnostics}{surv_explainer}(explainer) } \arguments{ \item{explainer}{an explainer object - model preprocessed by the \code{explain()} function} From fe431044abb2ac76fa873eb140214c881ef8fad3 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 19:11:58 +0200 Subject: [PATCH 184/207] fix tests --- tests/testthat/test-model_diagnostics.R | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/testthat/test-model_diagnostics.R b/tests/testthat/test-model_diagnostics.R index 66330f45..5b27591d 100644 --- a/tests/testthat/test-model_diagnostics.R +++ b/tests/testthat/test-model_diagnostics.R @@ -8,6 +8,7 @@ test_that("model_diagnostics for survival residuals works", { rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) md_cph <- model_diagnostics(cph_exp) + md_rsf <- model_diagnostics(rsf_ranger_exp) expect_s3_class(md_cph, "model_diagnostics_survival") expect_true(all(md_cph$result$time == cph_exp$y[,1])) expect_equal(ncol(md_cph$result) - ncol(cph_exp$data), 6) From 91ad9d45390b3c2f9d937b5c50ef3d840c3895a4 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 19:12:15 +0200 Subject: [PATCH 185/207] add explainer for flexsurv --- NAMESPACE | 2 + R/explain.R | 103 +++++++++++++++++++++++++++++++++++++---- R/model_info.R | 11 +++++ man/surv_model_info.Rd | 3 ++ 4 files changed, 110 insertions(+), 9 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 4b3c66b1..a9ca8851 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -3,6 +3,7 @@ S3method(explain,LearnerSurv) S3method(explain,coxph) S3method(explain,default) +S3method(explain,flexsurvreg) S3method(explain,model_fit) S3method(explain,ranger) S3method(explain,rfsrc) @@ -45,6 +46,7 @@ S3method(surv_model_info,LearnerSurv) S3method(surv_model_info,coxph) S3method(surv_model_info,cph) S3method(surv_model_info,default) +S3method(surv_model_info,flexsurvreg) S3method(surv_model_info,model_fit) S3method(surv_model_info,ranger) S3method(surv_model_info,rfsrc) diff --git a/R/explain.R b/R/explain.R index a92bcbdb..207b0e03 100644 --- a/R/explain.R +++ b/R/explain.R @@ -919,18 +919,12 @@ explain.sksurv <- function(model, if (is.null(predict_function)) { if (reticulate::py_has_attr(model, "predict")) { - predict_function <- function(model, newdata, times) model$predict(newdata) + predict_function <- function(model, newdata) model$predict(newdata) attr(predict_function, "verbose_info") <- "predict from scikit-survival will be used" + attr(predict_function, "is.default") <- TRUE } else { - predict_function <- function(model, newdata, times) { - rowSums(predict_cumulative_hazard_function(model, newdata, times)) - } - attr(predict_function, "verbose_info") <- "sum over the predict_cumulative_hazard_function will be used" - } - attr(predict_function, "is.default") <- TRUE - attr(predict_function, "use.times") <- TRUE - } else { attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) + } } if (!is.null(data) & any(colnames(data) != model$feature_names_in_)) { @@ -964,6 +958,97 @@ explain.sksurv <- function(model, } +#' @export +explain.flexsurvreg <- function(model, + data = NULL, + y = NULL, + predict_function = NULL, + predict_function_target_column = NULL, + residual_function = NULL, + weights = NULL, + ..., + label = NULL, + verbose = TRUE, + colorize = !isTRUE(getOption("knitr.in.progress")), + model_info = NULL, + type = NULL, + times = NULL, + times_generation = "quantiles", + predict_survival_function = NULL, + predict_cumulative_hazard_function = NULL) { + if (is.null(label)) { + label <- class(model)[1] + attr(label, "verbose_info") <- "default" + } + + if (is.null(predict_survival_function)) { + predict_survival_function <- function(model, newdata, times){ + raw_preds <- predict(model, newdata = newdata, times = times, type = "survival") + preds <- do.call(rbind, lapply(raw_preds[[1]], function(x) t(x[".pred_survival"]))) + rownames(preds) <- NULL + preds + } + attr(predict_survival_function, "verbose_info") <- "predict.flexsurvreg with type = 'survival' will be used" + attr(predict_survival_function, "is.default") <- TRUE + } else { + attr(predict_survival_function, "verbose_info") <- deparse(substitute(predict_survival_function)) + } + + if (is.null(predict_cumulative_hazard_function)) { + predict_cumulative_hazard_function <- function(model, newdata, times){ + raw_preds <- predict(model, newdata = newdata, times = times, type = "cumhaz") + preds <- do.call(rbind, lapply(raw_preds[[1]], function(x) t(x[".pred_cumhaz"]))) + rownames(preds) <- NULL + preds + } + attr(predict_cumulative_hazard_function, "verbose_info") <- "predict.flexsurvreg with type = 'cumhaz' will be used" + attr(predict_cumulative_hazard_function, "is.default") <- TRUE + } else { + attr(predict_cumulative_hazard_function, "verbose_info") <- deparse(substitute(predict_cumulative_hazard_function)) + } + + if (is.null(predict_function)) { + predict_function <- function(model, newdata){ + predict(model, newdata = newdata, type = "link")[[".pred_link"]] + } + attr(predict_function, "verbose_info") <- "predict.flexsurvreg with type = 'link' will be used" + attr(predict_function, "is.default") <- TRUE + } else { + attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) + } + + possible_data <- model.frame(parmodel) + if (is.null(data)) { + data <- possible_data[,-c(1, ncol(possible_data))] + attr(data, "verbose_info") <- "extracted" + } + + if (is.null(y)) { + y <- possible_data[,1] + attr(y, "verbose_info") <- "extracted" + } + + explain_survival( + model, + data = data, + y = y, + predict_function = predict_function, + predict_function_target_column = predict_function_target_column, + residual_function = residual_function, + weights = weights, + ... = ..., + label = label, + verbose = verbose, + colorize = colorize, + model_info = model_info, + type = type, + times = times, + times_generation = times_generation, + predict_survival_function = predict_survival_function, + predict_cumulative_hazard_function = predict_cumulative_hazard_function + ) +} + verbose_cat <- function(..., is.default = NULL, verbose = TRUE) { if (verbose) { if (!is.null(is.default)) { diff --git a/R/model_info.R b/R/model_info.R index 0048421c..0af0c237 100644 --- a/R/model_info.R +++ b/R/model_info.R @@ -124,6 +124,17 @@ surv_model_info.sksurv <- function(model, ...) { model_info } +#' @rdname surv_model_info +#' @export +surv_model_info.flexsurvreg <- function(model, ...) { + type <- "survival" + package <- "flexsurv" + ver <- get_pkg_ver_safe(package) + model_info <- list(package = package, ver = ver, type = type) + class(model_info) <- "model_info" + model_info +} + #' @rdname surv_model_info #' @export surv_model_info.default <- function(model, ...) { diff --git a/man/surv_model_info.Rd b/man/surv_model_info.Rd index 335011f6..8b672a3f 100644 --- a/man/surv_model_info.Rd +++ b/man/surv_model_info.Rd @@ -9,6 +9,7 @@ \alias{surv_model_info.cph} \alias{surv_model_info.LearnerSurv} \alias{surv_model_info.sksurv} +\alias{surv_model_info.flexsurvreg} \alias{surv_model_info.default} \title{Extract additional information from the model} \usage{ @@ -28,6 +29,8 @@ surv_model_info(model, ...) \method{surv_model_info}{sksurv}(model, ...) +\method{surv_model_info}{flexsurvreg}(model, ...) + \method{surv_model_info}{default}(model, ...) } \arguments{ From eab221f90639ce81273f8560a14dbc6e10a07af9 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 19:17:48 +0200 Subject: [PATCH 186/207] add flexsruv to suggests --- DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION b/DESCRIPTION index 36289b7f..3bccf60a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,6 +31,7 @@ Imports: Suggests: censored (>= 0.2.0), covr, + flexsurv, gbm, generics, glmnet, From 529624ea732f6211d911c8a5a7587f9c746acbb8 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 19:17:55 +0200 Subject: [PATCH 187/207] add simple test for explainer --- tests/testthat/test-explain.R | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-explain.R index a702c2c0..f4a9f21d 100644 --- a/tests/testthat/test-explain.R +++ b/tests/testthat/test-explain.R @@ -9,7 +9,7 @@ test_that("coxph prediction functions work correctly", { x = TRUE, y = TRUE ) - + cox_wrong <- survival::coxph( survival::Surv(rtime, recur) ~ ., data = rotterdam[, !colnames(rotterdam) %in% c("year", "dtime", "death")] @@ -66,11 +66,11 @@ test_that("coxph prediction functions work correctly", { explain(cox_rotterdam_rec, predict_survival_function = pec::predictSurvProb, verbose = FALSE) - + explain(cox_rotterdam_rec, predict_cumulative_hazard_function = pec::predictSurvProb, verbose = FALSE) - + explain(cox_rotterdam_rec, predict_function = predict, verbose = FALSE) @@ -196,7 +196,7 @@ test_that("ranger prediction functions work correctly", { y = survival::Surv(rotterdam$rtime, rotterdam$recur), predict_function = predict, verbose = FALSE) - + }) @@ -307,7 +307,7 @@ test_that("rsfrc prediction functions work correctly", { y = survival::Surv(colon$time, colon$status), predict_function = predict, verbose = FALSE) - + }) @@ -407,6 +407,15 @@ test_that("default methods for creating explainers work correctly", { expect_equal(cph_rms_exp$label, "coxph", ignore_attr = TRUE) + ### flexsurv::flexsurvreg ### + fsr <- flexsurv::flexsurvreg(survival::Surv(time, status) ~ + trt + celltype + karno + diagtime + age + prior, + data = veteran, dist = "exp") + fsr_exp <- explain(fsr, verbose = FALSE) + expect_s3_class(fsr_exp, c("surv_explainer", "explainer")) + expect_equal(fsr_exp$label, "flexsurvreg", ignore_attr = TRUE) + + ### parsnip::boost_tree ### library(censored, quietly = TRUE) @@ -528,7 +537,7 @@ test_that("warnings in explain_survival work correctly", { verbose = FALSE, type = "weird type", model_info = custom_info)) - + expect_error(explain_survival(cph, data = veteran, survival::Surv(veteran$time, veteran$status), From d0d5a61f225892369abfb65932778d6530c62f4a Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 19:19:01 +0200 Subject: [PATCH 188/207] fix typo --- R/explain.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/explain.R b/R/explain.R index 207b0e03..879c91f6 100644 --- a/R/explain.R +++ b/R/explain.R @@ -1017,7 +1017,7 @@ explain.flexsurvreg <- function(model, attr(predict_function, "verbose_info") <- deparse(substitute(predict_function)) } - possible_data <- model.frame(parmodel) + possible_data <- model.frame(model) if (is.null(data)) { data <- possible_data[,-c(1, ncol(possible_data))] attr(data, "verbose_info") <- "extracted" From e042a2ea29ff9e3f86e17161ffa265987f1441d4 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Thu, 24 Aug 2023 20:47:53 +0200 Subject: [PATCH 189/207] add tests --- tests/testthat/test-explain.R | 46 +++++++++++++++++++++++++ tests/testthat/test-model_diagnostics.R | 1 + 2 files changed, 47 insertions(+) diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-explain.R index f4a9f21d..dd7d8548 100644 --- a/tests/testthat/test-explain.R +++ b/tests/testthat/test-explain.R @@ -312,6 +312,52 @@ test_that("rsfrc prediction functions work correctly", { }) +test_that("flexsurvreg prediction functions work correctly", { + rotterdam <- survival::rotterdam + rotterdam$pid <- NULL + fsr_rotterdam <- + flexsurv::flexsurvreg( + survival::Surv(rtime, recur) ~ age + meno + size + grade, + data = rotterdam, + dist = "exp" + ) + + fsr_explainer <- explain(fsr_rotterdam, verbose = FALSE) + + times <- fsr_explainer$times + + fsr_explainer$predict_survival_function(fsr_rotterdam, rotterdam[c(1, 2, 3), ], times) + sf_preds <- predict(fsr_explainer, rotterdam[c(1, 2, 3), ], times, output_type = "survival") + expect_true(inherits(sf_preds, "matrix")) + expect_equal(dim(sf_preds), c(3, length(times))) + + fsr_explainer$predict_cumulative_hazard_function(fsr_rotterdam, rotterdam[c(1, 2, 3), ], times) + chf_preds <- predict(fsr_explainer, rotterdam[c(1, 2, 3), ], times, output_type = "chf") + expect_true(inherits(chf_preds, "matrix")) + expect_equal(dim(chf_preds), c(3, length(times))) + + fsr_explainer$predict_function(fsr_rotterdam, rotterdam[c(1, 2, 3), ]) + risk_preds <- predict(fsr_explainer, rotterdam[c(1, 2, 3), ], output_type = "risk") + expect_true(is.numeric(risk_preds)) + expect_equal(length(risk_preds), 3) + + + # test manually setting predict survival function / chf / predict function + # the functions DO NOT WORK this is just a test if everything is set properly + explain(fsr_rotterdam, + predict_survival_function = pec::predictSurvProb, + verbose = FALSE) + + explain(fsr_rotterdam, + predict_cumulative_hazard_function = pec::predictSurvProb, + verbose = FALSE) + + explain(fsr_rotterdam, + predict_function = predict, + verbose = FALSE) +}) + + test_that("automated `y` and `data` sourcing works", { rotterdam <- survival::rotterdam diff --git a/tests/testthat/test-model_diagnostics.R b/tests/testthat/test-model_diagnostics.R index 5b27591d..9cd50aa4 100644 --- a/tests/testthat/test-model_diagnostics.R +++ b/tests/testthat/test-model_diagnostics.R @@ -12,6 +12,7 @@ test_that("model_diagnostics for survival residuals works", { expect_s3_class(md_cph, "model_diagnostics_survival") expect_true(all(md_cph$result$time == cph_exp$y[,1])) expect_equal(ncol(md_cph$result) - ncol(cph_exp$data), 6) + expect_output(print(md_cph)) plot(md_rsf) plot(md_rsf, plot_type = "martingale") From de8a3348b6ed7fe28ae1d2e42e82c2a60402f78a Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 25 Aug 2023 00:02:17 +0200 Subject: [PATCH 190/207] add new times generation method with the median survial time --- R/explain.R | 51 ++++++++++++++++++++++++++--------------- man/explain_survival.Rd | 4 ++-- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/R/explain.R b/R/explain.R index 879c91f6..71e57883 100644 --- a/R/explain.R +++ b/R/explain.R @@ -24,7 +24,7 @@ #' @param type type of a model, by default `"survival"` #' #' @param times numeric, a vector of times at which the survival function and cumulative hazard function should be evaluated for calculations -#' @param times_generation either `"uniform"` or `"quantiles"`. Sets the way of generating the vector of times based on times provided in the `y` parameter. If `"uniform"` the vector contains 50 equally spaced points between the minimum and maximum observed times; if `"quantiles"` the vector contains 50 points between 0th and 98th percentiles of observed times. Ignored if `times` is not `NULL`. +#' @param times_generation either `"survival_quantiles"`, `"uniform"` or `"quantiles"`. Sets the way of generating the vector of times based on times provided in the `y` parameter. If `"survival_quantiles"` the vector contains unique time points out of 50 uniformly distributed survival quantiles of order (0, ..., 1-min(KM_SF)) based on the Kaplan-Meier estimator, and additional time point being the median survival time (if possible); if `"uniform"` the vector contains 50 equally spaced time points between the minimum and maximum observed times; if `"quantiles"` the vector contains unique time points out of 50 time points between 0th and 98th percentiles of observed times. Ignored if `times` is not `NULL`. #' @param predict_survival_function function taking 3 arguments `model`, `newdata` and `times`, and returning a matrix whose each row is a survival function evaluated at `times` for one observation from `newdata` #' @param predict_cumulative_hazard_function function taking 3 arguments `model`, `newdata` and `times`, and returning a matrix whose each row is a cumulative hazard function evaluated at `times` for one observation from `newdata` #' @@ -120,7 +120,7 @@ explain_survival <- model_info = NULL, type = NULL, times = NULL, - times_generation = "quantiles", + times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL) { if (!colorize) { @@ -228,22 +228,35 @@ explain_survival <- } # verbose times + median_survival_time <- NULL if (is.null(times)) { if (!is.null(y)) { switch(times_generation, + "survival_quantiles" = { + survobj <- Surv(y[,1], y[,2]) + sfit <- survival::survfit(survobj ~ 1, type="kaplan-meier") + + min_sf <- min(sfit$surv) + quantiles <- 1 - seq(1, min_sf, length.out=50) + + if(min_sf <= 0.5) median_survival_time <- as.numeric(quantile(sfit, 0.5)$quantile) + raw_times <- quantile(sfit, quantiles)$quantile + + times <- sort(na.omit(unique(c(raw_times, median_survival_time)))) + method_description <- "uniformly distributed Kaplan-Meier survival quantiles (0, ..., 1-min(KM_SF))" + }, "uniform" = { times <- seq(min(y[, 1]), max(y[, 1]), length.out = 50) - method_description <- "50 uniformly distributed time points from min to max" + method_description <- "uniformly distributed time points from min to max observed time" }, "quantiles" = { - times <- quantile(y[, 1], seq(0, 0.99, 0.02)) - method_description <- "50 time points being consecutive quantiles (0.00, 0.02, ..., 0.98)" + times <- quantile(y[, 1], seq(0, 0.98, 0.02)) + method_description <- "time points being consecutive quantiles (0.00, 0.02, ..., 0.98) of observed times" }, - stop("times_generation needs to be 'uniform' or 'quantiles'") + stop("times_generation needs to be 'survival_quantiles', 'uniform' or 'quantiles'") ) times <- sort(unique(times)) - times_stats <- get_times_stats(times) - verbose_cat(" -> times : ", times_stats[1], "unique time points", ", min =", times_stats[2], ", mean =", times_stats[3], ", median =", times_stats[4], ", max =", times_stats[5], verbose = verbose) + verbose_cat(" -> times : ", length(times), "unique time points", get_times_stats(times, median_survival_time), verbose = verbose) verbose_cat(" -> times : ", "(", color_codes$yellow_start, paste("generated from y as", method_description), color_codes$yellow_end, ")", verbose = verbose) } else { verbose_cat(" -> times : not specified and automatic generation is impossible ('y' is NULL)! (", color_codes$red_start, "WARNING", color_codes$red_end, ")", verbose = verbose) @@ -252,7 +265,7 @@ explain_survival <- } else { times <- sort(unique(times)) times_stats <- get_times_stats(times) - verbose_cat(" -> times : ", times_stats[1], "unique time points", ", min =", times_stats[2], ", mean =", times_stats[3], ", median =", times_stats[4], ", max =", times_stats[5], verbose = verbose) + verbose_cat(" -> times : ", length(times), "unique time points", get_times_stats(times, median_survival_time), verbose = verbose) } # verbose predict function @@ -350,6 +363,7 @@ explain_survival <- model_info = model_info, type = type, times = times, + median_survival_time = median_survival_time, predict_survival_function = predict_survival_function, predict_cumulative_hazard_function = predict_cumulative_hazard_function, ... = ... @@ -467,7 +481,7 @@ explain.coxph <- function(model, model_info = NULL, type = NULL, times = NULL, - times_generation = "quantiles", + times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL) { if (is.null(data)) { @@ -556,7 +570,7 @@ explain.ranger <- function(model, model_info = NULL, type = NULL, times = NULL, - times_generation = "quantiles", + times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL) { if (is.null(predict_survival_function)) { @@ -631,7 +645,7 @@ explain.rfsrc <- function(model, model_info = NULL, type = NULL, times = NULL, - times_generation = "quantiles", + times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL) { if (is.null(label)) { @@ -722,7 +736,7 @@ explain.model_fit <- function(model, model_info = NULL, type = NULL, times = NULL, - times_generation = "quantiles", + times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL) { if (is.null(label)) { @@ -803,7 +817,7 @@ explain.LearnerSurv <- function(model, model_info = NULL, type = NULL, times = NULL, - times_generation = "quantiles", + times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL) { if (is.null(label)) { @@ -883,7 +897,7 @@ explain.sksurv <- function(model, model_info = NULL, type = NULL, times = NULL, - times_generation = "quantiles", + times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL){ if (is.null(label)) { @@ -973,7 +987,7 @@ explain.flexsurvreg <- function(model, model_info = NULL, type = NULL, times = NULL, - times_generation = "quantiles", + times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL) { if (is.null(label)) { @@ -1060,8 +1074,9 @@ verbose_cat <- function(..., is.default = NULL, verbose = TRUE) { } } -get_times_stats <- function(times) { - c(length(times), min(times), mean(times), median(times), max(times)) +get_times_stats <- function(times, median_survival_time=NULL) { + median_survival_time_str <- ifelse(is.null(median_survival_time), "", paste0(" , median survival time = ", median_survival_time)) + paste0(", min = ", min(times), median_survival_time_str, " , max = ", max(times)) } # diff --git a/man/explain_survival.Rd b/man/explain_survival.Rd index afdd90e1..c8500af6 100644 --- a/man/explain_survival.Rd +++ b/man/explain_survival.Rd @@ -21,7 +21,7 @@ explain_survival( model_info = NULL, type = NULL, times = NULL, - times_generation = "quantiles", + times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL ) @@ -87,7 +87,7 @@ explain( \item{times}{numeric, a vector of times at which the survival function and cumulative hazard function should be evaluated for calculations} -\item{times_generation}{either \code{"uniform"} or \code{"quantiles"}. Sets the way of generating the vector of times based on times provided in the \code{y} parameter. If \code{"uniform"} the vector contains 50 equally spaced points between the minimum and maximum observed times; if \code{"quantiles"} the vector contains 50 points between 0th and 98th percentiles of observed times. Ignored if \code{times} is not \code{NULL}.} +\item{times_generation}{either \code{"survival_quantiles"}, \code{"uniform"} or \code{"quantiles"}. Sets the way of generating the vector of times based on times provided in the \code{y} parameter. If \code{"survival_quantiles"} the vector contains unique time points out of 50 uniformly distributed survival quantiles of order (0, ..., 1-min(KM_SF)) based on the Kaplan-Meier estimator, and additional time point being the median survival time (if possible); if \code{"uniform"} the vector contains 50 equally spaced time points between the minimum and maximum observed times; if \code{"quantiles"} the vector contains unique time points out of 50 time points between 0th and 98th percentiles of observed times. Ignored if \code{times} is not \code{NULL}.} \item{predict_survival_function}{function taking 3 arguments \code{model}, \code{newdata} and \code{times}, and returning a matrix whose each row is a survival function evaluated at \code{times} for one observation from \code{newdata}} From 1eaab03aec5b2b26079c9e4b6911b5b6427bb616 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 25 Aug 2023 00:03:59 +0200 Subject: [PATCH 191/207] fixes for new times generation method --- R/model_profile.R | 1 + R/model_profile_2d.R | 3 ++- R/plot_model_profile_2d.R | 18 +++++++++++++----- R/plot_model_profile_survival.R | 15 ++++++++++----- R/plot_predict_profile_survival.R | 9 ++++++--- R/predict_profile.R | 1 + man/plot.model_profile_2d_survival.Rd | 6 +++--- man/plot.model_profile_survival.Rd | 6 +++--- man/plot.predict_profile_survival.Rd | 2 +- tests/testthat/test-explain.R | 2 +- tests/testthat/test-model_profile.R | 18 +++++++++--------- vignettes/pdp.Rmd | 20 ++++++++++++-------- 12 files changed, 62 insertions(+), 39 deletions(-) diff --git a/R/model_profile.R b/R/model_profile.R index f0a536c9..0363c605 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -137,6 +137,7 @@ model_profile.surv_explainer <- function(explainer, center = center ) class(ret) <- c("model_profile_survival", "list") + ret$median_survival_time <- explainer$median_survival_time ret$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] ret$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] ret diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index b99a45bd..42f4e7fb 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -114,7 +114,8 @@ model_profile_2d.surv_explainer <- function(explainer, result = result, eval_times = unique(result$`_times_`), variables = variables, - type = type + type = type, + median_survival_time = explainer$median_survival_time ) class(ret) <- c("model_profile_2d_survival", "list") return(ret) diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index fd3337b5..d7384cab 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -6,7 +6,7 @@ #' @param x an object of class `model_profile_2d_survival` to be plotted #' @param ... additional objects of class `model_profile_2d_survival` to be plotted together #' @param variables list of character vectors of length 2, names of pairs of variables to be plotted -#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If `NULL` (default) then the median time from the explainer object is used. +#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If `NULL` (default) then the median survival time (if available) or the median time from the explainer object is used. #' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately #' @param facet_ncol number of columns for arranging subplots #' @param title character, title of the plot. `'default'` automatically generates either "2D partial dependence survival profiles" or "2D accumulated local effects survival profiles" depending on the explanation type. @@ -31,14 +31,14 @@ #' ) #' ) #' head(cph_model_profile_2d$result) -#' plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 103) +#' plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = cph_exp$times[20]) #' #' cph_model_profile_2d_ale <- model_profile_2d(cph_exp, #' variables = list(c("age", "karno")), #' type = "accumulated" #' ) #' head(cph_model_profile_2d_ale$result) -#' plot(cph_model_profile_2d_ale, times = c(8, 103), marginalize_over_time = TRUE) +#' plot(cph_model_profile_2d_ale, times = cph_exp$times[c(10, 20)], marginalize_over_time = TRUE) #' } #' #' @export @@ -122,8 +122,16 @@ prepare_model_profile_2d_plots <- function(x, subtitle, colors) { if (is.null(times)) { - times <- quantile(x$eval_times, p = 0.5, type = 1) - warning("Plot will be prepared for the median time point from the explainer's `times` vector. For another time point, set the value of `times`.") + if (marginalize_over_time){ + times <- x$eval_times + warning("Plot will be prepared with marginalization over all time points from the explainer's `times` vector. \nFor subset of time points, set the value of `times`.") + } else if (!is.null(x$median_survival_time)){ + times <- x$median_survival_time + warning("Plot will be prepared for the median survial time. For another time point, set the value of `times`.") + } else { + times <- quantile(x$eval_times, p = 0.5, type = 1) + warning("Plot will be prepared for the median time point from the explainer's `times` vector. For another time point, set the value of `times`.") + } } if (!marginalize_over_time && length(times) > 1) { diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index f74425de..7e01b664 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -10,7 +10,7 @@ #' @param variable_type character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all. Only used when `geom = "time"`. #' @param facet_ncol number of columns for arranging subplots. Only used when `geom = "time"`. #' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots. Only used when `geom = "time"`. -#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If `NULL` (default) then the median time from the explainer object is used. Only used when `geom = "variable"` and `marginalize_over_time = FALSE`. +#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If `NULL` (default) then the median survival time (if available) or the median time from the explainer object is used. Only used when `geom = "variable"` and `marginalize_over_time = FALSE`. #' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately. Only used when `geom = "variable"`. #' @param plot_type character, one of `"pdp"`, `"ice"`, `"pdp+ice"`, or `NULL` (default). If `NULL` then the type of plot is chosen automatically based on the number of variables to be plotted. Only used when `geom = "variable"`. #' @param title character, title of the plot @@ -39,9 +39,11 @@ #' #' plot(m_prof, geom = "variable", variables = "karno", plot_type = "pdp+ice") #' -#' plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "karno", plot_type = "pdp+ice") +#' plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], +#' variables = "karno", plot_type = "pdp+ice") #' -#' plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "trt", plot_type = "pdp+ice") +#' plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], +#' variables = "trt", plot_type = "pdp+ice") #' } #' #' @export @@ -207,9 +209,12 @@ plot2_mp <- function(x, if (marginalize_over_time){ times <- x$eval_times warning("Plot will be prepared with marginalization over all time points from the explainer's `times` vector. \nFor subset of time points, set the value of `times`.") - } else{ + } else if (!is.null(x$median_survival_time)){ + times <- x$median_survival_time + warning("Plot will be prepared for the median survial time. For another time point, set the value of `times`.") + } else { times <- quantile(x$eval_times, p = 0.5, type = 1) - warning("Plot will be prepared for the median time point from the explainer's `times` vector. \nFor another time point, set the value of `times`.") + warning("Plot will be prepared for the median time point from the explainer's `times` vector. For another time point, set the value of `times`.") } } diff --git a/R/plot_predict_profile_survival.R b/R/plot_predict_profile_survival.R index 8bbf930b..1f0d0d7d 100644 --- a/R/plot_predict_profile_survival.R +++ b/R/plot_predict_profile_survival.R @@ -10,7 +10,7 @@ #' @param variable_type character, either `"numerical"`, `"categorical"` or `NULL` (default), select only one type of variable for plotting, or leave `NULL` for all. Only used when `geom = "time"`. #' @param facet_ncol number of columns for arranging subplots. Only used when `geom = "time"`. #' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots. Only used when `geom = "time"`. -#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If `NULL` (default) then the median time from the explainer object is used. Only used when `geom = "variable"` and `marginalize_over_time = FALSE`. +#' @param times numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If `NULL` (default) then the median survival time (if available) or the median time from the explainer object is used. Only used when `geom = "variable"` and `marginalize_over_time = FALSE`. #' @param marginalize_over_time logical, if `TRUE` then the profile is calculated for all times and then averaged over time, if `FALSE` (default) then the profile is calculated for each time separately. Only used when `geom = "variable"`. #' @param title character, title of the plot #' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels @@ -171,9 +171,12 @@ plot2_cp <- function(x, if (marginalize_over_time){ times <- x$eval_times warning("Plot will be prepared with marginalization over all time points from the explainer's `times` vector. \nFor subset of time points, set the value of `times`.") - } else{ + } else if (!is.null(x$median_survival_time)){ + times <- x$median_survival_time + warning("Plot will be prepared for the median survial time. For another time point, set the value of `times`.") + } else { times <- quantile(x$eval_times, p = 0.5, type = 1) - warning("Plot will be prepared for the median time point from the explainer's `times` vector. \nFor another time point, set the value of `times`.") + warning("Plot will be prepared for the median time point from the explainer's `times` vector. For another time point, set the value of `times`.") } } diff --git a/R/predict_profile.R b/R/predict_profile.R index 98094455..235276c8 100644 --- a/R/predict_profile.R +++ b/R/predict_profile.R @@ -90,6 +90,7 @@ predict_profile.surv_explainer <- function(explainer, ... ) class(res) <- c("predict_profile_survival", class(res)) + res$median_survival_time <- explainer$median_survival_time res$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] res$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] return(res) diff --git a/man/plot.model_profile_2d_survival.Rd b/man/plot.model_profile_2d_survival.Rd index a6b7d30f..915ca572 100644 --- a/man/plot.model_profile_2d_survival.Rd +++ b/man/plot.model_profile_2d_survival.Rd @@ -23,7 +23,7 @@ \item{variables}{list of character vectors of length 2, names of pairs of variables to be plotted} -\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used.} +\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If \code{NULL} (default) then the median survival time (if available) or the median time from the explainer object is used.} \item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately} @@ -57,14 +57,14 @@ cph_model_profile_2d <- model_profile_2d(cph_exp, ) ) head(cph_model_profile_2d$result) -plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = 103) +plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = cph_exp$times[20]) cph_model_profile_2d_ale <- model_profile_2d(cph_exp, variables = list(c("age", "karno")), type = "accumulated" ) head(cph_model_profile_2d_ale$result) -plot(cph_model_profile_2d_ale, times = c(8, 103), marginalize_over_time = TRUE) +plot(cph_model_profile_2d_ale, times = cph_exp$times[c(10, 20)], marginalize_over_time = TRUE) } } diff --git a/man/plot.model_profile_survival.Rd b/man/plot.model_profile_survival.Rd index db8d8865..3de1e0ae 100644 --- a/man/plot.model_profile_survival.Rd +++ b/man/plot.model_profile_survival.Rd @@ -37,7 +37,7 @@ \item{numerical_plot_type}{character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots. Only used when \code{geom = "time"}.} -\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used. Only used when \code{geom = "variable"} and \code{marginalize_over_time = FALSE}.} +\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If \code{NULL} (default) then the median survival time (if available) or the median time from the explainer object is used. Only used when \code{geom = "variable"} and \code{marginalize_over_time = FALSE}.} \item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately. Only used when \code{geom = "variable"}.} @@ -78,9 +78,9 @@ plot(m_prof, variables = c("trt", "age"), facet_ncol = 1) plot(m_prof, geom = "variable", variables = "karno", plot_type = "pdp+ice") -plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "karno", plot_type = "pdp+ice") +plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], variables = "karno", plot_type = "pdp+ice") -plot(m_prof, geom = "variable", times = c(1, 2.72), variables = "trt", plot_type = "pdp+ice") +plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], variables = "trt", plot_type = "pdp+ice") } } diff --git a/man/plot.predict_profile_survival.Rd b/man/plot.predict_profile_survival.Rd index 2720321b..b759d2b4 100644 --- a/man/plot.predict_profile_survival.Rd +++ b/man/plot.predict_profile_survival.Rd @@ -36,7 +36,7 @@ \item{numerical_plot_type}{character, either \code{"lines"}, or \code{"contours"} selects the type of numerical variable plots. Only used when \code{geom = "time"}.} -\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the "times" field of the explainer. If \code{NULL} (default) then the median time from the explainer object is used. Only used when \code{geom = "variable"} and \code{marginalize_over_time = FALSE}.} +\item{times}{numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If \code{NULL} (default) then the median survival time (if available) or the median time from the explainer object is used. Only used when \code{geom = "variable"} and \code{marginalize_over_time = FALSE}.} \item{marginalize_over_time}{logical, if \code{TRUE} then the profile is calculated for all times and then averaged over time, if \code{FALSE} (default) then the profile is calculated for each time separately. Only used when \code{geom = "variable"}.} diff --git a/tests/testthat/test-explain.R b/tests/testthat/test-explain.R index dd7d8548..219f32e6 100644 --- a/tests/testthat/test-explain.R +++ b/tests/testthat/test-explain.R @@ -457,7 +457,7 @@ test_that("default methods for creating explainers work correctly", { fsr <- flexsurv::flexsurvreg(survival::Surv(time, status) ~ trt + celltype + karno + diagtime + age + prior, data = veteran, dist = "exp") - fsr_exp <- explain(fsr, verbose = FALSE) + fsr_exp <- explain(fsr, times_generation = "quantiles", verbose = FALSE) expect_s3_class(fsr_exp, c("surv_explainer", "explainer")) expect_equal(fsr_exp$label, "flexsurvreg", ignore_attr = TRUE) diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 4bc7a998..56827ce8 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -33,9 +33,9 @@ test_that("model_profile with type = 'partial' works", { plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "pdp", times = cph_exp$times[1]) plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "ice", times = cph_exp$times[1]) # multiple time points - plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "pdp+ice") - plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "pdp") - plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "ice") + plot(mp_cph_num, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "karno", plot_type = "pdp+ice") + plot(mp_cph_num, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "karno", plot_type = "pdp") + plot(mp_cph_num, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "karno", plot_type = "ice") expect_error(plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "nonexistent", times = cph_exp$times[1])) expect_error(plot(mp_cph_num, geom = "variable", variables = 1, plot_type = "pdp+ice", times = cph_exp$times[1])) @@ -60,10 +60,10 @@ test_that("model_profile with type = 'partial' works", { plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "pdp", times = rsf_ranger_exp$times[1]) plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "ice", times = rsf_ranger_exp$times[1]) # multiple time points - plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "pdp+ice") - plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), marginalize_over_time = T, variables = "celltype", plot_type = "pdp+ice") - plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "pdp") - plot(mp_rsf_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "ice") + plot(mp_rsf_cat, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "celltype", plot_type = "pdp+ice") + plot(mp_rsf_cat, geom = "variable", times = cph_exp$times[c(1, 20)], marginalize_over_time = T, variables = "celltype", plot_type = "pdp+ice") + plot(mp_rsf_cat, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "celltype", plot_type = "pdp") + plot(mp_rsf_cat, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "celltype", plot_type = "ice") expect_s3_class(mp_rsf_cat, "model_profile_survival") @@ -118,7 +118,7 @@ test_that("model_profile with type = 'accumulated' works", { # single time point plot(mp_cph_cat, geom = "variable", variables = "celltype", times=cph_exp$times[1]) # multiple time points - plot(mp_cph_cat, geom = "variable", times = c(4, 80.7), variables = "celltype", plot_type = "ale") + plot(mp_cph_cat, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "celltype", plot_type = "ale") expect_s3_class(mp_cph_cat, "model_profile_survival") expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) @@ -142,7 +142,7 @@ test_that("model_profile with type = 'accumulated' works", { # single time point plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "ale", times=cph_exp$times[1]) # multiple time points - plot(mp_cph_num, geom = "variable", times = c(4, 80.7), variables = "karno", plot_type = "ale") + plot(mp_cph_num, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "karno", plot_type = "ale") expect_s3_class(mp_cph_num, "model_profile_survival") expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) diff --git a/vignettes/pdp.Rmd b/vignettes/pdp.Rmd index 4484b8eb..f42d0382 100644 --- a/vignettes/pdp.Rmd +++ b/vignettes/pdp.Rmd @@ -59,28 +59,32 @@ plot(pdp, variables = c("karno"), numerical_plot_type = "contours") The plots above make use of the time dependent output of survival models, by placing the time dimension on the x-axis. However, for people familiar with Partial Dependence explanations in classification and regression, it might be more intuitive to place the variable values on the x-axis. For this reason, we provide the `geom = "variable"` argument, which can display the explanations without the aspect of time. -To use this function a specific time of interest has to be chosen. This time needs to be one of the values in the `times` field of the explainer. If the automatically generated times do not contain the time of interest, one needs to manually specify the `times` argument when creating the explainer. +To use this function a specific time of interest has to be chosen. This time needs to be one of the values in the `times` field of the explainer. For `times_generation = "survival_quantiles"` (which is the default when creating the explainer) the median survival time point is also available. If the automatically generated times do not contain the time of interest, one needs to manually specify the `times` argument when creating the explainer. -The example below shows the PD explanations for the `karno` variable at time 80. The y-axis represents the mean prediction (survival function), x-axis represents the values of the studied variable. Thin background lines are individual ceteris paribus profiles (otherwise known as ICE profiles). +The example below shows the PD explanations for the `karno` variable at the median survival time. The y-axis represents the mean prediction (survival function), x-axis represents the values of the studied variable. Thin background lines are individual ceteris paribus profiles (otherwise known as ICE profiles). ```{r} -plot(pdp, geom = "variable", variables = "karno", times = 80) +plot(pdp, geom = "variable", variables = "karno", times = exp$median_survival_time) ``` The same plot can be generated for the categorical `celltype` variable. In this case the x-axis represents the different values of the studied variable, boxplots present the distribution of individual ceteris paribus profiles, and the line represents the mean prediction (survival function), which is the PD explanation. ```{r} -plot(pdp, geom = "variable", variables = "celltype", times = 80) +plot(pdp, geom = "variable", variables = "celltype", times = exp$median_survival_time) ``` Of course, the plots can be prepared for multiple time points, at the same time and presented on one plot. +```{r} +selected_times <- c(exp$times[1], exp$median_survival_time, exp$times[length(exp$times)]) +selected_times +``` ```{r} -plot(pdp, geom = "variable", variables = "karno", times = c(1, 80, 151.72)) +plot(pdp, geom = "variable", variables = "karno", times = selected_times) ``` ```{r} -plot(pdp, geom = "variable", variables = "celltype", times = c(1, 80, 151.72)) +plot(pdp, geom = "variable", variables = "celltype", times = selected_times) ``` @@ -94,10 +98,10 @@ pdp_2d_num_cat <- model_profile_2d(exp, variables = list(c("karno", "celltype")) These explanations can be plotted using the plot function. ```{r} -plot(pdp_2d, times = 80) +plot(pdp_2d, times = exp$median_survival_time) ``` ```{r} -plot(pdp_2d_num_cat, times = 80) +plot(pdp_2d_num_cat, times = exp$median_survival_time) ``` From 1e8121556afff8c769023a905becd65a4b68e802 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 25 Aug 2023 00:18:37 +0200 Subject: [PATCH 192/207] fix line length in examples --- man/plot.model_profile_survival.Rd | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/man/plot.model_profile_survival.Rd b/man/plot.model_profile_survival.Rd index 3de1e0ae..45815e11 100644 --- a/man/plot.model_profile_survival.Rd +++ b/man/plot.model_profile_survival.Rd @@ -78,9 +78,11 @@ plot(m_prof, variables = c("trt", "age"), facet_ncol = 1) plot(m_prof, geom = "variable", variables = "karno", plot_type = "pdp+ice") -plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], variables = "karno", plot_type = "pdp+ice") +plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], + variables = "karno", plot_type = "pdp+ice") -plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], variables = "trt", plot_type = "pdp+ice") +plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], + variables = "trt", plot_type = "pdp+ice") } } From 87cff5d75bcf4e1d17fb17730dcb1ebbe6ea6107 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 25 Aug 2023 00:18:53 +0200 Subject: [PATCH 193/207] fix new times generation method --- R/explain.R | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/R/explain.R b/R/explain.R index 71e57883..0fa13e29 100644 --- a/R/explain.R +++ b/R/explain.R @@ -24,7 +24,7 @@ #' @param type type of a model, by default `"survival"` #' #' @param times numeric, a vector of times at which the survival function and cumulative hazard function should be evaluated for calculations -#' @param times_generation either `"survival_quantiles"`, `"uniform"` or `"quantiles"`. Sets the way of generating the vector of times based on times provided in the `y` parameter. If `"survival_quantiles"` the vector contains unique time points out of 50 uniformly distributed survival quantiles of order (0, ..., 1-min(KM_SF)) based on the Kaplan-Meier estimator, and additional time point being the median survival time (if possible); if `"uniform"` the vector contains 50 equally spaced time points between the minimum and maximum observed times; if `"quantiles"` the vector contains unique time points out of 50 time points between 0th and 98th percentiles of observed times. Ignored if `times` is not `NULL`. +#' @param times_generation either `"survival_quantiles"`, `"uniform"` or `"quantiles"`. Sets the way of generating the vector of times based on times provided in the `y` parameter. If `"survival_quantiles"` the vector contains unique time points out of 50 uniformly distributed survival quantiles based on the Kaplan-Meier estimator, and additional time point being the median survival time (if possible); if `"uniform"` the vector contains 50 equally spaced time points between the minimum and maximum observed times; if `"quantiles"` the vector contains unique time points out of 50 time points between 0th and 98th percentiles of observed times. Ignored if `times` is not `NULL`. #' @param predict_survival_function function taking 3 arguments `model`, `newdata` and `times`, and returning a matrix whose each row is a survival function evaluated at `times` for one observation from `newdata` #' @param predict_cumulative_hazard_function function taking 3 arguments `model`, `newdata` and `times`, and returning a matrix whose each row is a cumulative hazard function evaluated at `times` for one observation from `newdata` #' @@ -236,14 +236,16 @@ explain_survival <- survobj <- Surv(y[,1], y[,2]) sfit <- survival::survfit(survobj ~ 1, type="kaplan-meier") + max_sf <- max(sfit$surv[sfit$surv!=1]) # without 1 (for time = 0) min_sf <- min(sfit$surv) - quantiles <- 1 - seq(1, min_sf, length.out=50) + quantiles <- 1 - seq(max_sf, min_sf, length.out=50) if(min_sf <= 0.5) median_survival_time <- as.numeric(quantile(sfit, 0.5)$quantile) raw_times <- quantile(sfit, quantiles)$quantile + raw_times[1] <- min() times <- sort(na.omit(unique(c(raw_times, median_survival_time)))) - method_description <- "uniformly distributed Kaplan-Meier survival quantiles (0, ..., 1-min(KM_SF))" + method_description <- "uniformly distributed survival quantiles based on Kaplan-Meier estimator" }, "uniform" = { times <- seq(min(y[, 1]), max(y[, 1]), length.out = 50) From a8d3ab59700b15864c7e2bb176350f2d1ebfbe94 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 25 Aug 2023 00:27:06 +0200 Subject: [PATCH 194/207] remove typo --- R/explain.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/explain.R b/R/explain.R index 0fa13e29..4231f5e9 100644 --- a/R/explain.R +++ b/R/explain.R @@ -242,7 +242,6 @@ explain_survival <- if(min_sf <= 0.5) median_survival_time <- as.numeric(quantile(sfit, 0.5)$quantile) raw_times <- quantile(sfit, quantiles)$quantile - raw_times[1] <- min() times <- sort(na.omit(unique(c(raw_times, median_survival_time)))) method_description <- "uniformly distributed survival quantiles based on Kaplan-Meier estimator" From a732489a5b3da48db052be752be64954ebfae2b5 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Fri, 25 Aug 2023 00:27:41 +0200 Subject: [PATCH 195/207] improve description of new method --- man/explain_survival.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/explain_survival.Rd b/man/explain_survival.Rd index c8500af6..8d61c636 100644 --- a/man/explain_survival.Rd +++ b/man/explain_survival.Rd @@ -87,7 +87,7 @@ explain( \item{times}{numeric, a vector of times at which the survival function and cumulative hazard function should be evaluated for calculations} -\item{times_generation}{either \code{"survival_quantiles"}, \code{"uniform"} or \code{"quantiles"}. Sets the way of generating the vector of times based on times provided in the \code{y} parameter. If \code{"survival_quantiles"} the vector contains unique time points out of 50 uniformly distributed survival quantiles of order (0, ..., 1-min(KM_SF)) based on the Kaplan-Meier estimator, and additional time point being the median survival time (if possible); if \code{"uniform"} the vector contains 50 equally spaced time points between the minimum and maximum observed times; if \code{"quantiles"} the vector contains unique time points out of 50 time points between 0th and 98th percentiles of observed times. Ignored if \code{times} is not \code{NULL}.} +\item{times_generation}{either \code{"survival_quantiles"}, \code{"uniform"} or \code{"quantiles"}. Sets the way of generating the vector of times based on times provided in the \code{y} parameter. If \code{"survival_quantiles"} the vector contains unique time points out of 50 uniformly distributed survival quantiles based on the Kaplan-Meier estimator, and additional time point being the median survival time (if possible); if \code{"uniform"} the vector contains 50 equally spaced time points between the minimum and maximum observed times; if \code{"quantiles"} the vector contains unique time points out of 50 time points between 0th and 98th percentiles of observed times. Ignored if \code{times} is not \code{NULL}.} \item{predict_survival_function}{function taking 3 arguments \code{model}, \code{newdata} and \code{times}, and returning a matrix whose each row is a survival function evaluated at \code{times} for one observation from \code{newdata}} From 26a05063437dffd338ba102c9141476066c89ab1 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 28 Aug 2023 14:04:57 +0200 Subject: [PATCH 196/207] add model_diagnostics to `NEWS.md` --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 3a649c47..f272c573 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,7 @@ * added Accumulated Local Effects (ALE) explanations (see `model_profile(..., type = "accumulated")`) * added 2-dimensional PDP and ALE plots (see `model_profile_2d()` function) * added `plot2()` function for plotting PDP and ALE explanations without the time dimension +* added diagnostic explanations - residual analysis (see `model_diagnostics()` function) * made improvements on the vignettes for the package (see `vignette("pdp")` and `vignette("global-survshap")`) * increased the test coverage of the pacakge * reduced the number of expensive `requireNamespace()` calls ([#83](https://github.com/ModelOriented/survex/issues/83)) From 2728dbb992be95c6e59825d5328b4aea81abb214 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 28 Aug 2023 14:05:11 +0200 Subject: [PATCH 197/207] add new model_survshap plots --- R/plot_surv_shap.R | 185 +++++++++++++++++++++++++-- man/plot.aggregated_surv_shap.Rd | 15 ++- tests/testthat/test-model_survshap.R | 6 + 3 files changed, 191 insertions(+), 15 deletions(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 5c0c6681..7d0841d6 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -92,7 +92,7 @@ plot.surv_shap <- function(x, #' explanations of survival models created using the `model_survshap()` function. #' #' @param x an object of class `aggregated_surv_shap` to be plotted -#' @param geom character, one of `"importance"`, `"beeswarm"`, or `"profile"`. Type of chart to be plotted; `"importance"` shows the importance of variables over time and aggregated, `"beeswarm"` shows the distribution of SurvSHAP(t) values for variables and observations, `"profile"` shows the dependence of SurvSHAP(t) values on variable values. +#' @param geom character, one of `"importance"`, `"beeswarm"`, `"profile"` or `"curves"`. Type of chart to be plotted; `"importance"` shows the importance of variables over time and aggregated, `"beeswarm"` shows the distribution of SurvSHAP(t) values for variables and observations, `"profile"` shows the dependence of SurvSHAP(t) values on variable values, `"curves"` shows all SurvSHAP(t) curves for selected variable colored by its value or with functional boxplot if `boxplot = TRUE`. #' @param ... additional parameters passed to internal functions #' @param title character, title of the plot #' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for the XXX model (n = YYY)", where XXX is the explainer label and YYY is the number of observations used for calculations @@ -121,6 +121,12 @@ plot.surv_shap <- function(x, #' * `color_variable` - variable used to denote the color, by default equal to `variable` #' #' +#'#' ## `plot.aggregated_surv_shap(geom = "curves")` +#' +#' * `variable` - variable for which SurvSHAP(t) curves are to be plotted, by default first from result data +#' * `boxplot` - whether to plot functional boxplot with marked outliers or all curves colored by variable value +#' +#' #' @examples #' \donttest{ #' veteran <- survival::veteran @@ -154,7 +160,12 @@ plot.surv_shap <- function(x, #' ) #' plot(ranger_global_survshap) #' plot(ranger_global_survshap, geom = "beeswarm") -#' plot(ranger_global_survshap, geom = "profile", color_variable = "karno") +#' plot(ranger_global_survshap, geom = "profile", +#' variable = "age", color_variable = "karno") +#' plot(ranger_global_survshap, geom = "curves", +#' variable = "age") +#' plot(ranger_global_survshap, geom = "curves", +#' variable = "age", boxplot = TRUE) #' } #' #' @export @@ -201,7 +212,14 @@ plot.aggregated_surv_shap <- function(x, subtitle = subtitle, colors = colors ), - stop("`geom` must be one of 'importance', 'beeswarm' or 'profile'") + "curves" = plot_shap_global_curves( + x = x, + ... = ..., + title = title, + subtitle = subtitle, + colors = colors + ), + stop("`geom` must be one of 'importance', 'beeswarm', 'profile' or 'curves'") ) } @@ -256,6 +274,7 @@ plot_shap_global_importance <- function(x, right_plot + patchwork::plot_layout(widths = c(3, 5), guides = "collect") + patchwork::plot_annotation(title = title, subtitle = subtitle) & + theme_default_survex() & theme( legend.position = "top", plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), @@ -322,7 +341,6 @@ plot_shap_global_profile <- function(x, color_variable = NULL, title = "default", subtitle = "default", - max_vars = 7, colors = NULL) { df <- as.data.frame(do.call(rbind, x$aggregate)) @@ -348,10 +366,7 @@ plot_shap_global_profile <- function(x, title <- "Aggregated SurvSHAP(t) profile" } if (!is.null(subtitle) && subtitle == "default") { - subtitle <- paste0( - "created for the ", label, " model ", - "(n=", x$n_observations, ")" - ) + subtitle <- paste0("created for the ", unique(variable), " variable") } p <- with(df, { @@ -369,7 +384,12 @@ plot_shap_global_profile <- function(x, theme(legend.position = "bottom") }) - if (!is.factor(df$color_variable_val)) { + if (is.factor(df$color_variable_val) || is.character(df$color_variable_val)) { + p + scale_color_manual( + name = paste(color_variable, "value"), + values = generate_discrete_color_scale(length(unique(df$color_variable_val)), colors) + ) + } else { p + scale_color_gradient2( name = paste(color_variable, "value"), low = colors[1], @@ -377,12 +397,151 @@ plot_shap_global_profile <- function(x, high = colors[3], midpoint = median(df$color_variable_val) ) + } +} + + +plot_shap_global_curves <- function(x, + ..., + variable = NULL, + boxplot = FALSE, + coef = 1.5, + title = "default", + subtitle = "default", + colors = NULL, + rug = "all", + rug_colors = c("#dd0000", "#222222")) { + + if (is.null(variable)) { + variable <- colnames(df)[1] + warning("`variable` was not specified, the first from the result will be used.") + } + + label <- attr(x, "label") + if (!is.null(title) && title == "default") { + title <- "SurvSHAP(t) curves" + } + if (!is.null(subtitle) && subtitle == "default") { + subtitle <- paste0("created for the ", unique(variable), " variable") + } + + if (!boxplot){ + df <- as.data.frame(do.call(rbind, x$result)) + df$obs <- rep(1:x$n_observations, + each = length(x$eval_times)) + df$time <- rep(x$eval_times, + times = x$n_observations) + df$varval <- rep(x$variable_values[[variable]], + each = length(x$eval_times)) + + if (is.null(colors) || length(colors) < 3) { + colors <- c( + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3" + ) + } + base_plot <- with(df, + {ggplot(df, aes(x = time, y = !!sym(variable), color = varval)) + + geom_hline(yintercept = 0, alpha = 0.5, color = "black") + + geom_line(aes(group = obs), alpha = 0.5) + + theme_default_survex() + + theme(legend.position = "bottom") + + labs(y = "SurvSHAP(t) value", + title = title, + subtitle = subtitle) + }) + + if (is.factor(x$variable_values[[variable]]) || is.character(x$variable_values[[variable]])) { + base_plot <- base_plot + scale_color_manual( + name = paste(variable, "value"), + values = generate_discrete_color_scale(length(unique(df$varval)), colors) + ) + } else { + base_plot <- base_plot + scale_color_gradient2( + name = paste(variable, "value"), + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = median(as.numeric(as.character(df$varval))) + ) + } } else { - p + scale_color_manual( - name = paste(color_variable, "value"), - values = generate_discrete_color_scale(length(unique(df$color_variable_val)), colors) - ) + if (is.null(colors) || length(colors) < 3) { + colors <- c( + background = "#000000", + boxplot = "#9fe5bd", + outliers = "#371ea3" + ) + } + + df <- t(sapply( x$result, function(x) x[[variable]])) + n <- x$n_observations + p <- dim(df)[2] + + rmat <- apply(df, 2, rank) + down <- rmat-1 + up <- n-rmat + depth <- (rowSums(up*down)/p+n-1)/choose(n,2) + + index <- order(depth,decreasing=TRUE) + median <- colMeans(df[which(depth==max(depth)), ,drop=FALSE]) + + m <- ceiling(n * 0.5) + central_region <- df[index[1:m], ] + outer_region <- df[index[(m+1):n], ] + lower_bound <- apply(central_region, 2, min) + upper_bound <- apply(central_region, 2, max) + + iqr <- upper_bound - lower_bound + outlier_indices <- which(colSums((t(df) <= median - coef * iqr) + + (t(df) >= median + coef * iqr)) > 0) + outliers <- df[outlier_indices, ] + nonoutliers <- df[-outlier_indices,] + whisker_lower_bound <- apply(nonoutliers, 2, min) + whisker_upper_bound <- apply(nonoutliers, 2, max) + + df_all <- data.frame(x = x$eval_times, + y = as.vector(t(df)), + obs = rep(1:n, each = length(x$eval_times))) + + df_outliers <- data.frame(x = x$eval_times, + y = as.vector(t(outliers)), + obs = rep(1:length(outliers), each = length(x$eval_times))) + + df_res <- data.frame(x = x$eval_times, + median = median, + lower_bound = lower_bound, + upper_bound = upper_bound, + whisker_lower_bound = whisker_lower_bound, + whisker_upper_bound = whisker_upper_bound) + + base_plot <- with(list(df_all, df_outliers, df_res), + {ggplot() + + geom_hline(yintercept = 0, alpha = 0.5, color = colors[1]) + + geom_line(data = df_all, aes(x = x, y = y, group = obs), col = colors[1], alpha = 0.2, linewidth = 0.2) + + geom_ribbon(data = df_res, aes(x = x, ymin = lower_bound, ymax = upper_bound), + fill = colors[2], alpha = 0.2) + + geom_line(data = df_res, aes(x = x, y = median), col = colors[2], alpha = 0.8) + + geom_line(data = df_res, aes(x = x, y = lower_bound), col = colors[2], alpha = 0.8) + + geom_line(data = df_res, aes(x = x, y = upper_bound), col = colors[2], alpha = 0.8) + + geom_line(data = df_res, aes(x = x, y = whisker_lower_bound), col = colors[2], lty = 2, alpha = 0.8) + + geom_line(data = df_res, aes(x = x, y = whisker_upper_bound), col = colors[2], lty = 2, alpha = 0.8) + + geom_line(data = df_outliers, aes(x = x, y = y, group = obs), + col = colors[3], linewidth = 0.5, alpha = 0.1) + + theme_default_survex() + + labs(x = "time", + y = "SurvSHAP(t) value", + title = title, + subtitle = subtitle) + }) + cat("Observations with outlying SurvSHAP(t) values:\n") + print(x$variable_values[outlier_indices,]) } + + rug_df <- data.frame(times = x$event_times, statuses = as.character(x$event_statuses)) + return_plot <- add_rug_to_plot(base_plot, rug_df, rug, rug_colors) + return(return_plot) } preprocess_values_to_common_scale <- function(data) { diff --git a/man/plot.aggregated_surv_shap.Rd b/man/plot.aggregated_surv_shap.Rd index 3e9a836e..22d7b8ff 100644 --- a/man/plot.aggregated_surv_shap.Rd +++ b/man/plot.aggregated_surv_shap.Rd @@ -17,7 +17,7 @@ \arguments{ \item{x}{an object of class \code{aggregated_surv_shap} to be plotted} -\item{geom}{character, one of \code{"importance"}, \code{"beeswarm"}, or \code{"profile"}. Type of chart to be plotted; \code{"importance"} shows the importance of variables over time and aggregated, \code{"beeswarm"} shows the distribution of SurvSHAP(t) values for variables and observations, \code{"profile"} shows the dependence of SurvSHAP(t) values on variable values.} +\item{geom}{character, one of \code{"importance"}, \code{"beeswarm"}, \code{"profile"} or \code{"curves"}. Type of chart to be plotted; \code{"importance"} shows the importance of variables over time and aggregated, \code{"beeswarm"} shows the distribution of SurvSHAP(t) values for variables and observations, \code{"profile"} shows the dependence of SurvSHAP(t) values on variable values, \code{"curves"} shows all SurvSHAP(t) curves for selected variable colored by its value or with functional boxplot if \code{boxplot = TRUE}.} \item{...}{additional parameters passed to internal functions} @@ -57,6 +57,12 @@ explanations of survival models created using the \code{model_survshap()} functi \item \code{variable} - variable for which the profile is to be plotted, by default first from result data \item \code{color_variable} - variable used to denote the color, by default equal to \code{variable} } + +#' ## \code{plot.aggregated_surv_shap(geom = "curves")} +\itemize{ +\item \code{variable} - variable for which SurvSHAP(t) curves are to be plotted, by default first from result data +\item \code{boxplot} - whether to plot functional boxplot with marked outliers or all curves colored by variable value +} } } @@ -93,7 +99,12 @@ ranger_global_survshap <- model_survshap( ) plot(ranger_global_survshap) plot(ranger_global_survshap, geom = "beeswarm") -plot(ranger_global_survshap, geom = "profile", color_variable = "karno") +plot(ranger_global_survshap, geom = "profile", + variable = "age", color_variable = "karno") +plot(ranger_global_survshap, geom = "curves", + variable = "age") +plot(ranger_global_survshap, geom = "curves", + variable = "age", boxplot = TRUE) } } diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index 1fe1f73f..51dc9799 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -25,6 +25,9 @@ test_that("global survshap explanations with kernelshap work for ranger, using n plot(ranger_global_survshap, geom = "beeswarm") plot(ranger_global_survshap, geom = "profile", variable = "karno", color_variable = "celltype") plot(ranger_global_survshap, geom = "profile", variable = "karno", color_variable = "age") + plot(ranger_global_survshap, geom = "curves", variable = "karno") + plot(ranger_global_survshap, geom = "curves", variable = "celltype") + plot(ranger_global_survshap, geom = "curves", variable = "karno", boxplot = TRUE) expect_error(plot(ranger_global_survshap, geom = "nonexistent")) single_survshap <- extract_predict_survshap(ranger_global_survshap, 5) @@ -49,6 +52,9 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex plot(cph_global_survshap, geom = "beeswarm") plot(cph_global_survshap, geom = "profile", variable = "karno", color_variable = "celltype") plot(cph_global_survshap, geom = "profile", variable = "karno", color_variable = "age") + plot(cph_global_survshap, geom = "curves", variable = "karno") + plot(cph_global_survshap, geom = "curves", variable = "celltype") + plot(cph_global_survshap, geom = "curves", variable = "karno", boxplot = TRUE) expect_s3_class(cph_global_survshap, c("aggregated_surv_shap", "surv_shap")) expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times)) From ff54a7da6eca062278260d9fbc3d510e5c098cba Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Mon, 28 Aug 2023 14:13:28 +0200 Subject: [PATCH 198/207] add new times generation method to `NEWS.md` --- NEWS.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index f272c573..dee41ee5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,8 +7,9 @@ * added 2-dimensional PDP and ALE plots (see `model_profile_2d()` function) * added `plot2()` function for plotting PDP and ALE explanations without the time dimension * added diagnostic explanations - residual analysis (see `model_diagnostics()` function) +* added new times generation method `"survival_quantiles"` and setting it as default (see `explain()`) * made improvements on the vignettes for the package (see `vignette("pdp")` and `vignette("global-survshap")`) -* increased the test coverage of the pacakge +* increased the test coverage of the package * reduced the number of expensive `requireNamespace()` calls ([#83](https://github.com/ModelOriented/survex/issues/83)) # survex 1.0.0 From 1cda2963383b4c2505af2f7efb6fc77829411a24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Tue, 29 Aug 2023 12:50:03 +0200 Subject: [PATCH 199/207] Add explanations for `output_type='chf'` --- R/model_profile.R | 111 +++++++++-------- R/model_profile_2d.R | 31 +++-- R/model_survshap.R | 3 + R/plot_model_profile_survival.R | 2 +- R/plot_surv_ceteris_paribus.R | 21 +++- R/predict_parts.R | 3 +- R/predict_profile.R | 11 +- R/surv_ceteris_paribus.R | 8 +- R/surv_model_profiles.R | 10 +- R/surv_shap.R | 39 ++++-- man/model_profile.surv_explainer.Rd | 2 +- man/model_profile_2d.surv_explainer.Rd | 2 +- man/model_survshap.surv_explainer.Rd | 3 + man/predict_profile.surv_explainer.Rd | 5 +- man/surv_ceteris_paribus.Rd | 1 + man/surv_shap.Rd | 3 + tests/testthat/test-model_profile.R | 158 +++++++++++++++++++++++++ tests/testthat/test-model_profile_2d.R | 95 +++++++++++++++ tests/testthat/test-model_survshap.R | 12 +- tests/testthat/test-predict_parts.R | 52 ++++++++ tests/testthat/test-predict_profile.R | 72 ++++++++++- 21 files changed, 548 insertions(+), 96 deletions(-) diff --git a/R/model_profile.R b/R/model_profile.R index 0363c605..5841a1b9 100644 --- a/R/model_profile.R +++ b/R/model_profile.R @@ -15,7 +15,7 @@ #' @param k passed to `DALEX::model_profile` if `output_type == "risk"`, otherwise ignored #' @param center logical, should profiles be centered around the average prediction #' @param type the type of variable profile, `"partial"` for Partial Dependence, `"accumulated"` for Accumulated Local Effects, or `"conditional"` (available only for `output_type == "risk"`) -#' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the `DALEX::model_profile` function. +#' @param output_type either `"survival"`, `"chf"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. If `"chf"` the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the `DALEX::predict_profile` function. #' #' @return An object of class `model_profile_survival`. It is a list with the element `result` containing the results of the calculation. #' @@ -81,8 +81,9 @@ model_profile.surv_explainer <- function(explainer, type = "partial", output_type = "survival") { variables <- unique(variables, categorical_variables) - switch(output_type, - "risk" = DALEX::model_profile( + + if (output_type == "risk"){ + DALEX::model_profile( explainer = explainer, variables = variables, N = N, @@ -91,59 +92,65 @@ model_profile.surv_explainer <- function(explainer, k = k, center = center, type = type - ), - "survival" = { - test_explainer(explainer, "model_profile", has_data = TRUE, has_survival = TRUE) - data <- explainer$data - if (!is.null(N) && N < nrow(data)) { - ndata <- data[sample(1:nrow(data), N), , drop = FALSE] - } else { - ndata <- data[1:nrow(data), , drop = FALSE] - } + ) + } else if (output_type %in% c("survival", "chf")) { - if (type == "partial") { - cp_profiles <- surv_ceteris_paribus(explainer, - new_observation = ndata, - variables = variables, - categorical_variables = categorical_variables, - grid_points = grid_points, - variable_splits_type = variable_splits_type, - center = center, - ... - ) + test_explainer(explainer, "model_profile", has_data = TRUE, has_survival = TRUE) + data <- explainer$data + if (!is.null(N) && N < nrow(data)) { + ndata <- data[sample(1:nrow(data), N), , drop = FALSE] + } else { + ndata <- data[1:nrow(data), , drop = FALSE] + } - result <- surv_aggregate_profiles(cp_profiles, ..., - variables = variables - ) - } else if (type == "accumulated") { - cp_profiles <- list(variable_values = data.frame(ndata)) - result <- surv_ale(explainer, - data = ndata, - variables = variables, - categorical_variables = categorical_variables, - grid_points = grid_points, - center = center, - ... - ) - } else { - stop("Currently only `partial` and `accumulated` types are implemented") - } + if (type == "partial") { + cp_profiles <- surv_ceteris_paribus(explainer, + new_observation = ndata, + variables = variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + variable_splits_type = variable_splits_type, + center = center, + output_type = output_type, + ... + ) - ret <- list( - eval_times = unique(result$`_times_`), - cp_profiles = cp_profiles, - result = result, - type = type, - center = center + result <- surv_aggregate_profiles(cp_profiles, ..., + variables = variables ) - class(ret) <- c("model_profile_survival", "list") - ret$median_survival_time <- explainer$median_survival_time - ret$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] - ret$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] - ret - }, - stop("Currently only `risk` and `survival` output types are implemented") - ) + } else if (type == "accumulated") { + cp_profiles <- list(variable_values = data.frame(ndata)) + result <- surv_ale( + explainer, + data = ndata, + variables = variables, + categorical_variables = categorical_variables, + grid_points = grid_points, + center = center, + output_type = output_type, + ...) + } else { + stop("Currently only `partial` and `accumulated` types are implemented") + } + ret <- list( + eval_times = unique(result$`_times_`), + cp_profiles = cp_profiles, + result = result, + type = type, + center = center, + output_type = output_type + ) + + class(ret) <- c("model_profile_survival", "list") + ret$median_survival_time <- explainer$median_survival_time + ret$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] + ret$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] + ret + } else { + stop("The `output_type` argument has to be one of 'survival', 'chf' or 'risk'") + } + + } #' @export diff --git a/R/model_profile_2d.R b/R/model_profile_2d.R index 42f4e7fb..f288cd66 100644 --- a/R/model_profile_2d.R +++ b/R/model_profile_2d.R @@ -12,7 +12,7 @@ #' @param center logical, should profiles be centered around the average prediction #' @param variable_splits_type character, decides how variable grids should be calculated. Use `"quantiles"` for quantiles or `"uniform"` (default) to get uniform grid of points. Used only if `type = "partial"`. #' @param type the type of variable profile, `"partial"` for Partial Dependence or `"accumulated"` for Accumulated Local Effects -#' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. Currently only `"survival"` is available. +#' @param output_type either `"survival"`, `"chf"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. If `"chf"` the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the `DALEX::predict_profile` function. #' #' @return An object of class `model_profile_2d_survival`. It is a list with the element `result` containing the results of the calculation. #' @@ -69,7 +69,7 @@ model_profile_2d.surv_explainer <- function(explainer, stop("'variables' must be specified as a list of pairs (two-element vectors)") } - if (output_type != "survival") { + if (!output_type %in% c("survival", "chf")) { stop("Currently only `survival` output type is implemented") } test_explainer(explainer, "model_profile", has_data = TRUE, has_survival = TRUE) @@ -95,7 +95,8 @@ model_profile_2d.surv_explainer <- function(explainer, categorical_variables = categorical_variables, grid_points = grid_points, variable_splits_type = variable_splits_type, - center = center + center = center, + output_type = output_type ) } else if (type == "accumulated") { result <- surv_ale_2d( @@ -104,7 +105,8 @@ model_profile_2d.surv_explainer <- function(explainer, variables = variables, categorical_variables = categorical_variables, grid_points = grid_points, - center = center + center = center, + output_type = output_type ) } else { stop("Currently only `partial` and `accumulated` types are implemented") @@ -115,7 +117,8 @@ model_profile_2d.surv_explainer <- function(explainer, eval_times = unique(result$`_times_`), variables = variables, type = type, - median_survival_time = explainer$median_survival_time + median_survival_time = explainer$median_survival_time, + output_type = output_type ) class(ret) <- c("model_profile_2d_survival", "list") return(ret) @@ -127,10 +130,15 @@ surv_pdp_2d <- function(x, categorical_variables, grid_points, variable_splits_type, - center) { + center, + output_type) { model <- x$model label <- x$label - predict_survival_function <- x$predict_survival_function + if (output_type == "survival"){ + predict_survival_function <- x$predict_survival_function + } else { + predict_survival_function <- x$predict_cumulative_hazard_function + } times <- x$times unique_variables <- unlist(variables) @@ -193,10 +201,15 @@ surv_ale_2d <- function(x, variables, categorical_variables, grid_points, - center) { + center, + output_type) { model <- x$model label <- x$label - predict_survival_function <- x$predict_survival_function + if (output_type == "survival"){ + predict_survival_function <- x$predict_survival_function + } else { + predict_survival_function <- x$predict_cumulative_hazard_function + } times <- x$times predictions_original <- predict_survival_function( diff --git a/R/model_survshap.R b/R/model_survshap.R index 9b0581e5..a3b9b0e1 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -61,6 +61,7 @@ model_survshap.surv_explainer <- function(explainer, y_true = NULL, calculation_method = "kernelshap", aggregation_method = "integral", + output_type = "survival", ...) { stopifnot( "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( @@ -96,6 +97,7 @@ model_survshap.surv_explainer <- function(explainer, shap_values <- surv_shap( explainer = explainer, new_observation = observations, + output_type = output_type, y_true = y_true, calculation_method = calculation_method, aggregation_method = aggregation_method @@ -104,5 +106,6 @@ model_survshap.surv_explainer <- function(explainer, attr(shap_values, "label") <- explainer$label shap_values$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] shap_values$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] + shap_values$output_type <- output_type return(shap_values) } diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 7e01b664..c2e9e924 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -529,7 +529,7 @@ prepare_model_profile_plots <- function(x, aggregated_profiles$`_real_point_` <- FALSE - pl <- plot_individual_ceteris_paribus_survival(aggregated_profiles, variables, colors, numerical_plot_type, rug_df, rug, rug_colors, x$center) + pl <- plot_individual_ceteris_paribus_survival(aggregated_profiles, variables, colors, numerical_plot_type, rug_df, rug, rug_colors, x$center, x$output_type) patchwork::wrap_plots(pl, ncol = facet_ncol) + patchwork::plot_annotation( diff --git a/R/plot_surv_ceteris_paribus.R b/R/plot_surv_ceteris_paribus.R index 075feb97..79cf149c 100644 --- a/R/plot_surv_ceteris_paribus.R +++ b/R/plot_surv_ceteris_paribus.R @@ -11,8 +11,10 @@ prepare_ceteris_paribus_plots <- function(x, rug_df <- data.frame(times = x$event_times, statuses = as.character(x$event_statuses), label = unique(x$result$`_label_`)) obs <- as.data.frame(x$variable_values) center <- x$center + output_type <- x$output_type x <- x$result + all_profiles <- x class(all_profiles) <- "data.frame" @@ -81,7 +83,8 @@ prepare_ceteris_paribus_plots <- function(x, rug_df = rug_df, rug = rug, rug_colors = rug_colors, - center = center + center = center, + output_type = output_type ) patchwork::wrap_plots(pl, ncol = facet_ncol) + @@ -100,7 +103,8 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, rug_df, rug, rug_colors, - center) { + center, + output_type) { pl <- lapply(variables, function(var) { df <- all_profiles[all_profiles$`_vname_` == var, ] @@ -143,7 +147,12 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, facet_wrap(~`_vname_`) }) if (!center) { - base_plot <- base_plot + ylim(c(0, 1)) + ylab("survival function value") + if (output_type == "survival"){ + base_plot <- base_plot + ylim(c(0, 1)) + ylab("survival function value") + } else { + base_plot <- base_plot + ylab("CHF value") + } + } } else { base_plot <- with(df, { @@ -233,7 +242,11 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles, facet_wrap(~`_vname_`) }) if (!center) { - base_plot <- base_plot + ylim(c(0, 1)) + ylab("survival function value") + if (output_type == "survival"){ + base_plot <- base_plot + ylim(c(0, 1)) + ylab("survival function value") + } else { + base_plot <- base_plot + ylab("CHF value") + } } } diff --git a/R/predict_parts.R b/R/predict_parts.R index 4ace7d83..d57fd459 100644 --- a/R/predict_parts.R +++ b/R/predict_parts.R @@ -75,13 +75,14 @@ predict_parts.surv_explainer <- function(explainer, new_observation, ..., N = NU )) } else { res <- switch(type, - "survshap" = surv_shap(explainer, new_observation, ...), + "survshap" = surv_shap(explainer, new_observation, output_type, ...), "survlime" = surv_lime(explainer, new_observation, ...), stop("Only `survshap` and `survlime` methods are implemented for now") ) } attr(res, "label") <- ifelse(is.null(explanation_label), explainer$label, explanation_label) + res$output_type <- output_type res$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] res$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] class(res) <- c("predict_parts_survival", class(res)) diff --git a/R/predict_profile.R b/R/predict_profile.R index 235276c8..934ac236 100644 --- a/R/predict_profile.R +++ b/R/predict_profile.R @@ -8,7 +8,7 @@ #' @param categorical_variables a character vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the `variables` argument, they will be added at the end. #' @param ... additional parameters passed to `DALEX::predict_profile` if `output_type =="risk"` #' @param type character, only `"ceteris_paribus"` is implemented -#' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the `DALEX::predict_profile` function. +#' @param output_type either `"survival"`, `"chf"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. If `"chf"` the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the `DALEX::predict_profile` function. #' @param variable_splits_type character, decides how variable grids should be calculated. Use `"quantiles"` for percentiles or `"uniform"` (default) to get uniform grid of points. #' @param center logical, should profiles be centered around the average prediction #' @@ -45,6 +45,7 @@ predict_profile <- function(explainer, categorical_variables = NULL, ..., type = "ceteris_paribus", + output_type = "survival", variable_splits_type = "uniform", center = FALSE) { UseMethod("predict_profile", explainer) @@ -63,7 +64,7 @@ predict_profile.surv_explainer <- function(explainer, center = FALSE) { variables <- unique(variables, categorical_variables) if (!type %in% "ceteris_paribus") stop("Type not supported") - if (!output_type %in% c("risk", "survival")) stop("output_type not supported") + if (!output_type %in% c("risk", "survival", "chf")) stop("output_type not supported") if (length(dim(new_observation)) != 2 || nrow(new_observation) != 1) { stop("new_observation should be a single row data.frame") } @@ -78,7 +79,7 @@ predict_profile.surv_explainer <- function(explainer, variable_splits_type = variable_splits_type )) } - if (output_type == "survival") { + if (output_type %in% c("survival", "chf")) { if (type == "ceteris_paribus") { res <- surv_ceteris_paribus( explainer, @@ -87,15 +88,17 @@ predict_profile.surv_explainer <- function(explainer, categorical_variables = categorical_variables, variable_splits_type = variable_splits_type, center = center, + output_type = output_type, ... ) class(res) <- c("predict_profile_survival", class(res)) + res$output_type <- output_type res$median_survival_time <- explainer$median_survival_time res$event_times <- explainer$y[explainer$y[, 1] <= max(explainer$times), 1] res$event_statuses <- explainer$y[explainer$y[, 1] <= max(explainer$times), 2] return(res) } else { - stop("For survival output only type=`ceteris_paribus` is implemented") + stop("For 'survival' and 'chf' output only type=`ceteris_paribus` is implemented") } } } diff --git a/R/surv_ceteris_paribus.R b/R/surv_ceteris_paribus.R index 031a9b99..994778de 100644 --- a/R/surv_ceteris_paribus.R +++ b/R/surv_ceteris_paribus.R @@ -27,12 +27,18 @@ surv_ceteris_paribus.surv_explainer <- function(x, grid_points = 101, variable_splits_type = "uniform", center = FALSE, + output_type = "survival", ...) { test_explainer(x, has_data = TRUE, has_survival = TRUE, has_y = TRUE, function_name = "ceteris_paribus_survival") data <- x$data model <- x$model label <- x$label - predict_survival_function <- x$predict_survival_function + if (output_type == "survival"){ + predict_survival_function <- x$predict_survival_function + } else { + predict_survival_function <- x$predict_cumulative_hazard_function + } + times <- x$times surv_ceteris_paribus.default( diff --git a/R/surv_model_profiles.R b/R/surv_model_profiles.R index 6a958bc2..e460a346 100644 --- a/R/surv_model_profiles.R +++ b/R/surv_model_profiles.R @@ -97,7 +97,8 @@ surv_ale <- function(x, variables, categorical_variables, grid_points, - center = FALSE) { + center = FALSE, + output_type = "survival") { if (is.null(variables)) { variables <- colnames(data) } @@ -114,7 +115,12 @@ surv_ale <- function(x, model <- x$model label <- x$label - predict_survival_function <- x$predict_survival_function + if (output_type == "survival"){ + predict_survival_function <- x$predict_survival_function + } else { + predict_survival_function <- x$predict_cumulative_hazard_function + } + times <- x$times # Make predictions for original levels diff --git a/R/surv_shap.R b/R/surv_shap.R index 7970e751..e0e1bf8a 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -2,6 +2,7 @@ #' #' @param explainer an explainer object - model preprocessed by the `explain()` function #' @param new_observation new observations for which predictions need to be explained +#' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations. #' @param ... additional parameters, passed to internal functions #' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting #' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation @@ -16,6 +17,7 @@ #' @keywords internal surv_shap <- function(explainer, new_observation, + output_type, ..., y_true = NULL, calculation_method = "kernelshap", @@ -63,8 +65,8 @@ surv_shap <- function(explainer, # to display final object correctly, when is.matrix(new_observation) == TRUE res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, - "exact_kernel" = use_exact_shap(explainer, new_observation, ...), - "kernelshap" = use_kernelshap(explainer, new_observation, ...), + "exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...), + "kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...), stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented") ) @@ -85,11 +87,11 @@ surv_shap <- function(explainer, return(res) } -use_exact_shap <- function(explainer, new_observation, observation_aggregation_method, ...) { +use_exact_shap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) { shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { - as.data.frame(shap_kernel(explainer, new_observation[as.integer(i), ], ...)) + as.data.frame(shap_kernel(explainer, new_observation[as.integer(i), ], output_type, ...)) }, USE.NAMES = TRUE, simplify = FALSE @@ -99,12 +101,13 @@ use_exact_shap <- function(explainer, new_observation, observation_aggregation_m } -shap_kernel <- function(explainer, new_observation, ...) { +shap_kernel <- function(explainer, new_observation, output_type, ...) { timestamps <- explainer$times p <- ncol(explainer$data) - target_sf <- explainer$predict_survival_function(explainer$model, new_observation, timestamps) - sfs <- explainer$predict_survival_function(explainer$model, explainer$data, timestamps) + + target_sf <- predict(explainer, new_observation, times = timestamps, output_type = output_type) + sfs <- predict(explainer, explainer$data, times = timestamps, output_type = output_type) baseline_sf <- apply(sfs, 2, mean) @@ -186,13 +189,23 @@ aggregate_surv_shap <- function(survshap, times, method, ...) { } -use_kernelshap <- function(explainer, new_observation, observation_aggregation_method, ...) { +use_kernelshap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) { predfun <- function(model, newdata) { - explainer$predict_survival_function( - model, - newdata, - times = explainer$times - ) + + if (output_type == "survival"){ + explainer$predict_survival_function( + model, + newdata, + times = explainer$times + ) + } else { + explainer$predict_cumulative_hazard_function( + model, + newdata, + times = explainer$times + ) + } + } shap_values <- sapply( diff --git a/man/model_profile.surv_explainer.Rd b/man/model_profile.surv_explainer.Rd index bd2bea6c..b2881c06 100644 --- a/man/model_profile.surv_explainer.Rd +++ b/man/model_profile.surv_explainer.Rd @@ -49,7 +49,7 @@ model_profile( \item{center}{logical, should profiles be centered around the average prediction} -\item{output_type}{either \code{"survival"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the \code{DALEX::model_profile} function.} +\item{output_type}{either \code{"survival"}, \code{"chf"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. If \code{"chf"} the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the \code{DALEX::predict_profile} function.} \item{categorical_variables}{character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the \code{variables} argument, they will be added at the end.} diff --git a/man/model_profile_2d.surv_explainer.Rd b/man/model_profile_2d.surv_explainer.Rd index 2204f706..59f23461 100644 --- a/man/model_profile_2d.surv_explainer.Rd +++ b/man/model_profile_2d.surv_explainer.Rd @@ -46,7 +46,7 @@ model_profile_2d( \item{type}{the type of variable profile, \code{"partial"} for Partial Dependence or \code{"accumulated"} for Accumulated Local Effects} -\item{output_type}{either \code{"survival"} or \code{"risk"} the type of survival model output that should be considered for explanations. Currently only \code{"survival"} is available.} +\item{output_type}{either \code{"survival"}, \code{"chf"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. If \code{"chf"} the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the \code{DALEX::predict_profile} function.} } \value{ An object of class \code{model_profile_2d_survival}. It is a list with the element \code{result} containing the results of the calculation. diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 9a51ced0..840f6460 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -13,6 +13,7 @@ model_survshap(explainer, ...) y_true = NULL, calculation_method = "kernelshap", aggregation_method = "integral", + output_type = "survival", ... ) } @@ -28,6 +29,8 @@ model_survshap(explainer, ...) \item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} \item{aggregation_method}{a character, either \code{"integral"}, \code{"integral_absolute"}, \code{"mean_absolute"}, \code{"max_absolute"}, or \code{"sum_of_squares"}} + +\item{output_type}{a character, either \code{"survival"} or \code{"chf"}. Determines which type of prediction should be used for explanations.} } \value{ An object of class \code{aggregated_surv_shap} containing the computed global SHAP values. diff --git a/man/predict_profile.surv_explainer.Rd b/man/predict_profile.surv_explainer.Rd index b9905d10..4a2cea96 100644 --- a/man/predict_profile.surv_explainer.Rd +++ b/man/predict_profile.surv_explainer.Rd @@ -12,6 +12,7 @@ predict_profile( categorical_variables = NULL, ..., type = "ceteris_paribus", + output_type = "survival", variable_splits_type = "uniform", center = FALSE ) @@ -41,11 +42,11 @@ predict_profile( \item{type}{character, only \code{"ceteris_paribus"} is implemented} +\item{output_type}{either \code{"survival"}, \code{"chf"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. If \code{"chf"} the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the \code{DALEX::predict_profile} function.} + \item{variable_splits_type}{character, decides how variable grids should be calculated. Use \code{"quantiles"} for percentiles or \code{"uniform"} (default) to get uniform grid of points.} \item{center}{logical, should profiles be centered around the average prediction} - -\item{output_type}{either \code{"survival"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the \code{DALEX::predict_profile} function.} } \value{ An object of class \code{c("predict_profile_survival", "surv_ceteris_paribus")}. It is a list with the final result in the \code{result} element. diff --git a/man/surv_ceteris_paribus.Rd b/man/surv_ceteris_paribus.Rd index da5d1e70..c2698cdf 100644 --- a/man/surv_ceteris_paribus.Rd +++ b/man/surv_ceteris_paribus.Rd @@ -16,6 +16,7 @@ surv_ceteris_paribus(x, ...) grid_points = 101, variable_splits_type = "uniform", center = FALSE, + output_type = "survival", ... ) } diff --git a/man/surv_shap.Rd b/man/surv_shap.Rd index 4cc684b9..b1fc485f 100644 --- a/man/surv_shap.Rd +++ b/man/surv_shap.Rd @@ -7,6 +7,7 @@ surv_shap( explainer, new_observation, + output_type, ..., y_true = NULL, calculation_method = "kernelshap", @@ -18,6 +19,8 @@ surv_shap( \item{new_observation}{new observations for which predictions need to be explained} +\item{output_type}{a character, either \code{"survival"} or \code{"chf"}. Determines which type of prediction should be used for explanations.} + \item{...}{additional parameters, passed to internal functions} \item{y_true}{a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting} diff --git a/tests/testthat/test-model_profile.R b/tests/testthat/test-model_profile.R index 56827ce8..6eb980e7 100644 --- a/tests/testthat/test-model_profile.R +++ b/tests/testthat/test-model_profile.R @@ -164,3 +164,161 @@ test_that("default DALEX::model_profile is ok", { expect_error(model_profile(cph_exp, output_type = "something_else")) }) + + + + +test_that("model_profile with type = 'partial' and output_type = 'chf' works", { + veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] + + cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) + rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) + + cph_exp <- explain(cph, verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_src_exp <- explain(rsf_src, verbose = FALSE) + + + mp_cph_cat <- model_profile(cph_exp, output_type = 'chf', variable_splits_type = "quantiles", grid_points = 6, N = 4) + plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") + + expect_s3_class(mp_cph_cat, "model_profile_survival") + expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) + expect_equal(ncol(mp_cph_cat$result), 7) + expect_true(all(unique(mp_cph_cat$result$`_vname_`) %in% colnames(cph_exp$data))) + + mp_chosen_var <- model_profile(cph_exp, output_type = 'chf', variable_splits_type = "quantiles", grid_points = 6, variables = "karno") + expect_s3_class(mp_chosen_var, "model_profile_survival") + expect_true(all(mp_chosen_var$eval_times == cph_exp$times)) + expect_equal(ncol(mp_chosen_var$result), 7) + + mp_cph_num <- model_profile(cph_exp, output_type = 'chf', variable_splits_type = "quantiles", grid_points = 6) + plot(mp_cph_num, variable_type = "numerical") + plot(mp_cph_num, numerical_plot_type = "contours") + + ### Add tests for plot for numerical PDP + # single time point + plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "pdp+ice", times = cph_exp$times[1]) + plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "pdp", times = cph_exp$times[1]) + plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "ice", times = cph_exp$times[1]) + # multiple time points + plot(mp_cph_num, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "karno", plot_type = "pdp+ice") + plot(mp_cph_num, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "karno", plot_type = "pdp") + plot(mp_cph_num, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "karno", plot_type = "ice") + + expect_error(plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "nonexistent", times = cph_exp$times[1])) + expect_error(plot(mp_cph_num, geom = "variable", variables = 1, plot_type = "pdp+ice", times = cph_exp$times[1])) + expect_error(plot(mp_cph_num, geom = "variable", variables = c("karno", "diagtime"), plot_type = "pdp+ice", times = cph_exp$times[1])) + expect_error(plot(mp_cph_num, geom = "variable", variables = "nonexistent", plot_type = "pdp+ice", times = cph_exp$times[1])) + + + expect_s3_class(mp_cph_num, "model_profile_survival") + expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) + expect_equal(ncol(mp_cph_num$result), 7) + expect_true(all(unique(mp_cph_num$result$`_vname_`) %in% colnames(cph_exp$data))) + + + mp_rsf_cat <- model_profile(rsf_ranger_exp, output_type = 'chf', variable_splits_type = "uniform", grid_points = 6) + plot(mp_rsf_cat, variable_type = "categorical") + + + plot(mp_cph_cat, mp_rsf_cat) + ### Add tests for plot for categorical PDP + # single time point + plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "pdp+ice", times = rsf_ranger_exp$times[1]) + plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "pdp", times = rsf_ranger_exp$times[1]) + plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "ice", times = rsf_ranger_exp$times[1]) + # multiple time points + plot(mp_rsf_cat, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "celltype", plot_type = "pdp+ice") + plot(mp_rsf_cat, geom = "variable", times = cph_exp$times[c(1, 20)], marginalize_over_time = T, variables = "celltype", plot_type = "pdp+ice") + plot(mp_rsf_cat, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "celltype", plot_type = "pdp") + plot(mp_rsf_cat, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "celltype", plot_type = "ice") + + + expect_s3_class(mp_rsf_cat, "model_profile_survival") + expect_true(all(mp_rsf_cat$eval_times == cph_exp$times)) + expect_equal(ncol(mp_rsf_cat$result), 7) + expect_true(all(unique(mp_rsf_cat$result$`_vname_`) %in% colnames(rsf_ranger_exp$data))) + + + mp_rsf_num <- model_profile(rsf_ranger_exp, output_type = 'chf', variable_splits_type = "uniform", grid_points = 6) + plot(mp_rsf_num, variable_type = "numerical") + plot(mp_rsf_num, variable_type = "numerical", numerical_plot_type = "contours") + + expect_s3_class(mp_rsf_num, "model_profile_survival") + expect_true(all(mp_rsf_num$eval_times == cph_exp$times)) + expect_equal(ncol(mp_rsf_num$result), 7) + expect_true(all(unique(mp_rsf_num$result$`_vname_`) %in% colnames(rsf_ranger_exp$data))) + + expect_output(print(mp_cph_num)) + expect_warning(plot(mp_rsf_cat, geom = "variable", variables = "celltype", plot_type = "pdp+ice")) + expect_error(plot(mp_rsf_num, variables = "nonexistent", grid_points = 6)) + expect_error(model_profile(rsf_ranger_exp, type = "conditional")) + expect_error(plot(mp_rsf_num, geom = "variable", variables = "nonexistent")) + expect_error(plot(mp_rsf_num, geom = "variable", variables = "age", times = -1)) + expect_error(plot(mp_rsf_num, geom = "nonexistent")) + expect_error(plot(mp_rsf_num, nonsense_argument = "character")) + + centered_mp <- model_profile(rsf_src_exp, output_type = 'chf', variables = "karno", center = TRUE) + plot(centered_mp, geom = "variable", variables = "karno", times = rsf_src_exp$times[1]) + +}) + +test_that("model_profile with type = 'accumulated' and output_type = 'chf' works", { + veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] + + cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) + rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) + + cph_exp <- explain(cph, verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_src_exp <- explain(rsf_src, verbose = FALSE) + + + mp_cph_cat <- model_profile(cph_exp, + output_type = 'chf', + grid_points = 6, + type = 'accumulated', + categorical_variables = "trt") + plot(mp_cph_cat, variables = "celltype", variable_type = "categorical") + + ### Add tests for plot for categorical ALE + # single time point + plot(mp_cph_cat, geom = "variable", variables = "celltype", times=cph_exp$times[1]) + # multiple time points + plot(mp_cph_cat, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "celltype", plot_type = "ale") + + expect_s3_class(mp_cph_cat, "model_profile_survival") + expect_true(all(mp_cph_cat$eval_times == cph_exp$times)) + expect_equal(ncol(mp_cph_cat$result), 7) + expect_true(all(unique(mp_cph_cat$result$`_vname_`) %in% colnames(cph_exp$data))) + expect_error(plot(mp_cph_cat, geom = "variable", variables = "celltype", plot_type = "pdp")) + expect_error(plot(mp_cph_cat, geom = "variable", variables = "celltype", plot_type = "nonexistent")) + expect_error(plot(mp_cph_cat, geom = "variable", variables = 1, plot_type = "nonexistent")) + expect_error(plot(mp_cph_cat, geom = "variable", variables = c("celltype", "trt"), plot_type = "nonexistent")) + + + mp_cph_num <- model_profile(cph_exp, + output_type = 'chf', + grid_points = 6, + type = 'accumulated', + categorical_variables = "trt") + plot(mp_cph_num, variable_type = "numerical") + plot(mp_cph_num, numerical_plot_type = "contours") + + ### Add tests for plot for numerical ALE + # single time point + plot(mp_cph_num, geom = "variable", variables = "karno", plot_type = "ale", times=cph_exp$times[1]) + # multiple time points + plot(mp_cph_num, geom = "variable", times = cph_exp$times[c(1, 20)], variables = "karno", plot_type = "ale") + + expect_s3_class(mp_cph_num, "model_profile_survival") + expect_true(all(unique(mp_cph_num$eval_times) == cph_exp$times)) + expect_equal(ncol(mp_cph_num$result), 7) + expect_true(all(unique(mp_cph_num$result$`_vname_`) %in% colnames(cph_exp$data))) + + expect_output(print(mp_cph_num)) + expect_error(plot(mp_rsf_num, output_type = 'chf', variables = "nonexistent", grid_points = 6)) +}) diff --git a/tests/testthat/test-model_profile_2d.R b/tests/testthat/test-model_profile_2d.R index 56a98838..1ee45dee 100644 --- a/tests/testthat/test-model_profile_2d.R +++ b/tests/testthat/test-model_profile_2d.R @@ -88,3 +88,98 @@ test_that("model_profile_2d with type = 'accumulated' works", { type = "accumulated")) expect_error(plot(mp_rsf_ale, times = -1)) }) + + +test_that("model_profile_2d with type = 'partial' and output_type = 'chf' works", { + veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] + + cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) + rsf <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + + cph_exp <- explain(cph, verbose = FALSE) + rsf_exp <- explain(rsf, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + + mp_cph_pdp <- model_profile_2d(cph_exp, + variable_splits_type = "uniform", + variables = list(c("trt", "age"), + c("karno", "trt"), + c("karno", "age")), + categorical_variables = "trt", + grid_points = 6, + output_type = 'chf') + mp_small <- model_profile_2d(cph_exp, + variables = list(c("trt", "age")), + categorical_variables = 1, + grid_points = 2, + N = 2, + output_type = 'chf') + plot(mp_cph_pdp, times=cph_exp$times[1]) + + mp_rsf_pdp <- model_profile_2d(rsf_exp, + variables = list(c("karno", "age")), + grid_points = 6, + output_type = 'chf', + N = 25) + plot(mp_cph_pdp, mp_rsf_pdp, variables = list(c("karno", "age")), times=cph_exp$times[1]) + + expect_output(print(mp_cph_pdp)) + expect_s3_class(mp_cph_pdp, "model_profile_2d_survival") + expect_true(all(mp_cph_pdp$eval_times == cph_exp$times)) + expect_equal(ncol(mp_cph_pdp$result), 9) + expect_true(all(unique(c(mp_cph_pdp$result$`_v1name_`, mp_cph_pdp$result$`_v2name_`)) + %in% colnames(cph_exp$data))) + + expect_warning(plot(mp_cph_pdp)) + expect_error(model_profile_2d(rsf_exp)) + expect_error(model_profile_2d(rsf_exp, type = "conditional", + variables = list(c("karno", "age")))) + expect_error(model_profile_2d(rsf_exp, output_type = "risk", + variables = list(c("karno", "age")))) +} +) + +test_that("model_profile_2d with type = 'accumulated' and output_type = 'chf' works", { + veteran <- survival::veteran[c(1:3, 16:18, 46:48, 56:58, 71:73, 91:93, 111:113, 126:128), ] + + rsf <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_exp <- explain(rsf, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + + mp_rsf_ale <- model_profile_2d(rsf_exp, + variable_splits_type = "quantiles", + variables = list(c("karno", "age")), + grid_points = 6, + output_type = 'chf', + type = "accumulated") + + mp_rsf_ale_noncentered <- model_profile_2d(rsf_exp, + variable_splits_type = "quantiles", + variables = list(c("karno", "age")), + grid_points = 6, + output_type = 'chf', + type = "accumulated", + center = FALSE) + + expect_error(model_profile_2d(rsf_exp, + type = "accumulated", + output_type = 'chf', + variables = list(c("karno", "celltype")))) + + plot(mp_rsf_ale, times=rsf_exp$times[1]) + + expect_output(print(mp_rsf_ale)) + expect_s3_class(mp_rsf_ale, "model_profile_2d_survival") + expect_true(all(unique(mp_rsf_ale$eval_times) == rsf_exp$times)) + expect_equal(ncol(mp_rsf_ale$result), 14) + expect_true(all(unique(c(mp_rsf_ale$result$`_v1name_`, mp_rsf_ale$result$`_v2name_`)) + %in% colnames(rsf_exp$data))) + + expect_error(plot(mp_rsf_ale, variables = "nonexistent")) + expect_error(plot(mp_rsf_ale, + variables = list(c("karno", "trt")), + categorical_variables="trt")) + expect_error(model_profile_2d(mp_rsf_ale, + variables = list(c("karno", "trt")), + categorical_variables="trt", + type = "accumulated")) + expect_error(plot(mp_rsf_ale, times = -1)) +}) diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index 51dc9799..d56b75cc 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -21,13 +21,21 @@ test_that("global survshap explanations with kernelshap work for ranger, using n aggregation_method = "mean_absolute", calculation_method = "kernelshap" ) + ranger_global_survshap <- model_survshap( + explainer = rsf_ranger_exp, + new_observation = veteran[c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status")], + y_true = Surv(veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)]), + aggregation_method = "mean_absolute", + calculation_method = "kernelshap", + output_type = "chf" + ) plot(ranger_global_survshap) plot(ranger_global_survshap, geom = "beeswarm") plot(ranger_global_survshap, geom = "profile", variable = "karno", color_variable = "celltype") plot(ranger_global_survshap, geom = "profile", variable = "karno", color_variable = "age") plot(ranger_global_survshap, geom = "curves", variable = "karno") plot(ranger_global_survshap, geom = "curves", variable = "celltype") - plot(ranger_global_survshap, geom = "curves", variable = "karno", boxplot = TRUE) + expect_output(plot(ranger_global_survshap, geom = "curves", variable = "karno", boxplot = TRUE)) expect_error(plot(ranger_global_survshap, geom = "nonexistent")) single_survshap <- extract_predict_survshap(ranger_global_survshap, 5) @@ -54,7 +62,7 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex plot(cph_global_survshap, geom = "profile", variable = "karno", color_variable = "age") plot(cph_global_survshap, geom = "curves", variable = "karno") plot(cph_global_survshap, geom = "curves", variable = "celltype") - plot(cph_global_survshap, geom = "curves", variable = "karno", boxplot = TRUE) + expect_output(plot(cph_global_survshap, geom = "curves", variable = "karno", boxplot = TRUE)) expect_s3_class(cph_global_survshap, c("aggregated_surv_shap", "surv_shap")) expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times)) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index d70dfbfc..450fb1cf 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -49,6 +49,58 @@ test_that("survshap explanations work", { }) +test_that("survshap explanations with output_type = 'chf' work", { + veteran <- survival::veteran + + cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) + rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) + + cph_exp <- explain(cph, verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_src_exp <- explain(rsf_src, verbose = FALSE) + + parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares", output_type = "chf") + parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf") + plot(parts_cph) + plot(parts_cph, rug = "events") + plot(parts_cph, rug = "censors") + plot(parts_cph, rug = "none") + + parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf") + plot(parts_ranger) + + parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")], output_type = "chf") + plot(parts_src) + + plot(parts_cph, parts_ranger, parts_src) + + parts_cph2 <- predict_parts(cph_exp, veteran[4,], explanation_label = "second_explanation", output_type = "chf") + plot(parts_cph, parts_cph2) + + expect_s3_class(parts_cph, c("predict_parts_survival", "surv_shap")) + expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) + expect_s3_class(parts_src, c("predict_parts_survival", "surv_shap")) + + expect_equal(nrow(parts_cph$result), length(cph_exp$times)) + expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp$times)) + expect_equal(nrow(parts_src$result), length(rsf_src_exp$times)) + + expect_true(all(colnames(parts_cph$result) == colnames(cph_exp$data))) + expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp$data))) + expect_true(all(colnames(parts_src$result) == colnames(rsf_src_exp$data))) + + expect_output(print(parts_cph)) + + expect_error(predict_parts(cph_exp, veteran[1, ], aggregation_method = "nonexistent")) + expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "sampling")) + expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "nonexistent")) + expect_error(predict_parts(cph_exp, veteran[1, c(1, 1, 1, 1, 1)], calculation_method = "nonexistent")) + + +}) + + test_that("survlime explanations work", { veteran <- survival::veteran diff --git a/tests/testthat/test-predict_profile.R b/tests/testthat/test-predict_profile.R index 308f58a7..f4fda6fa 100644 --- a/tests/testthat/test-predict_profile.R +++ b/tests/testthat/test-predict_profile.R @@ -20,10 +20,10 @@ test_that("ceteris_paribus works", { plot(cph_pp, rug = "censors", variable_type = "numerical") plot(cph_pp, rug = "none") plot(cph_pp, geom = "variable", variables = "karno", times=cph_exp$times[1]) - plot(cph_pp, geom = "variable", times=cph_exp$times[1]) + expect_warning(plot(cph_pp, geom = "variable", times=cph_exp$times[1])) plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "karno") plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "celltype") - plot(cph_pp, geom = "variable", variables = "karno", marginalize_over_time = TRUE) + expect_warning(plot(cph_pp, geom = "variable", variables = "karno", marginalize_over_time = TRUE)) plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "celltype", marginalize_over_time = TRUE) expect_error(plot(cph_pp, variables = "aaa")) @@ -41,7 +41,7 @@ test_that("ceteris_paribus works", { cph_pp_centered <- predict_profile(cph_exp, veteran[2, -c(3, 4)], center = TRUE) plot(cph_pp_centered) plot(cph_pp_centered, numerical_plot_type = "contours") - plot(cph_pp_centered, geom = "variable", variables = "karno") + expect_warning(plot(cph_pp_centered, geom = "variable", variables = "karno")) cph_pp_cat <- predict_profile(cph_exp, veteran[2, -c(3, 4)], variables = c("celltype")) plot(predict_profile(cph_exp, veteran[2, -c(3, 4)], categorical_variables = 1)) @@ -66,6 +66,72 @@ test_that("ceteris_paribus works", { expect_output(print(cph_pp)) }) +test_that("ceteris_paribus with output_type = 'chf' works", { + + veteran <- survival::veteran + + cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) + rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) + + cph_exp <- explain(cph, verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_src_exp <- explain(rsf_src, verbose = FALSE) + + cph_pp <- predict_profile(cph_exp, veteran[2, -c(3, 4)], output_type = 'chf') + ranger_pp <- predict_profile(rsf_ranger_exp, veteran[2, -c(3, 4)], output_type = 'chf') + + plot(cph_pp, colors = c("#ff0000", "#00ff00", "#0000ff")) + plot(cph_pp) + plot(cph_pp, numerical_plot_type = "contours") + plot(cph_pp, ranger_pp, rug = "events", variables = c("karno", "age")) + plot(cph_pp, rug = "censors", variable_type = "numerical") + plot(cph_pp, rug = "none") + plot(cph_pp, geom = "variable", variables = "karno", times=cph_exp$times[1]) + expect_warning(plot(cph_pp, geom = "variable", times=cph_exp$times[1])) + plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "karno") + plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "celltype") + expect_warning(plot(cph_pp, geom = "variable", variables = "karno", marginalize_over_time = TRUE)) + plot(cph_pp, geom = "variable", times = cph_pp$eval_times[1:2], variables = "celltype", marginalize_over_time = TRUE) + + expect_error(plot(cph_pp, variables = "aaa")) + expect_error(plot(cph_pp, variable_type = "nonexistent")) + expect_error(plot(cph_pp, numerical_plot_type = "nonexistent")) + expect_error(plot(cph_pp, geom = "variable", variables = "nonexistent")) + expect_error(plot(cph_pp, geom = "variable", variables = "age", times = -1)) + expect_error(plot(cph_pp, geom = "nonexistent")) + expect_error(plot(cph_pp, geom = "variable", variables = 1, plot_type = "pdp+ice", times = cph_exp$times[1])) + expect_error(plot(cph_pp, geom = "variable", variables = c("karno", "diagtime"), times = cph_exp$times[1])) + expect_error(predict_profile(cph_exp, veteran[2, -c(3, 4)], output_type = "nonexistent")) + expect_error(predict_profile(cph_exp, veteran[2:3, -c(3, 4)])) + expect_error(predict_profile(cph_exp, veteran[2, -c(3, 4)], type = "nonexistent")) + + cph_pp_centered <- predict_profile(cph_exp, veteran[2, -c(3, 4)], center = TRUE, output_type = 'chf') + plot(cph_pp_centered) + plot(cph_pp_centered, numerical_plot_type = "contours") + expect_warning(plot(cph_pp_centered, geom = "variable", variables = "karno")) + + cph_pp_cat <- predict_profile(cph_exp, veteran[2, -c(3, 4)], variables = c("celltype"), output_type = 'chf') + plot(predict_profile(cph_exp, veteran[2, -c(3, 4)], categorical_variables = 1)) + plot(cph_pp_cat, variable_type = "categorical", colors = c("#ff0000", "#00ff00", "#0000ff")) + plot(cph_pp_cat, variable_type = "categorical") + + expect_s3_class(cph_pp, c("predict_profile_survival", "surv_ceteris_paribus")) + expect_s3_class(cph_pp_cat, c("predict_profile_survival", "surv_ceteris_paribus")) + + expect_true(all(unique(cph_pp$result$`_vname_`) %in% colnames(cph_exp$data))) + expect_true(all(unique(cph_pp_cat$result$`_vname_`) %in% colnames(cph_exp$data))) + + + expect_true(all(unique(cph_pp$result$`_yhat_`) >= 0)) + expect_true(all(unique(cph_pp_cat$result$`_yhat_`) >= 0)) + + expect_setequal(cph_pp$eval_times, cph_exp$times) + expect_setequal(cph_pp_cat$eval_times, cph_exp$times) + + expect_output(print(cph_pp)) +}) + test_that("default DALEX::ceteris_paribus works", { From 38f0b790b9d15f6916f578bfb8453091de6a57aa Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 29 Aug 2023 12:51:56 +0200 Subject: [PATCH 200/207] fix titles when NULL --- R/plot_model_profile_2d.R | 2 +- R/plot_model_profile_survival.R | 2 +- R/plot_predict_profile_survival.R | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/plot_model_profile_2d.R b/R/plot_model_profile_2d.R index d7384cab..0ea50b55 100644 --- a/R/plot_model_profile_2d.R +++ b/R/plot_model_profile_2d.R @@ -53,7 +53,7 @@ plot.model_profile_2d_survival <- function(x, colors = NULL) { explanations_list <- c(list(x), list(...)) num_models <- length(explanations_list) - if (title == "default") { + if (!is.null(title) && title == "default") { if (x$type == "partial") { title <- "2D partial dependence survival profiles" } diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index c2e9e924..147235d0 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -75,7 +75,7 @@ plot.model_profile_survival <- function(x, stop("`variable_type` needs to be 'numerical' or 'categorical'") } - if (title == "default") { + if (!is.null(title) && title == "default") { if (x$type == "partial") { title <- "Partial dependence survival profiles" if (geom == "variable") { diff --git a/R/plot_predict_profile_survival.R b/R/plot_predict_profile_survival.R index 1f0d0d7d..fa40d33c 100644 --- a/R/plot_predict_profile_survival.R +++ b/R/plot_predict_profile_survival.R @@ -70,7 +70,7 @@ plot.predict_profile_survival <- function(x, stop("`variable_type` needs to be 'numerical' or 'categorical'") } - if (title == "default"){ + if (!is.null(title) && title == "default"){ title <- "Ceteris paribus survival profile" } From 352094ef4e709abfc773ceee0e79454be9230c9f Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 29 Aug 2023 13:13:34 +0200 Subject: [PATCH 201/207] fix colors for diagnostic plots --- R/plot_model_diagnostics_survival.R | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/R/plot_model_diagnostics_survival.R b/R/plot_model_diagnostics_survival.R index b6516be9..8646edce 100644 --- a/R/plot_model_diagnostics_survival.R +++ b/R/plot_model_diagnostics_survival.R @@ -26,7 +26,7 @@ plot.model_diagnostics_survival <- function(x, facet_ncol = NULL, title = "Model diagnostics", subtitle = "default", - colors = c("#160e3b", "#f05a71", "#ceced9")){ + colors = NULL){ lapply(list(x, ...), function(x) { if (!inherits(x, "model_diagnostics_survival")) { stop("All ... must be objects of class `model_diagnostics_survival`.") @@ -53,6 +53,10 @@ plot.model_diagnostics_survival <- function(x, df$x <- switch(xvariable, "index" = unlist(lapply(n_observations, function(x) seq_len(x))), df[[xvariable]]) + + if (is.null(colors) || length(colors) < 3) + colors <- c("#160e3b", "#f05a71", "#ceced9") + pl <- with(df, { pl <- ggplot(df, aes(x = x, y = y, color = status)) + geom_hline(yintercept = 0, color = colors[3], lty = 2, linewidth = 1) + @@ -70,6 +74,9 @@ plot.model_diagnostics_survival <- function(x, }) return(pl) } else if (plot_type == "Cox-Snell"){ + if (is.null(colors) || length(colors) < 2) + colors <- c("#9fe5bd", "#000000") + split_df <- split(df, df$label) df_list <- lapply(split_df, function(df_tmp){ fit_coxsnell <- survival::survfit(survival::Surv(cox_snell_residuals, as.numeric(status)) ~ 1, data=df_tmp) @@ -90,10 +97,10 @@ plot.model_diagnostics_survival <- function(x, pl <- with(df, {ggplot(df, aes(x = time, y = cumhaz)) + + geom_abline(slope = 1, color = colors[2], linewidth = 1) + geom_step(color = colors[1], linewidth = 1) + geom_step(aes(y = lower), linetype = "dashed", color = colors[1], alpha = 0.8) + geom_step(aes(y = upper), linetype = "dashed", color = colors[1], alpha = 0.8) + - geom_abline(slope = 1, color = colors[2], linewidth = 1) + labs(x = "Cox-Snell residuals (pseudo observed times)", y = "Cumulative hazard at pseudo observed times") + theme_default_survex() + From 6f6f573a1b1a95a6d29e14b101518391dda54559 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Tue, 29 Aug 2023 19:50:34 +0200 Subject: [PATCH 202/207] Fix pkgdown action --- .github/workflows/test-coverage.yaml | 2 +- _pkgdown.yml | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 75ad46bb..b86748c2 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -27,5 +27,5 @@ jobs: needs: coverage - name: Test coverage - run: covr::codecov(quiet = FALSE, function_exclusions=c("surv_model_info\\.", "explain.LearnerSurv", "loss_adapt_mlr3proba")) + run: covr::codecov(quiet = FALSE, function_exclusions=c("surv_model_info\\.", "explain.LearnerSurv", "loss_adapt_mlr3proba", "explain.sksurv")) shell: Rscript {0} diff --git a/_pkgdown.yml b/_pkgdown.yml index 70ffca02..69a06f5a 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -34,6 +34,8 @@ reference: - plot.model_profile_2d_survival - subtitle: Model SurvSHAP(t) - contents: plot.aggregated_surv_shap +- subtitle: Model Diagnostics +- contents: plot.model_diagnostics_survival - subtitle: Predict Parts - contents: - plot.predict_parts_survival From 482c3a119f6bd29d637559d23958cf3ac83c9ce6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Spytek?= Date: Tue, 29 Aug 2023 20:01:56 +0200 Subject: [PATCH 203/207] Fix documentation mismatches --- man/plot.model_diagnostics_survival.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/plot.model_diagnostics_survival.Rd b/man/plot.model_diagnostics_survival.Rd index 934fd1ef..2f4bd27a 100644 --- a/man/plot.model_diagnostics_survival.Rd +++ b/man/plot.model_diagnostics_survival.Rd @@ -13,7 +13,7 @@ facet_ncol = NULL, title = "Model diagnostics", subtitle = "default", - colors = c("#160e3b", "#f05a71", "#ceced9") + colors = NULL ) } \arguments{ From 78bcb933444137d4d4c9bf0199ea2e3659139fbd Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 29 Aug 2023 22:47:45 +0200 Subject: [PATCH 204/207] fix plot when there are no outliers --- R/plot_surv_shap.R | 79 +++++++++++++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 26 deletions(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 7d0841d6..31766b97 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -176,13 +176,6 @@ plot.aggregated_surv_shap <- function(x, subtitle = "default", max_vars = 7, colors = NULL) { - if (is.null(colors)) { - colors <- c( - low = "#9fe5bd", - mid = "#46bac2", - high = "#371ea3" - ) - } if (geom == "swarm") { geom <- "beeswarm" @@ -245,7 +238,8 @@ plot_shap_global_importance <- function(x, rug = rug, rug_colors = rug_colors ) + - labs(y = ylab_right) + labs(y = ylab_right) + + theme_default_survex() label <- attr(x, "label") long_df <- stack(x$aggregate) @@ -261,11 +255,19 @@ plot_shap_global_importance <- function(x, ) } + if (is.null(colors)) { + colors <- c( + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3" + ) + } + left_plot <- with(long_df, { ggplot(long_df, aes(x = values, y = reorder(ind, values))) + geom_col(fill = colors[2]) + theme_default_survex() + - labs(x = xlab_left) + + labs(x = xlab_left, y = "variable") + theme(axis.title.y = element_blank()) }) @@ -310,6 +312,15 @@ plot_shap_global_beeswarm <- function(x, "(n=", x$n_observations, ")" ) } + + if (is.null(colors)) { + colors <- c( + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3" + ) + } + with(df, { ggplot(data = df, aes(x = shap_value, y = variable, color = var_value)) + geom_vline(xintercept = 0, color = "#ceced9", linetype = "solid") + @@ -373,7 +384,7 @@ plot_shap_global_profile <- function(x, ggplot(df, aes(x = variable_val, y = shap_val, color = color_variable_val)) + geom_hline(yintercept = 0, color = "#ceced9", linetype = "solid") + geom_point() + - geom_rug(aes(x = df$variable_val), inherit.aes = F, color = "#ceced9") + + geom_rug(aes(x = variable_val), inherit.aes = F, color = "#ceced9") + labs( x = paste(variable, "value"), y = "Aggregated SurvSHAP(t) value", @@ -390,6 +401,13 @@ plot_shap_global_profile <- function(x, values = generate_discrete_color_scale(length(unique(df$color_variable_val)), colors) ) } else { + if (is.null(colors)) { + colors <- c( + low = "#9fe5bd", + mid = "#46bac2", + high = "#371ea3" + ) + } p + scale_color_gradient2( name = paste(color_variable, "value"), low = colors[1], @@ -496,8 +514,18 @@ plot_shap_global_curves <- function(x, iqr <- upper_bound - lower_bound outlier_indices <- which(colSums((t(df) <= median - coef * iqr) + (t(df) >= median + coef * iqr)) > 0) - outliers <- df[outlier_indices, ] - nonoutliers <- df[-outlier_indices,] + + if (length(outlier_indices) > 0){ + nonoutliers <- df[-outlier_indices,] + outliers <- df[outlier_indices, ] + df_outliers <- data.frame(x = x$eval_times, + y = as.vector(t(outliers)), + obs = rep(1:length(outliers), each = length(x$eval_times))) + } else { + nonoutliers <- df + df_outliers <- NULL + } + whisker_lower_bound <- apply(nonoutliers, 2, min) whisker_upper_bound <- apply(nonoutliers, 2, max) @@ -505,10 +533,6 @@ plot_shap_global_curves <- function(x, y = as.vector(t(df)), obs = rep(1:n, each = length(x$eval_times))) - df_outliers <- data.frame(x = x$eval_times, - y = as.vector(t(outliers)), - obs = rep(1:length(outliers), each = length(x$eval_times))) - df_res <- data.frame(x = x$eval_times, median = median, lower_bound = lower_bound, @@ -521,22 +545,25 @@ plot_shap_global_curves <- function(x, geom_hline(yintercept = 0, alpha = 0.5, color = colors[1]) + geom_line(data = df_all, aes(x = x, y = y, group = obs), col = colors[1], alpha = 0.2, linewidth = 0.2) + geom_ribbon(data = df_res, aes(x = x, ymin = lower_bound, ymax = upper_bound), - fill = colors[2], alpha = 0.2) + - geom_line(data = df_res, aes(x = x, y = median), col = colors[2], alpha = 0.8) + - geom_line(data = df_res, aes(x = x, y = lower_bound), col = colors[2], alpha = 0.8) + - geom_line(data = df_res, aes(x = x, y = upper_bound), col = colors[2], alpha = 0.8) + - geom_line(data = df_res, aes(x = x, y = whisker_lower_bound), col = colors[2], lty = 2, alpha = 0.8) + - geom_line(data = df_res, aes(x = x, y = whisker_upper_bound), col = colors[2], lty = 2, alpha = 0.8) + - geom_line(data = df_outliers, aes(x = x, y = y, group = obs), - col = colors[3], linewidth = 0.5, alpha = 0.1) + + fill = colors[1], alpha = 0.4) + + geom_line(data = df_res, aes(x = x, y = median), col = colors[2], linewidth = 1) + + geom_line(data = df_res, aes(x = x, y = lower_bound), col = colors[2], linewidth = 1) + + geom_line(data = df_res, aes(x = x, y = upper_bound), col = colors[2], linewidth = 1) + + geom_line(data = df_res, aes(x = x, y = whisker_lower_bound), col = colors[2], lty = 2, linewidth = 1) + + geom_line(data = df_res, aes(x = x, y = whisker_upper_bound), col = colors[2], lty = 2, linewidth = 1) + theme_default_survex() + labs(x = "time", y = "SurvSHAP(t) value", title = title, subtitle = subtitle) }) - cat("Observations with outlying SurvSHAP(t) values:\n") - print(x$variable_values[outlier_indices,]) + if (length(outlier_indices) > 0){ + base_plot <- base_plot + + geom_line(data = df_outliers, aes(x = x, y = y, group = obs), + col = colors[3], linewidth = 0.5, alpha = 0.1) + cat("Observations with outlying SurvSHAP(t) values:\n") + print(x$variable_values[outlier_indices,]) + } } rug_df <- data.frame(times = x$event_times, statuses = as.character(x$event_statuses)) From eef35833a467916973a7edf8f9d39c6a6de7261f Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 29 Aug 2023 23:37:56 +0200 Subject: [PATCH 205/207] fix profile plots with `geom="variable"` for categorical variables --- R/plot_model_profile_survival.R | 31 +++++++++++++++++++++++++------ R/plot_predict_profile_survival.R | 14 ++++++++++---- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index 147235d0..d36638c7 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -456,20 +456,39 @@ plot_pdp_cat <- function(pdp_dt, scale_fill_manual(name = "time", values = colors) } } else { + pdp_dt$time <- as.numeric(as.character(pdp_dt$time)) + if (!is.null(ice_dt)) { + ice_dt$time <- as.numeric(as.character(ice_dt$time)) + } if (plot_type == "ice") { ggplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions)) + - geom_boxplot(alpha = 0.2, mapping = aes(color = time)) + - scale_color_manual(name = "time", values = colors) + geom_boxplot(mapping = aes(color = time, group = interaction(time, !!feature_name_count_sym)), alpha = 0.2) + + scale_colour_gradient2( + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = median(as.numeric(as.character(ice_dt$time))) + ) } else if (plot_type == "pdp+ice") { ggplot(mapping = aes(color = time)) + - geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions), alpha = 0.2, width = 0.7) + + geom_boxplot(data = ice_dt, aes(x = !!feature_name_count_sym, y = predictions, group = interaction(time, !!feature_name_count_sym)), alpha = 0.2, width = 0.7) + geom_line(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, group = time), linewidth = 0.6, position = position_dodge(0.7)) + - scale_color_manual(name = "time", values = colors) + scale_colour_gradient2( + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = median(as.numeric(as.character(pdp_dt$time))) + ) } else if (plot_type == "pdp" || plot_type == "ale") { - ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, fill = time)) + + ggplot(data = pdp_dt, aes(x = !!feature_name_count_sym, y = pd, fill = time, group = time)) + geom_bar(stat = "identity", width = 0.5, position = "dodge") + scale_y_continuous(expand = c(0, NA)) + - scale_fill_manual(name = "time", values = colors) + scale_fill_gradient2( + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = median(as.numeric(as.character(pdp_dt$time))) + ) } } }) diff --git a/R/plot_predict_profile_survival.R b/R/plot_predict_profile_survival.R index fa40d33c..d5eb64e0 100644 --- a/R/plot_predict_profile_survival.R +++ b/R/plot_predict_profile_survival.R @@ -272,7 +272,7 @@ plot_ice_num <-function(ice_df, low = colors[1], mid = colors[2], high = colors[3], - midpoint = median(as.numeric(as.character(pdp_dt$time))) + midpoint = median(as.numeric(as.character(ice_df$time))) ) } }) @@ -291,10 +291,16 @@ plot_ice_cat <- function(ice_df, scale_fill_manual(name = "time", values = colors) + geom_hline(yintercept = 0, linetype="dashed") } else { - ggplot(data = ice_df, aes(x = !!feature_name_sym, y = ice, fill = time)) + + ice_df$time <- as.numeric(as.character(ice_df$time)) + ggplot(data = ice_df, aes(x = !!feature_name_sym, y = ice, fill = time, group = time)) + geom_bar(stat = "identity", width = 0.5, position = "dodge") + - scale_y_continuous() + - scale_fill_manual(name = "time", values = colors) + scale_y_continuous(expand = c(0, NA)) + + scale_fill_gradient2( + low = colors[1], + mid = colors[2], + high = colors[3], + midpoint = median(as.numeric(as.character(ice_df$time))) + ) } }) } From e28fb272a2a9b8665eec448905cc2df3f33a35cd Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Tue, 29 Aug 2023 23:40:28 +0200 Subject: [PATCH 206/207] fix check warnings (context in plots) --- R/plot_surv_shap.R | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 31766b97..3cde5f3f 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -540,7 +540,7 @@ plot_shap_global_curves <- function(x, whisker_lower_bound = whisker_lower_bound, whisker_upper_bound = whisker_upper_bound) - base_plot <- with(list(df_all, df_outliers, df_res), + base_plot <- with(list(df_all, df_res), {ggplot() + geom_hline(yintercept = 0, alpha = 0.5, color = colors[1]) + geom_line(data = df_all, aes(x = x, y = y, group = obs), col = colors[1], alpha = 0.2, linewidth = 0.2) + @@ -558,9 +558,12 @@ plot_shap_global_curves <- function(x, subtitle = subtitle) }) if (length(outlier_indices) > 0){ - base_plot <- base_plot + - geom_line(data = df_outliers, aes(x = x, y = y, group = obs), - col = colors[3], linewidth = 0.5, alpha = 0.1) + base_plot <- with(df_outliers, + {base_plot + + geom_line(data = df_outliers, aes(x = x, y = y, group = obs), + col = colors[3], linewidth = 0.5, alpha = 0.1) + } + ) cat("Observations with outlying SurvSHAP(t) values:\n") print(x$variable_values[outlier_indices,]) } From 2bc62152538d5200389f076399be661f29248089 Mon Sep 17 00:00:00 2001 From: krzyzinskim Date: Wed, 30 Aug 2023 13:32:05 +0200 Subject: [PATCH 207/207] bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 3bccf60a..999a5b70 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.0.0.9102 +Version: 1.1.0 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")),