Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

requiring bake helper function when skip = TRUE #3

Open
tedmoorman opened this issue Nov 19, 2021 · 1 comment
Open

requiring bake helper function when skip = TRUE #3

tedmoorman opened this issue Nov 19, 2021 · 1 comment

Comments

@tedmoorman
Copy link

I'm receiving an error that I need a bake helper function, although I specified that I don't want the step applied to bake by stating skip = TRUE.

step_custom_transformation(my_variable, prep_function = my_function, skip = TRUE)
Error in step_custom_transformation(., my_variable, prep_function = my_function,  : 
  No bake helper function ('bake_function') has been specified.

I guess I can work around it by supplying any function, but it might be better if the requirement were changed.

@tedmoorman
Copy link
Author

I'm trying to modify the code. I think I've got the following correct.

d_vars = selected_vars,
        skip = skip,
        id = id
      )
    )
    }

# constructor function.
#' @importFrom recipes step
step_custom_transformation_new <-
  function(terms = NULL,
           role = "predictor",
           trained = FALSE,
           prep_function = NULL,
           prep_options = NULL,
           prep_output = prep_output,
           bake_function = NULL,
           bake_options = NULL,
           bake_how = "bind_cols",
           selected_vars = NULL,
           skip = FALSE,
           id = id) {
    step(
      subclass = "custom_transformation",
      terms = terms,
      role = role,
      trained = trained,
      prep_function = prep_function,
      prep_options = prep_options,
      prep_output = prep_output,
      bake_function = bake_function,
      bake_options = bake_options,
      bake_how = bake_how,
      selected_vars = selected_vars,
      skip = skip,
      id = id
    )
  }

# prepare step (train step/estimate (any) parameters from initial data set).
#' @export
#' @importFrom recipes prep terms_select
#' @importFrom purrr invoke
prep.step_custom_transformation <- function(x, training, info = NULL, ...) {
  
  # selected vars as character vector.
  selected_vars <- terms_select(x$terms, info = info)
  
  # if no prep helper function has been specified, do nothing. Invoke the
  # prep helper function otherwise.
  if (!is.null(x$prep_function)) {
    
    #### prepare all arguments before calling the prep helper function.
    
    # add mandatory argument 'x'.
    args <- list(x = training[, selected_vars])
    
    # add additional arguments (if any).
    if (!is.null(x$prep_options)) {
      args <- append(args, x$prep_options)
    }
    
    # compute intermediate output from prep helper function.
    prep_output <- tryCatch({
      invoke(x$prep_function, args)},
      error = function(e) {
        stop("An error occured in the call to the prep helper function",
             " ('prep_function'). See details below: \n",
             e)
      })
    
  } else {
    
    # set output to NULL otherwise.
    prep_output <- NULL
    
  }
  
  step_custom_transformation_new(
    terms = x$terms,
    role = x$role,
    trained = TRUE,
    prep_function = x$prep_function,
    prep_options = x$prep_options,
    prep_output = prep_output,
    bake_function = x$bake_function,
    bake_options = x$bake_options,
    bake_how = x$bake_how,
    selected_vars = selected_vars,
    skip = x$skip,
    id = x$id
  )
  
}

Here is where I start getting hung up.

# bake step (/apply transformation to new data set).
#' @export
#' @importFrom dplyr bind_cols select
#' @importFrom purrr invoke
#' @importFrom recipes bake
#' @importFrom tibble as_tibble
bake.step_custom_transformation <- function(object, new_data, ...) {
  #### prepare arguments before calling the bake helper function.
  
  # add mandatory argument for 'x' - set to new data set.
  args <- list(x = new_data)
  
  # add intermediate output from the prep helper function.
  if (!is.null(object$prep_output)) {
    args <- append(args, list(prep_output = object$prep_output))
  }
  
  # add additional arguments (if any).
  if (!is.null(object$bake_options)) {
    args <- append(args, object$bake_options)
  }
  
  # invoke the bake helper function.
  bake_function_output <- 
    tryCatch({
      invoke(object$bake_function, args)
    },
    error = function(e) {
      stop("An error occured in the call to the bake helper function",
           " ('bake_function'). See details below: \n",
           e)
    })
  
  # convert to tibble.
  bake_function_output <-
    tryCatch({
      bake_function_output %>%
        as_tibble(.)
    },
    error = function(e) {
      stop("Unable to convert output from bake helper function to tibble.")
    })
  
  # check dimensions of output from bake helper function.
  if (nrow(bake_function_output) != nrow(new_data)) {
    stop("There was a mismatch between the number of rows ",
         "in the output from the bake helper function (",
         nrow(bake_function_output),
         ") and the number of rows of the input data (",
         nrow(new_data), ").")
  }
  
  # append transformed variables to new data set.
  output <- switch(object$bake_how,
                   
                   # append output to input by binding columns.
                   "bind_cols" = {
                     
                     # bind output columns to input data.frame.
                     new_data %>%
                       as_tibble() %>%
                       bind_cols(bake_function_output)
                     
                   },
                   
                   # replace selected variables with output.
                   "replace" = {
                     
                     new_data %>%
                       as_tibble() %>%
                       # drop selected vars.
                       select(-c(object$selected_vars)) %>%
                       # bind output columns to input data.frame.
                       bind_cols(bake_function_output)
                     
                   })
  
  # return output.
  output
}

#' @export
print.step_custom_transformation <-
  function(x, width = max(20, options()$width - 30), ...) {
    
    cat("The following variables are used for computing" ,
        " transformations", ifelse(x$skip == FALSE && x$bake_how == "replace",
                                   "\n and will be dropped afterwards:\n ",
                                   ":\n "), sep = "")
    cat(format_selectors(x$terms, wdth = width))
    invisible(x)
    
  }

#' @rdname step_custom_transformation
#' @param x A `step_custom_transformation` object.
#' @export
#' @importFrom generics tidy
#' @importFrom tibble tibble
#' @importFrom recipes sel2char
tidy.step_custom_transformation <- function(x, ...) {
  
  res <- tibble(terms = sel2char(x$terms))
  res$id <- x$id
  res
  
}

I'm not sure what syntax I need to ignore error messages related to bake and any bake-related functions.

Here is the example I'm using for testing.

library(dplyr)
library(purrr)
library(tibble)
library(recipes)
library(generics)

# divide 'mtcars' into two data sets.
cars_initial <- mtcars[1:16, ]
cars_new <- mtcars[17:nrow(mtcars), ]

# define prep helper function, that computes means and standard deviations
# for (an arbitrary number of) numeric variables.
compute_means_sd <- function(x) {

 map(.x = x, ~ list(mean = mean(.x), sd = sd(.x)))

}

# define bake helper function, that centers numeric variables to have
# a mean of 'alpha' and scale them to have a standard deviation of
# 'beta'.
center_scale <- function(x, prep_output, alpha, beta) {

  # extract only the relevant variables from the new data set.
  new_data <- select(x, names(prep_output))

  # apply transformation to each of these variables.
  # variables are centered around 'alpha' and scaled to have a standard
  # deviation of 'beta'.
  map2(.x = new_data,
       .y = prep_output,
       ~ alpha + (.x - .y$mean) * beta / .y$sd)

}

# create recipe.
rec <- recipe(cars_initial) %>%
  step_custom_transformation(mpg, disp,
                             prep_function = compute_means_sd,
                             skip = TRUE)

# prep recipe.
rec_prep <- prep(rec)

# bake recipe.
rec_baked <- bake(rec_prep, cars_new)
rec_baked

# inspect output.
rec
rec_baked
tidy(rec)
tidy(rec, 1)
tidy(rec_prep)
tidy(rec_prep, 1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant