Skip to content

Commit

Permalink
update survshap for new version of treeshap
Browse files Browse the repository at this point in the history
  • Loading branch information
krzyzinskim committed Oct 24, 2023
1 parent 19365d1 commit 9fbc227
Showing 1 changed file with 13 additions and 32 deletions.
45 changes: 13 additions & 32 deletions R/surv_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ surv_shap <- function(explainer,
res$result <- switch(calculation_method,
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, N, ...),
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, N, ...),
"treeshap" = use_treeshap(explainer, new_observation, ...),
"treeshap" = use_treeshap(explainer, new_observation, output_type, ...),
stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented"))
# quality-check here
stopifnot(
Expand Down Expand Up @@ -282,7 +282,7 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
return(shap_values)
}

use_treeshap <- function(explainer, new_observation, ...){
use_treeshap <- function(explainer, new_observation, output_type, ...){

stopifnot(
"new_observation must be a data.frame" = inherits(
Expand All @@ -292,45 +292,26 @@ use_treeshap <- function(explainer, new_observation, ...){
# init unify_append_args
unify_append_args <- list()

if (inherits(explainer$model, "ranger")) {
# UNIFY_FUN to prepare code for easy Integration of other ml algorithms
# that are supported by treeshap
UNIFY_FUN <- treeshap::ranger_surv.unify
unify_append_args <- list(type = "survival", times = explainer$times)
} else {
if (!inherits(explainer$model, "ranger")) {
stop("Support for `treeshap` is currently only implemented for `ranger`.")
}

unify_args <- list(
rf_model = explainer$model,
data = explainer$data
)

if (length(unify_append_args) > 0) {
unify_args <- c(unify_args, unify_append_args)
}

tmp_unified <- do.call(UNIFY_FUN, unify_args)
tmp_unified <- treeshap::unify(explainer$model,
explainer$data,
type = output_type,
times = explainer$times)

shap_values <- sapply(
X = as.character(seq_len(nrow(new_observation))),
FUN = function(i) {
# ensure that matrix has expected dimensions; as.integer is
# necessary for valid comparison with "identical"
new_obs_mat <- new_observation[as.integer(i), ]
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))

tmp_res <- do.call(
rbind,
lapply(
tmp_unified,
function(m) {
new_obs_mat <- new_observation[as.integer(i), ]
# ensure that matrix has expected dimensions; as.integer is
# necessary for valid comparison with "identical"
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))
treeshap::treeshap(
unified_model = m,
x = new_obs_mat,
...
)$shaps
}
)
lapply(treeshap::treeshap(tmp_unified, x = new_obs_mat, ...), function(x) x$shaps)
)

tmp_shap_values <- data.frame(tmp_res)
Expand Down

0 comments on commit 9fbc227

Please sign in to comment.