Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create inits from fit and draws objects #937

Merged
merged 34 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ae46447
adds feature to make inits from fit and draws objects
SteveBronder Mar 21, 2024
a89f31f
update to import stats::aggregate
SteveBronder Mar 21, 2024
0ebe16d
remove use of \() anonymous function
SteveBronder Mar 21, 2024
099a352
fix pareto_smooth call with explicit argument
avehtari Mar 28, 2024
e4ec68e
Merge branch 'master' into feature/fit-inits
jgabry Mar 28, 2024
6c1092b
update stevebronder from ctb to auth
SteveBronder Apr 3, 2024
7da1f17
remove extra files
SteveBronder Apr 3, 2024
872405f
remove extra file
SteveBronder Apr 3, 2024
d85f3b5
add tests for partial variable matching and allow draws objs with les…
SteveBronder Apr 4, 2024
6cd3a63
update r-setup action version
SteveBronder Apr 4, 2024
a7c3400
update rtools version for github actions
SteveBronder Apr 4, 2024
feee313
update rtools for github workflow
SteveBronder Apr 4, 2024
1197ad4
fix init test
SteveBronder Apr 4, 2024
a8297d3
register s3 methods
jgabry Apr 15, 2024
c9a74c1
Merge branch 'master' into feature/fit-inits
jgabry Apr 15, 2024
67f029a
try to fix r cmd check warning about s3 methods
jgabry Apr 15, 2024
e2d3c41
update docs for init object
SteveBronder Apr 17, 2024
238e9bb
Merge remote-tracking branch 'origin' into feature/fit-inits
SteveBronder Apr 17, 2024
9821640
update docs and failing test
SteveBronder Apr 18, 2024
1b2ab35
kickoff again
SteveBronder Apr 18, 2024
20dfe3d
add context to test inits
SteveBronder Apr 18, 2024
2bbfbbf
update docs for inits
SteveBronder Apr 18, 2024
e867c24
add Sys.sleep(1) to test_inits and remove capture.output
SteveBronder Apr 18, 2024
12974f7
Merge branch 'master' into feature/fit-inits
andrjohns Apr 19, 2024
392941e
testing
SteveBronder Apr 19, 2024
e03a643
Merge remote-tracking branch 'origin/feature/fit-inits' into feature/…
SteveBronder Apr 19, 2024
df7617d
fix merge conflicts
jgabry Apr 22, 2024
f14641e
Merge branch 'master' into feature/fit-inits
jgabry Apr 22, 2024
e711265
remove set.seed() in each init test
SteveBronder May 3, 2024
213cfe7
update
SteveBronder May 3, 2024
25d0a74
update for stringi temporary fix
SteveBronder May 3, 2024
cd4eedd
trying workaround for stringi
SteveBronder May 3, 2024
d984931
...
SteveBronder May 3, 2024
ef9e718
remove things to try to fix stringi cli error
SteveBronder May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Authors@R:
email = "[email protected]", comment = c(ORCID = "0000-0003-1878-3253")),
person(given = "Jacob", family = "Socolar", role = "ctb"),
person(given = "Martin", family = "Modrák", role = "ctb"),
person(given = "Steve", family = "Bronder", role = "ctb"))
person(given = "Steve", family = "Bronder", role = "aut"))
Description: A lightweight interface to 'Stan' <https://mc-stan.org>.
The 'CmdStanR' interface is an alternative to 'RStan' that calls the command
line interface for compilation and running algorithms instead of interfacing
Expand All @@ -29,7 +29,7 @@ URL: https://mc-stan.org/cmdstanr/, https://discourse.mc-stan.org
BugReports: https://github.com/stan-dev/cmdstanr/issues
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.3.0
RoxygenNote: 7.3.1
Roxygen: list(markdown = TRUE, r6 = FALSE)
SystemRequirements: CmdStan (https://mc-stan.org/users/interfaces/cmdstan)
Depends:
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ export(write_stan_json)
export(write_stan_tempfile)
import(R6)
importFrom(posterior,as_draws)
importFrom(stats,aggregate)
231 changes: 219 additions & 12 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ CmdStanArgs <- R6::R6Class(
}
self$output_dir <- repair_path(self$output_dir)
self$output_basename <- output_basename
if (is.function(init)) {
init <- process_init_function(init, length(self$proc_ids), model_variables)
} else if (is.list(init) && !is.data.frame(init)) {
init <- process_init_list(init, length(self$proc_ids), model_variables)
if (inherits(self$method_args, "PathfinderArgs")) {
num_inits <- self$method_args$num_paths
} else {
num_inits <- length(self$proc_ids)
}
init <- process_init(init, num_inits, model_variables)
self$init <- init
self$opencl_ids <- opencl_ids
self$num_threads = NULL
Expand Down Expand Up @@ -691,7 +692,12 @@ validate_cmdstan_args <- function(self) {
assert_file_exists(self$data_file, access = "r")
}
num_procs <- length(self$proc_ids)
validate_init(self$init, num_procs)
if (inherits(self$method_args, "PathfinderArgs")) {
num_inits <- self$method_args$num_paths
} else {
num_inits <- length(self$proc_ids)
}
validate_init(self$init, num_inits)
validate_seed(self$seed, num_procs)
if (!is.null(self$opencl_ids)) {
if (cmdstan_version() < "2.26") {
Expand Down Expand Up @@ -1018,17 +1024,63 @@ validate_exe_file <- function(exe_file) {
invisible(TRUE)
}


#' Generic for processing inits
#' @noRd
process_init <- function(...) {
UseMethod("process_init")
}

#' Default method
#' @noRd
process_init.default <- function(x, ...) {
return(x)
}

#' Write initial values to files if provided as posterior `draws` object
#' @noRd
#' @param init A type that inherits the `posterior::draws` class.
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.draws <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
if (!is.null(model_variables)) {
variable_names = names(model_variables$parameters)
} else {
variable_names = colnames(draws)[!grepl("__", colnames(draws))]
}
draws <- posterior::subset_draws(init, variable = variable_names)
draws <- posterior::resample_draws(draws, ndraws = num_procs,
method ="simple_no_replace")
draws_rvar = posterior::as_draws_rvars(draws)
inits = lapply(1:num_procs, function(draw_iter) {
init_i = lapply(variable_names, function(var_name) {
x = drop(posterior::draws_of(drop(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter))))
return(x)
})
names(init_i) = variable_names
return(init_i)
})
return(process_init(inits, num_procs, model_variables, warn_partial))
}

#' Write initial values to files if provided as list of lists
#' @noRd
#' @param init List of init lists.
#' @param num_procs Number of CmdStan processes.
#' @param num_procs Number of inits needed.
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init_list <- function(init, num_procs, model_variables = NULL,
process_init.list <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
if (!all(sapply(init, function(x) is.list(x) && !is.data.frame(x)))) {
stop("If 'init' is a list it must be a list of lists.", call. = FALSE)
Expand Down Expand Up @@ -1083,10 +1135,11 @@ process_init_list <- function(init, num_procs, model_variables = NULL,
}
init_paths <-
tempfile(
pattern = paste0("init-", seq_along(init), "-"),
pattern = "init-",
tmpdir = cmdstan_tempdir(),
fileext = ".json"
fileext = ""
)
init_paths <- paste0(init_paths, "_", seq_along(init), ".json")
for (i in seq_along(init)) {
write_stan_json(init[[i]], init_paths[i])
}
Expand All @@ -1096,11 +1149,12 @@ process_init_list <- function(init, num_procs, model_variables = NULL,
#' Write initial values to files if provided as function
#' @noRd
#' @param init Function generating a single list of initial values.
#' @param num_procs Number of CmdStan processes.
#' @param num_procs Number of inits needed.
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @return A character vector of file paths.
process_init_function <- function(init, num_procs, model_variables = NULL) {
process_init.function <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
args <- formals(init)
if (is.null(args)) {
fn_test <- init()
Expand All @@ -1116,7 +1170,160 @@ process_init_function <- function(init, num_procs, model_variables = NULL) {
if (!is.list(fn_test) || is.data.frame(fn_test)) {
stop("If 'init' is a function it must return a single list.")
}
process_init_list(init_list, num_procs, model_variables)
process_init(init_list, num_procs, model_variables)
}

#' Write initial values to files if provided as a `CmdStanMCMC` class
#' @noRd
#' @param init A `CmdStanMCMC` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output of model$variables() includes also other variables than parameters (data, transformed parameters, and generated quantities). More accurate would be "output of model$variables()$parameters". The same comment repeats in many function docs. As the dimensionality of generated quantities is often much bigger than the parameters, it is not useful to include them in init.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanMCMC <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
# Convert from data.table to data.frame
if (all(init$return_codes() == 1)) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[2:(length(colnames(draws_df)) - 3)])
}
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
method = "simple_no_replace")
init_draws_lst = process_init(init_draws_df,
num_procs = num_procs, model_variables = model_variables)
return(init_draws_lst)
}

#' Performs PSIS resampling on the draws from an approxmation method for inits.
#' @noRd
#' @param init A set of draws with `lp__` and `lp_approx__` columns.
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
#' @importFrom stats aggregate
process_init_approx <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
# Convert from data.table to data.frame
if (init$return_codes() == 1) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[3:(length(colnames(draws_df)) - 3)])
}
draws_df$lw = draws_df$lp__ - draws_df$lp_approx__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if calculate_lp was FALSE?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to rework this. Now if calculate_lp is false then the weights are just all set to 1. Is that alright?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

# Calculate unique draws based on 'lw' using base R functions
unique_draws = length(unique(draws_df$lw))
if (num_procs > unique_draws) {
if (inherits(init, "CmdStanPathfinder")) {
stop(paste0("Not enough distinct draws (", num_procs, ") in pathfinder fit to create inits. Try running Pathfinder with psis_resample=FALSE"))
jgabry marked this conversation as resolved.
Show resolved Hide resolved
} else {
stop(paste0("Not enough distinct draws (", num_procs, ") to create inits."))
}
}
if (unique_draws < (0.95 * nrow(draws_df))) {
temp_df = stats::aggregate(.draw ~ lw, data = draws_df, FUN = min)
draws_df = posterior::as_draws_df(merge(temp_df, draws_df, by = 'lw'))
draws_df$pareto_weight = exp(draws_df$lw - max(draws_df$lw))
} else {
draws_df$pareto_weight = posterior::pareto_smooth(
exp(draws_df$lw - max(draws_df$lw)), tail = "right", return_k=FALSE)
}
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
weights = draws_df$pareto_weight, method = "simple_no_replace")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I read correctly, this resampling is done always, even if Stan did already do resampling. If model$pathfinder(..., psis_resample=TRUE) was used and there are enough unique draws, then the resampling here should be done with uniform weights.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

init_draws_lst = process_init(init_draws_df,
num_procs = num_procs, model_variables = model_variables, warn_partial)
return(init_draws_lst)
}


#' Write initial values to files if provided as a `CmdStanPathfinder` class
#' @noRd
#' @param init A `CmdStanPathfinder` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanPathfinder <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
process_init_approx(init, num_procs, model_variables, warn_partial)
}

#' Write initial values to files if provided as a `CmdStanVB` class
#' @noRd
#' @param init A `CmdStanVB` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanVB <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
process_init_approx(init, num_procs, model_variables, warn_partial)
}

#' Write initial values to files if provided as a `CmdStanLaplace` class
#' @noRd
#' @param init A `CmdStanLaplace` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanLaplace <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
process_init_approx(init, num_procs, model_variables, warn_partial)
}


#' Write initial values to files if provided as a `CmdStanMLE` class
#' @noRd
#' @param init A `CmdStanMLE` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanMLE <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
# Convert from data.table to data.frame
if (init$return_codes() == 1) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[2:(length(colnames(draws_df)) - 3)])
}
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
method = "simple")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CmdStanMLE should only have one draw. Resampling just copies that one draw several times. There is no performance hit, but it just looks strange.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'll fix this to just rep the draw a bunch of times

init_draws_lst_lst = process_init(init_draws_df,
num_procs = num_procs, model_variables = model_variables, warn_partial)
return(init_draws_lst_lst)
}

#' Validate initial values
Expand Down
2 changes: 1 addition & 1 deletion R/cmdstanr-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
#' @inherit cmdstan_model examples
#' @import R6
#'
NULL
"_PACKAGE"

if (getRversion() >= "2.15.1") utils::globalVariables(c("self", "private", "super"))
6 changes: 3 additions & 3 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,11 @@ unconstrain_variables <- function(variables) {
" not provided!", call. = FALSE)
}

# Remove zero-length parameters from model_variables, otherwise process_init_list
# Remove zero-length parameters from model_variables, otherwise process_init
# warns about missing inputs
model_variables$parameters <- model_variables$parameters[nonzero_length_params]

stan_pars <- process_init_list(list(variables), num_procs = 1, model_variables)
stan_pars <- process_init(list(variables), num_procs = 1, model_variables)
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, stan_pars)
}
CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables)
Expand Down Expand Up @@ -594,7 +594,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
# but not in metadata()$variables
nonzero_length_params <- names(model_variables$parameters) %in% model_par_names

# Remove zero-length parameters from model_variables, otherwise process_init_list
# Remove zero-length parameters from model_variables, otherwise process_init
# warns about missing inputs
pars <- names(model_variables$parameters[nonzero_length_params])

Expand Down
2 changes: 1 addition & 1 deletion man/model-method-check_syntax.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-compile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-diagnose.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-expose_functions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-format.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading