diff --git a/R/summarize_change.R b/R/summarize_change.R index 6cca7d36e9..e6fd1d5fd7 100644 --- a/R/summarize_change.R +++ b/R/summarize_change.R @@ -70,7 +70,7 @@ a_change_from_baseline <- function(df, custom_stat_functions <- default_and_custom_stats_list$custom_stats # Adding automatically extra parameters to the statistic function (see ?rtables::additional_fun_params) - extra_afun_params <- names(get_additional_analysis_fun_parameters(add_alt_df = FALSE)) + extra_afun_params <- names(list(...)$.additional_fun_parameters) x_stats <- .apply_stat_functions( default_stat_fnc = s_change_from_baseline, custom_stat_fnc_list = custom_stat_functions, @@ -82,7 +82,10 @@ a_change_from_baseline <- function(df, ) # Fill in with formatting defaults if needed - .stats <- get_stats("analyze_vars_numeric", stats_in = .stats) + .stats <- c( + get_stats("analyze_vars_numeric", stats_in = .stats), + names(custom_stat_functions) # Additional stats from custom functions + ) .formats <- get_formats_from_stats(.stats, .formats) .labels <- get_labels_from_stats(.stats, .labels) .indent_mods <- get_indents_from_stats(.stats, .indent_mods) @@ -177,9 +180,10 @@ summarize_change <- function(lyt, extra_args <- c(extra_args, "variables" = list(variables), ...) # Adding all additional information from layout to analysis functions (see ?rtables::additional_fun_params) + extra_args[[".additional_fun_parameters"]] <- get_additional_analysis_fun_parameters(add_alt_df = FALSE) formals(a_change_from_baseline) <- c( formals(a_change_from_baseline), - get_additional_analysis_fun_parameters() + extra_args[[".additional_fun_parameters"]] ) # Main analysis call - Nothing with .* -> these should be dedicated to the analysis function diff --git a/R/utils_default_stats_formats_labels.R b/R/utils_default_stats_formats_labels.R index 606571c2e1..2b47ca88c6 100644 --- a/R/utils_default_stats_formats_labels.R +++ b/R/utils_default_stats_formats_labels.R @@ -126,6 +126,7 @@ get_stats <- function(method_groups = "analyze_vars_numeric", stats_in = NULL, a out <- list(default_stats = NULL, custom_stats = NULL) if (is.list(stats_in)) { is_custom_fnc <- sapply(stats_in, is.function) + checkmate::assert_list(stats_in[is_custom_fnc], types = "function", names = "named") out[["custom_stats"]] <- stats_in[is_custom_fnc] out[["default_stats"]] <- unlist(stats_in[!is_custom_fnc]) } else { @@ -139,7 +140,7 @@ get_stats <- function(method_groups = "analyze_vars_numeric", stats_in = NULL, a .apply_stat_functions <- function(default_stat_fnc, custom_stat_fnc_list, args_list) { # Default checks checkmate::assert_function(default_stat_fnc) - checkmate::assert_list(custom_stat_fnc_list, types = "function", null.ok = TRUE) + checkmate::assert_list(custom_stat_fnc_list, types = "function", null.ok = TRUE, names = "named") checkmate::assert_list(args_list) # Checking custom stats have same formals diff --git a/tests/testthat/_snaps/summarize_change.md b/tests/testthat/_snaps/summarize_change.md index b0e8774d89..9752712464 100644 --- a/tests/testthat/_snaps/summarize_change.md +++ b/tests/testthat/_snaps/summarize_change.md @@ -499,3 +499,20 @@ Median -2.00 Min - Max -2.00 - -2.00 +# summarize_change works with custom statistical functions + + Code + res + Output + all obs + ——————————————————— + V1 + n 3 + my_stat 1.00 + V2 + n 3 + my_stat 0.83 + V3 + n 3 + my_stat 0.67 + diff --git a/tests/testthat/test-summarize_change.R b/tests/testthat/test-summarize_change.R index 2b3b457a81..adf60ffc68 100644 --- a/tests/testthat/test-summarize_change.R +++ b/tests/testthat/test-summarize_change.R @@ -123,7 +123,15 @@ testthat::test_that("summarize_change works with custom statistical functions", summarize_change( "CHG", variables = list(value = "AVAL", baseline_flag = "ABLFLL"), - .stats = c("n", "my_stat" = function(df, ...) mean(df$AVISIT, na.rm = TRUE)) + .stats = c("n", "my_stat" = function(df, ...) { + a <- mean(df$AVAL, na.rm = TRUE) + b <- list(...)$.N_row + a / b + }), + .formats = c("my_stat" = function(x, ...) sprintf("%.2f", x)) ) %>% build_table(dta_test) + + res <- testthat::expect_silent(result) + testthat::expect_snapshot(res) }) diff --git a/vignettes/tern_functions_guide.Rmd b/vignettes/tern_functions_guide.Rmd index 200d1764fc..929c106ff2 100644 --- a/vignettes/tern_functions_guide.Rmd +++ b/vignettes/tern_functions_guide.Rmd @@ -115,17 +115,21 @@ fix_layout %>% print() ``` -Adding a custom statistic: +Adding a custom statistic (and custom format): ```{r} # changing n count format and label and indentation fix_layout %>% - summarize_change("CHG", - variables = list(value = "AVAL", baseline_flag = "ABLFLL"), - .stats = c("n", "mean", "arg" = function(x, ...) mean(x)), - .formats = c(n = function(x, ...) as.character(x * 100)) - ) %>% # Note you need ...!!! - build_table(dta_test) %>% - print() + summarize_change( + "CHG", + variables = list(value = "AVAL", baseline_flag = "ABLFLL"), + .stats = c("n", "my_stat" = function(df, ...) { + a <- mean(df$AVAL, na.rm = TRUE) + b <- list(...)$.N_row # It has access at all `?rtables::additional_fun_params` + a / b + }), + .formats = c("my_stat" = function(x, ...) sprintf("%.2f", x)) + ) %>% + build_table(dta_test) ```