Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose new stan args #932

Merged
merged 15 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ CmdStanArgs <- R6::R6Class(
sig_figs = NULL,
opencl_ids = NULL,
model_variables = NULL,
num_threads = NULL) {
num_threads = NULL,
save_cmdstan_config = NULL) {

self$model_name <- model_name
self$stan_code <- stan_code
Expand All @@ -60,6 +61,7 @@ CmdStanArgs <- R6::R6Class(
self$save_latent_dynamics <- save_latent_dynamics
self$using_tempdir <- is.null(output_dir)
self$model_variables <- model_variables
self$save_cmdstan_config <- save_cmdstan_config
if (os_is_wsl()) {
# Want to ensure that any files under WSL are written to a tempdir within
# WSL to avoid IO performance issues
Expand Down Expand Up @@ -87,6 +89,9 @@ CmdStanArgs <- R6::R6Class(
self$opencl_ids <- opencl_ids
self$num_threads = NULL
self$method_args$validate(num_procs = length(self$proc_ids))
if (is.logical(self$save_cmdstan_config)) {
self$save_cmdstan_config <- as.integer(self$save_cmdstan_config)
}
self$validate()
},
validate = function() {
Expand All @@ -111,7 +116,7 @@ CmdStanArgs <- R6::R6Class(
} else if (type == "profile") {
basename <- paste0(basename, "-profile")
}
if (type == "output" && !is.null(self$output_basename)) {
if (type == "output" && !is.null(self$output_basename)) {
basename <- self$output_basename
}
generate_file_names(
Expand Down Expand Up @@ -180,6 +185,9 @@ CmdStanArgs <- R6::R6Class(
if (!is.null(profile_file)) {
args$output <- c(args$output, paste0("profile_file=", wsl_safe_path(profile_file)))
}
if (!is.null(self$save_cmdstan_config)) {
args$output <- c(args$output, paste0("save_cmdstan_config=", self$save_cmdstan_config))
}
if (!is.null(self$opencl_ids)) {
args$opencl <- c("opencl", paste0("platform=", self$opencl_ids[1]), paste0("device=", self$opencl_ids[2]))
}
Expand Down Expand Up @@ -218,7 +226,8 @@ SampleArgs <- R6::R6Class(
term_buffer = NULL,
window = NULL,
fixed_param = FALSE,
diagnostics = NULL) {
diagnostics = NULL,
save_metric = NULL) {

self$iter_warmup <- iter_warmup
self$iter_sampling <- iter_sampling
Expand All @@ -232,6 +241,7 @@ SampleArgs <- R6::R6Class(
self$inv_metric <- inv_metric
self$fixed_param <- fixed_param
self$diagnostics <- diagnostics
self$save_metric <- save_metric
if (identical(self$diagnostics, "")) {
self$diagnostics <- NULL
}
Expand Down Expand Up @@ -275,6 +285,9 @@ SampleArgs <- R6::R6Class(
if (is.logical(self$save_warmup)) {
self$save_warmup <- as.integer(self$save_warmup)
}
if (is.logical(self$save_metric)) {
self$save_metric <- as.integer(self$save_metric)
}
invisible(self)
},
validate = function(num_procs) {
Expand Down Expand Up @@ -314,7 +327,8 @@ SampleArgs <- R6::R6Class(
.make_arg("adapt_engaged"),
.make_arg("init_buffer"),
.make_arg("term_buffer"),
.make_arg("window")
.make_arg("window"),
.make_arg("save_metric")
)
} else {
new_args <- list(
Expand All @@ -335,7 +349,8 @@ SampleArgs <- R6::R6Class(
.make_arg("adapt_engaged"),
.make_arg("init_buffer"),
.make_arg("term_buffer"),
.make_arg("window")
.make_arg("window"),
.make_arg("save_metric")
)
}
new_args <- do.call(c, new_args)
Expand Down Expand Up @@ -682,6 +697,7 @@ validate_cmdstan_args <- function(self) {
checkmate::assert_flag(self$save_latent_dynamics)
checkmate::assert_integerish(self$refresh, lower = 0, null.ok = TRUE)
checkmate::assert_integerish(self$sig_figs, lower = 1, upper = 18, null.ok = TRUE)
checkmate::assert_integerish(self$save_cmdstan_config, lower = 0, upper = 1, len = 1, null.ok = TRUE)
if (!is.null(self$sig_figs) && cmdstan_version() < "2.25") {
warning("The 'sig_figs' argument is only supported with cmdstan 2.25+ and will be ignored!", call. = FALSE)
}
Expand Down Expand Up @@ -799,6 +815,15 @@ validate_sample_args <- function(self, num_procs) {
checkmate::assert_subset(self$diagnostics, empty.ok = FALSE, choices = available_hmc_diagnostics())
}

checkmate::assert_integerish(self$save_metric,
lower = 0, upper = 1,
len = 1,
null.ok = TRUE)

if (is.null(self$adapt_engaged) || (!self$adapt_engaged && !is.null(self$save_metric))) {
self$save_metric <- 0
}

invisible(TRUE)
}

Expand Down
50 changes: 46 additions & 4 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -898,10 +898,13 @@ CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose)
#' Save output and data files
#'
#' @name fit-method-save_output_files
#' @aliases fit-method-save_data_file fit-method-save_latent_dynamics_files fit-method-save_profile_files
#' fit-method-output_files fit-method-data_file fit-method-latent_dynamics_files fit-method-profile_files
#' save_output_files save_data_file save_latent_dynamics_files save_profile_files
#' output_files data_file latent_dynamics_files profile_files
#' @aliases fit-method-save_data_file fit-method-save_latent_dynamics_files
#' fit-method-save_profile_files fit-method-output_files fit-method-data_file
#' fit-method-latent_dynamics_files fit-method-profile_files
#' fit-method-save_config_files fit-method-save_metric_files save_output_files
#' save_data_file save_latent_dynamics_files save_profile_files
#' save_config_files save_metric_files output_files data_file
#' latent_dynamics_files profile_files config_files metric_files
#'
#' @description All fitted model objects have methods for saving (moving to a
#' specified location) the files created by CmdStanR to hold CmdStan output
Expand Down Expand Up @@ -936,6 +939,14 @@ CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose)
#' `$save_output_files()` except `"-profile-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_metric_files()` everything is the same as for
#' `$save_output_files()` except `"-metric-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_config_files()` everything is the same as for
#' `$save_output_files()` except `"-config-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_data_file()` no `id` is included in the file name because even
#' with multiple MCMC chains the data file is the same.
#'
Expand Down Expand Up @@ -998,6 +1009,26 @@ save_data_file <- function(dir = ".",
}
CmdStanFit$set("public", name = "save_data_file", value = save_data_file)

#' @rdname fit-method-save_output_files
save_config_files <- function(dir = ".",
basename = NULL,
timestamp = TRUE,
random = TRUE) {
self$runset$save_config_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_config_files", value = save_config_files)

#' @rdname fit-method-save_output_files
save_metric_files <- function(dir = ".",
basename = NULL,
timestamp = TRUE,
random = TRUE) {
self$runset$save_metric_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_metric_files", value = save_metric_files)



#' @rdname fit-method-save_output_files
#' @param include_failed (logical) Should CmdStan runs that failed also be
#' included? The default is `FALSE.`
Expand All @@ -1024,6 +1055,17 @@ data_file <- function() {
}
CmdStanFit$set("public", name = "data_file", value = data_file)

#' @rdname fit-method-save_output_files
config_files <- function(include_failed = FALSE) {
self$runset$config_files(include_failed)
}
CmdStanFit$set("public", name = "config_files", value = config_files)

#' @rdname fit-method-save_output_files
metric_files <- function(include_failed = FALSE) {
self$runset$metric_files(include_failed)
}
CmdStanFit$set("public", name = "metric_files", value = metric_files)

#' Report timing of CmdStan runs
#'
Expand Down
36 changes: 25 additions & 11 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,8 @@ sample <- function(data = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
save_metric = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
# deprecated
cores = NULL,
num_cores = NULL,
Expand Down Expand Up @@ -1240,7 +1242,8 @@ sample <- function(data = NULL,
term_buffer = term_buffer,
window = window,
fixed_param = fixed_param,
diagnostics = diagnostics
diagnostics = diagnostics,
save_metric = save_metric
)
args <- CmdStanArgs$new(
method_args = sample_args,
Expand All @@ -1260,7 +1263,8 @@ sample <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1357,6 +1361,7 @@ sample_mpi <- function(data = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
# deprecated
validate_csv = TRUE) {

Expand Down Expand Up @@ -1420,7 +1425,8 @@ sample_mpi <- function(data = NULL,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs,
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan_mpi(mpi_cmd, mpi_args)
Expand Down Expand Up @@ -1500,7 +1506,8 @@ optimize <- function(data = NULL,
tol_param = NULL,
history_size = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1541,7 +1548,8 @@ optimize <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1632,7 +1640,8 @@ laplace <- function(data = NULL,
jacobian = TRUE, # different default than for optimize!
draws = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
if (cmdstan_version() < "2.32") {
stop("This method is only available in cmdstan >= 2.32", call. = FALSE)
}
Expand Down Expand Up @@ -1706,7 +1715,8 @@ laplace <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1786,7 +1796,8 @@ variational <- function(data = NULL,
output_samples = NULL,
draws = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1827,7 +1838,8 @@ variational <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1929,7 +1941,8 @@ pathfinder <- function(data = NULL,
psis_resample = NULL,
calculate_lp = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1976,7 +1989,8 @@ pathfinder <- function(data = NULL,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables,
num_threads = num_threads
num_threads = num_threads,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down
Loading
Loading