From d729b32d3de732921568bac38ab287360110ae96 Mon Sep 17 00:00:00 2001 From: Ven Popov Date: Thu, 14 Mar 2024 07:39:21 +0100 Subject: [PATCH] add generate_inits function --- DESCRIPTION | 5 +- NAMESPACE | 4 + R/inits.R | 216 ++++++++++++++++++++++++++++++++++++ man/generate_inits.Rd | 103 +++++++++++++++++ tests/testthat/test-inits.R | 34 ++++++ 5 files changed, 360 insertions(+), 2 deletions(-) create mode 100644 R/inits.R create mode 100644 man/generate_inits.Rd create mode 100644 tests/testthat/test-inits.R diff --git a/DESCRIPTION b/DESCRIPTION index 3d55cb48..e1550134 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ URL: https://mc-stan.org/cmdstanr/, https://discourse.mc-stan.org BugReports: https://github.com/stan-dev/cmdstanr/issues Encoding: UTF-8 LazyData: true -RoxygenNote: 7.3.0 +RoxygenNote: 7.3.1 Roxygen: list(markdown = TRUE, r6 = FALSE) SystemRequirements: CmdStan (https://mc-stan.org/users/interfaces/cmdstan) Depends: @@ -42,7 +42,8 @@ Imports: processx (>= 3.5.0), R6 (>= 2.4.0), withr (>= 2.5.0), - rlang (>= 0.4.7) + rlang (>= 0.4.7), + glue Suggests: bayesplot, ggplot2, diff --git a/NAMESPACE b/NAMESPACE index c8a6217d..dfa8ff20 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,6 +6,9 @@ S3method(as_draws,CmdStanMCMC) S3method(as_draws,CmdStanMLE) S3method(as_draws,CmdStanPathfinder) S3method(as_draws,CmdStanVB) +S3method(generate_inits,CmdStanMCMC) +S3method(generate_inits,character) +S3method(generate_inits,draws) export(as_cmdstan_fit) export(as_draws) export(as_mcmc.list) @@ -19,6 +22,7 @@ export(cmdstan_version) export(cmdstanr_example) export(draws_to_csv) export(eng_cmdstan) +export(generate_inits) export(install_cmdstan) export(num_threads) export(print_example_program) diff --git a/R/inits.R b/R/inits.R new file mode 100644 index 00000000..0e734e28 --- /dev/null +++ b/R/inits.R @@ -0,0 +1,216 @@ +#' Generate initial values for Stan Models +#' +#' The `generate_inits()` methods generate a list of lists of initial values for +#' each chain to be used in initializing a model fit with Stan +#' +#' @name generate_inits +#' @param object An object from which to generate initial values +#' @param ... Additional arguments to be passed to the specific methods +#' @details The `generate_inits()` method is generic function to which specific +#' methods for different classes of objects can be written. In the `cmdstanr` +#' package, the following objects are supported: +#' +#' * A `CmdStanMCMC` object, which is the result of sampling from a Stan model with cmdstanr +#' * A vector of file paths to the CSV files containing the draws from a Stan model +#' * A draws object from the `posterior` package +#' +#' For these objects, the function specified in \code{FUN} is applied to the +#' draws to generate the inits. This can be very flexible - any function works +#' as long as it returns a scalar and can be applied to a vector. +#' @return A list of lists of initial values for each chain +#' @export +#' @examples +#' \dontrun{ +#' # inits from a CmdStanMCMC object +#' stanfit <- cmdstanr::cmdstanr_example("logistic") +#' generate_inits(stanfit) +#' generate_inits(stanfit, FUN = mean) +#' generate_inits(stanfit, FUN = quantile, probs = 0.5) +#' generate_inits(stanfit, draws = "last") +#' +#' # inits from a vector of file paths +#' files <- stanfit$output_files() +#' generate_inits(files) +#' +#' # inits from a draws object +#' draws <- stanfit$draws() +#' generate_inits(draws) +#' +#' # warmup and then use the final draws for the inits of a separate sampling stage +#' warmup <- cmdstanr_example("logistic", parallel_chains = 4, iter_sampling = 0, save_warmup = T) +#' inits <- generate_inits(warmup, draws = "last") +#' mod <- cmdstan_model(exe_file = warmup$runset$exe_file()) +#' fit <- mod$sample(warmup$data_file(), +#' parallel_chains = 4, +#' init = inits, +#' iter_warmup = 0, +#' inv_metric = warmup$inv_metric(matrix = FALSE), +#' step_size = warmup$metadata()$step_size_adaptation, +#' adapt_engaged = FALSE) +#' +#' # compare with standard fitting with combined warmup and sampling +#' fit_standard <- mod$sample(warmup$data_file(), +#' parallel_chains = 4) +#' } +generate_inits <- function(object, ...) { + UseMethod("generate_inits") +} + +#' @rdname generate_inits +#' @param FUN A function to apply to the draws to generate the inits. Only used +#' if draws = "all" or "sampling". It should be a function name that takes a +#' vector as input and returns a scalar, such as mean or median. The function +#' will be applied to each parameter's draws to generate the inits. The +#' default is to sample 1 random draw from the posterior draws +#' @param variables A character vector of parameter names for which to generate +#' inits +#' @export +generate_inits.draws <- function(object, variables = NULL, FUN = sample1, ...) { + checkmate::assert_function(FUN) + checkmate::assert_character(variables, null.ok = TRUE) + draws <- posterior::as_draws_array(object) + checkmate::assert_scalar( + FUN(c(draws[,1,1]), ...), + .var.name = paste0('the return value of ', as.character(quote(FUN))) + ) + + # extract parameter information from draws + nchains <- length(dimnames(draws)$chain) + all_pars <- dimnames(draws)$variable + par_dims <- variable_dims(all_pars) + if (is.null(variables)) { + variables <- names(par_dims) + } else { + variables <- intersect(variables, names(par_dims)) + } + + # apply the function to the draws to select the inits + draws <- apply(draws, 2:3, FUN, ...) + + # prepare init list + out <- vector('list', nchains) + + for (i in 1:nchains) { + out[[i]] <- vector('list', length(variables)) + names(out[[i]]) <- variables + + # extract the draw for each parameter and store it in the proper format + for (par in variables) { + pattern <- paste0("^", par, "(\\[|$)") + idx <- grep(pattern, all_pars) + values <- draws[i,idx] + dims <- par_dims[[par]] + if (any(dims > 1)) { + out[[i]][[par]] <- array(values, dims) + } else { + out[[i]][[par]] <- as.numeric(values) + } + } + } + + out +} + +#' @param draws A character string. Either "last", "sampling" or "all". If +#' "last", only the last draw is used. If "sampling", all the draws from the +#' sampling phase are used. If "all", all the draws, including warmup, are +#' used. +#' @export +#' @rdname generate_inits +generate_inits.CmdStanMCMC <- function(object, variables = NULL, FUN = sample1, + draws = "sampling", ...) { + draws <- match.arg(draws, c("last", "sampling", "all")) + pars <- names(object$runset$args$model_variables$parameters) + if (!is.null(variables)) { + pars <- intersect(variables, pars) + } + + # get the draws array + if (draws == "last") { + draws <- read_last_draws(object$output_files()) + dimnames(draws)$variable <- object$metadata()$model_params + } else { + draws <- object$draws(variables = pars, inc_warmup = draws == "all", + format = "draws_array") + } + + generate_inits(draws, FUN = FUN, variables = pars, ...) +} + +#' @export +#' @rdname generate_inits +generate_inits.character <- function(object, variables = NULL, FUN = sample1, + draws = "sampling", ...) { + checkmate::assert_file_exists(object) + checkmate::assert_function(FUN) + draws <- match.arg(draws, c("sampling", "all")) + stanfit <- as_cmdstan_fit(object, format = "draws_array") + generate_inits(stanfit, variables = variables, FUN = FUN, draws = draws, ...) +} + + + +sample1 <- function(x) { + if (length(x) == 1) { + return(x) + } + base::sample(x, size = 1) +} + + +# efficiently extract the last complete draw recorded in a cmdstan csv file +# @param csv_file A character string of the file path +# @param par_names A logical. If TRUE, the parameter names are included. +# @return A draws array +read_last_draws <- function(csv_files, par_names = FALSE) { + checkmate::assert_file_exists(csv_files) + checkmate::assert_logical(par_names) + + out <- vector("list", length(csv_files)) + for (i in seq_along(csv_files)) { + file <- csv_files[i] + tmpfile <- tempfile() + cmd <- glue::glue("tail -12 {file} | grep \"^[0-9-]\" | tail -2 > {tmpfile}") + switch(.Platform$OS.type, + windows = shell(cmd), + unix = system(cmd)) + lines <- readLines(tmpfile) + + if (length(lines) == 0) { + stop("No draws found in ", csv_files[i], call. = FALSE) + } + + lines <- strsplit(lines, ",") + if (length(lines[[2]]) < length(lines[[1]])) { + message("The last draw is incomplete. The last complete draw will be used.") + res <- lines[[1]] + } else { + res <- lines[[2]] + } + out[[i]] <- as.data.frame(t(as.numeric(unlist(res)))) + } + + out <- do.call(posterior::as_draws_array, list(out)) + + if (par_names) { + tmpfile <- tempfile() + cmd <- glue::glue('grep \"^[a-zA-Z]\" {csv_files[i]} > {tmpfile}') + switch(.Platform$OS.type, + windows = shell(cmd), + unix = system(cmd)) + pars <- readLines(tmpfile) + if (length(pars) == 0) { + message("No parameter names found in ", csv_files[i]) + } else if (length(pars) > 1) { + stop("Could not identify the parameter names in ", csv_files[i], call. = FALSE) + } else { + pars <- strsplit(pars, ",")[[1]] + dimnames(out)$variable <- repair_variable_names(pars) + } + } + + # remove diagnostics + out <- out[,,-c(2:7)] + + out +} diff --git a/man/generate_inits.Rd b/man/generate_inits.Rd new file mode 100644 index 00000000..4abbd8c7 --- /dev/null +++ b/man/generate_inits.Rd @@ -0,0 +1,103 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/inits.R +\name{generate_inits} +\alias{generate_inits} +\alias{generate_inits.draws} +\alias{generate_inits.CmdStanMCMC} +\alias{generate_inits.character} +\title{Generate initial values for Stan Models} +\usage{ +generate_inits(object, ...) + +\method{generate_inits}{draws}(object, variables = NULL, FUN = sample1, ...) + +\method{generate_inits}{CmdStanMCMC}( + object, + variables = NULL, + FUN = sample1, + draws = "sampling", + ... +) + +\method{generate_inits}{character}( + object, + variables = NULL, + FUN = sample1, + draws = "sampling", + ... +) +} +\arguments{ +\item{object}{An object from which to generate initial values} + +\item{...}{Additional arguments to be passed to the specific methods} + +\item{variables}{A character vector of parameter names for which to generate +inits} + +\item{FUN}{A function to apply to the draws to generate the inits. Only used +if draws = "all" or "sampling". It should be a function name that takes a +vector as input and returns a scalar, such as mean or median. The function +will be applied to each parameter's draws to generate the inits. The +default is to sample 1 random draw from the posterior draws} + +\item{draws}{A character string. Either "last", "sampling" or "all". If +"last", only the last draw is used. If "sampling", all the draws from the +sampling phase are used. If "all", all the draws, including warmup, are +used.} +} +\value{ +A list of lists of initial values for each chain +} +\description{ +The \code{generate_inits()} methods generate a list of lists of initial values for +each chain to be used in initializing a model fit with Stan +} +\details{ +The \code{generate_inits()} method is generic function to which specific +methods for different classes of objects can be written. In the \code{cmdstanr} +package, the following objects are supported: +\itemize{ +\item A \code{CmdStanMCMC} object, which is the result of sampling from a Stan model with cmdstanr +\item A vector of file paths to the CSV files containing the draws from a Stan model +\item A draws object from the \code{posterior} package +} + +For these objects, the function specified in \code{FUN} is applied to the +draws to generate the inits. This can be very flexible - any function works +as long as it returns a scalar and can be applied to a vector. +} +\examples{ +\dontrun{ +# inits from a CmdStanMCMC object +stanfit <- cmdstanr::cmdstanr_example("logistic") +generate_inits(stanfit) +generate_inits(stanfit, FUN = mean) +generate_inits(stanfit, FUN = quantile, probs = 0.5) +generate_inits(stanfit, draws = "last") + +# inits from a vector of file paths +files <- stanfit$output_files() +generate_inits(files) + +# inits from a draws object +draws <- stanfit$draws() +generate_inits(draws) + +# warmup and then use the final draws for the inits of a separate sampling stage +warmup <- cmdstanr_example("logistic", parallel_chains = 4, iter_sampling = 0, save_warmup = T) +inits <- generate_inits(warmup, draws = "last") +mod <- cmdstan_model(exe_file = warmup$runset$exe_file()) +fit <- mod$sample(warmup$data_file(), + parallel_chains = 4, + init = inits, + iter_warmup = 0, + inv_metric = warmup$inv_metric(matrix = FALSE), + step_size = warmup$metadata()$step_size_adaptation, + adapt_engaged = FALSE) + +# compare with standard fitting with combined warmup and sampling +fit_standard <- mod$sample(warmup$data_file(), + parallel_chains = 4) +} +} diff --git a/tests/testthat/test-inits.R b/tests/testthat/test-inits.R new file mode 100644 index 00000000..91b4ec7a --- /dev/null +++ b/tests/testthat/test-inits.R @@ -0,0 +1,34 @@ +test_that("generate inits works", { + fit_mcmc <- cmdstanr_example("logistic", chains = 2) + inits1 <- generate_inits(fit_mcmc) + inits2 <- generate_inits(fit_mcmc, draws = "last") + inits3 <- generate_inits(fit_mcmc, FUN = median) + inits4 <- generate_inits(fit_mcmc, FUN = quantile, probs = 0.5) + + draws <- fit_mcmc$draws() + inits5 <- generate_inits(draws) + inits6 <- generate_inits(draws, variables = c('beta')) + + files <- fit_mcmc$output_files() + inits7 <- generate_inits(files) + + expect_length(inits1, 2) + expect_length(inits2, 2) + expect_length(inits3, 2) + expect_length(inits4, 2) + expect_length(inits5, 2) + expect_length(inits6, 2) + expect_length(inits7, 2) + + expect_equal(names(inits1[[1]]), c('alpha','beta')) + expect_equal(names(inits2[[1]]), c('alpha','beta')) + expect_equal(names(inits3[[1]]), c('alpha','beta')) + expect_equal(names(inits4[[1]]), c('alpha','beta')) + expect_equal(names(inits5[[1]]), c('lp__','alpha','beta','log_lik')) + expect_equal(names(inits6[[1]]), c('beta')) + expect_equal(names(inits5[[1]]), c('lp__','alpha','beta','log_lik')) + + dims <- variable_dims(fit_mcmc$metadata()$variables) + expect_equal(length(inits5[[1]]$alpha), dims$alpha) + expect_equal(length(inits5[[1]]$beta), dims$beta) +})