Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow user to modify model prediction results #12

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Description: A general test for conditional independence in supervised learning
License: GPL (>= 3)
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.1.2
RoxygenNote: 7.2.2
URL: https://github.com/bips-hb/cpi,
https://bips-hb.github.io/cpi/
BugReports: https://github.com/bips-hb/cpi/issues
Expand Down
9 changes: 8 additions & 1 deletion R/compute_loss.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

# Internal function to compute sample loss
compute_loss <- function(pred, measure) {
compute_loss <- function(pred, measure, modify_trp, ...) {
if (inherits(pred, "Prediction")) {
truth <- pred$truth
response <- pred$response
Expand All @@ -11,6 +11,13 @@ compute_loss <- function(pred, measure) {
prob <- do.call(rbind, lapply(pred, function(x) x$prob))
}

if (is.function(modify_trp)) {
out <- modify_trp(truth, response, prob, ...)
truth <- out$truth
response <- out$response
prob <- out$prob
}

if (measure$id == "regr.mse") {
# Squared errors
loss <- (truth - response)^2
Expand Down
116 changes: 112 additions & 4 deletions R/cpi.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
#' @param groups (Named) list with groups. Set to \code{NULL} (default) for no
#' groups, i.e. compute CPI for each feature. See examples.
#' @param verbose Verbose output of resampling procedure.
#' @param modify_trp An optional function to modify the \code{truth}, \code{response},
#' and \code{prob} values of the model output. Function must take as arguments the
#' original \code{truth}, \code{response}, and \code{prob} output from model
#' predictions and return a named list of \code{truth}, \code{response}, and
#' \code{prob} values, some of which may be modified. For modified values to work
#' nicely with the rest of the package, they must match the expected format of
#' unmodified values (e.g., \code{prob} must be a data.frame with the number of columns
#' being the number of classes and column names being the class labels; the first
#' level of the \code{truth} factor must be the positive class). See examples. Set to
#' \code{FALSE} (default) to use unmodified truth, response, and probability values
#' from model output.
#' @param ... Optional arguments to be passed to the \code{modify_trp} function
#' (e.g., a classification threshold).
#'
#' @return
#' For \code{test = "bayes"} a list of \code{BEST} objects. In any other
Expand Down Expand Up @@ -133,8 +146,101 @@
#' cpi(task = mytask, learner = lrn("regr.ranger"),
#' resampling = rsmp("holdout"),
#' knockoff_fun = seqknockoff::knockoffs_seq)
#' }
#'
#' # Use a function to modify the truth, response, and probability model outputs
#' # Add a probability data.frame to predictions when regression was run on a
#' # 0/1 binary outcome (i.e., treat the response as a probability, make the
#' # response a "classified" outcome)
#'
#' # Data prep:
#' data <- palmerpenguins::penguins
#' data$species <- ifelse(data$species == "Adelie", 1, 0)
#' keep_cols <- c("species", "bill_length_mm", "bill_depth_mm",
#' "flipper_length_mm", "body_mass_g")
#' data <- data[, keep_cols]
#' data <- data[complete.cases(data), ]
#'
#' response_is_prob <- function(truth, response, prob) {
#' prob <- setNames(data.frame(response, 1 - response), nm = c("1", "0"))
#' response <- ifelse(response >= 0.5, yes = 1, no = 0)
#'
#' return(list(truth = truth, response = response, prob = prob))
#' }
#'
#' cpi(task = TaskRegr$new(id = "penguins.regr", backend = data, target = "species"),
#' learner = lrn("regr.ranger"),
#' resampling = rsmp("holdout"),
#' measure = "classif.logloss", test = "t",
#' modify_trp = response_is_prob)
#'
#' # Same as above, but also passing an additional argument to the
#' # modify_trp function that makes a new classification threshold
#' # Same data can be used.
#'
#' response_is_prob_new_thresh <-
#' function(truth, response, prob, classification_thresh) {
#' prob <- setNames(data.frame(response, 1 - response), nm = c("1", "0"))
#' response <- ifelse(response >= classification_thresh, yes = 1, no = 0)
#'
#' return(list(truth = truth, response = response, prob = prob))
#' }
#'
#' cpi(task = TaskRegr$new(id = "penguins.regr", backend = data, target = "species"),
#' learner = lrn("regr.ranger"),
#' resampling = rsmp("holdout"),
#' measure = "classif.logloss", test = "t",
#' modify_trp = response_is_prob_new_thresh,
#' classification_thresh = 0.3)
#'
#' # Setting a new classification threshold on a "classif" learner and rescaling
#' # probability outputs such that predictions lower than the new threshold
#' # are on [0,0.5] and predictions above the new threshold are on [0.5, 1]
#'
#' rescale_prob <- function(truth, response, prob, classification_thresh) {
#'
#' rescale <- function(x, old_max, old_min, new_max, new_min) {
#' return(((new_max - new_min) / (old_max - old_min)) * (x - old_max) + new_max)
#' }
#'
#' classes <- levels(truth)
#' response <- ifelse(prob[, classes[1]] >= classification_thresh,
#' yes = classes[1], no = classes[2])
#' prob[, classes[1]] <- ifelse(prob[, classes[1]] <= classification_thresh,
#' yes = rescale(x = prob[, classes[1]],
#' old_max = classification_thresh,
#' old_min = 0,
#' new_max = 0.5,
#' new_min = 0),
#' no = rescale(x = prob[, classes[1]],
#' old_max = 1,
#' old_min = classification_thresh,
#' new_max = 1,
#' new_min = 0.5))
#'
#' prob[, classes[2]] <- 1 - prob[, classes[1]]
#'
#' return(list(truth = truth, response = response, prob = prob))
#' }
#'
#' # Data prep
#' data = palmerpenguins::penguins
#' data$species = factor(ifelse(data$species == "Adelie", "1", "0"))
#' keep_cols <- c("species", "bill_length_mm", "bill_depth_mm",
#' "flipper_length_mm", "body_mass_g")
#' data <- data[, keep_cols]
#' data <- data[complete.cases(data), ]
#'
#' cpi(task = TaskClassif$new(id = "penguins.binary", backend = data,
#' target = "species", positive = "1"),
#' learner = lrn("classif.ranger", predict_type = "prob"),
#' resampling = rsmp("holdout"),
#' measure = "classif.logloss", test = "t",
#' modify_trp = rescale_prob,
#' classification_thresh = 0.3)
#' }
#'


cpi <- function(task, learner,
resampling = NULL,
test_data = NULL,
Expand All @@ -146,7 +252,9 @@ cpi <- function(task, learner,
x_tilde = NULL,
knockoff_fun = function(x) knockoff::create.second_order(as.matrix(x)),
groups = NULL,
verbose = FALSE) {
verbose = FALSE,
modify_trp = FALSE,
...) {

# Set verbose level (and save old state)
old_logger_treshold <- lgr::get_logger("mlr3")$threshold
Expand Down Expand Up @@ -220,7 +328,7 @@ cpi <- function(task, learner,
fit_full <- fit_learner(learner = learner, task = task, resampling = resampling,
measure = measure, test_data = test_data, verbose = verbose)
pred_full <- predict_learner(fit_full, task, resampling = resampling, test_data = test_data)
err_full <- compute_loss(pred_full, measure)
err_full <- compute_loss(pred_full, measure, modify_trp, ...)

# Generate knockoff data
if (is.null(x_tilde)) {
Expand Down Expand Up @@ -265,7 +373,7 @@ cpi <- function(task, learner,

# Predict with knockoff data
pred_reduced <- predict_learner(fit_full, reduced_task, resampling = resampling, test_data = reduced_test_data)
err_reduced <- compute_loss(pred_reduced, measure)
err_reduced <- compute_loss(pred_reduced, measure, modify_trp, ...)
if (log) {
dif <- log(err_reduced / err_full)
} else {
Expand Down
112 changes: 110 additions & 2 deletions man/cpi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.