diff --git a/R/surv_shap.R b/R/surv_shap.R index 15b2ff7..75737ba 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -104,7 +104,7 @@ surv_shap <- function(explainer, res$result <- switch(calculation_method, "exact_kernel" = use_exact_shap(explainer, new_observation, output_type, N, ...), "kernelshap" = use_kernelshap(explainer, new_observation, output_type, N, ...), - "treeshap" = use_treeshap(explainer, new_observation, ...), + "treeshap" = use_treeshap(explainer, new_observation, output_type, ...), stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented")) # quality-check here stopifnot( @@ -282,7 +282,7 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) { return(shap_values) } -use_treeshap <- function(explainer, new_observation, ...){ +use_treeshap <- function(explainer, new_observation, output_type, ...){ stopifnot( "new_observation must be a data.frame" = inherits( @@ -292,45 +292,26 @@ use_treeshap <- function(explainer, new_observation, ...){ # init unify_append_args unify_append_args <- list() - if (inherits(explainer$model, "ranger")) { - # UNIFY_FUN to prepare code for easy Integration of other ml algorithms - # that are supported by treeshap - UNIFY_FUN <- treeshap::ranger_surv.unify - unify_append_args <- list(type = "survival", times = explainer$times) - } else { + if (!inherits(explainer$model, "ranger")) { stop("Support for `treeshap` is currently only implemented for `ranger`.") } - unify_args <- list( - rf_model = explainer$model, - data = explainer$data - ) - - if (length(unify_append_args) > 0) { - unify_args <- c(unify_args, unify_append_args) - } - - tmp_unified <- do.call(UNIFY_FUN, unify_args) + tmp_unified <- treeshap::unify(explainer$model, + explainer$data, + type = output_type, + times = explainer$times) shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { + # ensure that matrix has expected dimensions; as.integer is + # necessary for valid comparison with "identical" + new_obs_mat <- new_observation[as.integer(i), ] + stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation))))) + tmp_res <- do.call( rbind, - lapply( - tmp_unified, - function(m) { - new_obs_mat <- new_observation[as.integer(i), ] - # ensure that matrix has expected dimensions; as.integer is - # necessary for valid comparison with "identical" - stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation))))) - treeshap::treeshap( - unified_model = m, - x = new_obs_mat, - ... - )$shaps - } - ) + lapply(treeshap::treeshap(tmp_unified, x = new_obs_mat, ...), function(x) x$shaps) ) tmp_shap_values <- data.frame(tmp_res)