Skip to content

Commit

Permalink
Merge pull request #886 from stan-dev/unconstrain-draws
Browse files Browse the repository at this point in the history
`$unconstrain_draws()` returns draws format
  • Loading branch information
andrjohns authored Jan 10, 2024
2 parents 2bec769 + a43a178 commit 1b77cf4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 17 deletions.
22 changes: 14 additions & 8 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
#' @param files (character vector) The paths to the CmdStan CSV files. These can
#' be files generated by running CmdStanR or running CmdStan directly.
#' @param draws A `posterior::draws_*` object.
#' @param format (string) The format of the returned draws. Must be a valid
#' format from the \pkg{posterior} package.
#'
#' @examples
#' \dontrun{
Expand All @@ -562,7 +564,8 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
#' [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#' [hessian()]
#'
unconstrain_draws <- function(files = NULL, draws = NULL) {
unconstrain_draws <- function(files = NULL, draws = NULL,
format = getOption("cmdstanr_draws_format", "draws_array")) {
if (!is.null(files) || !is.null(draws)) {
if (!is.null(files) && !is.null(draws)) {
stop("Either a list of CSV files or a draws object can be passed, not both",
Expand Down Expand Up @@ -600,13 +603,16 @@ unconstrain_draws <- function(files = NULL, draws = NULL) {
skeleton <- self$variable_skeleton(transformed_parameters = FALSE,
generated_quantities = FALSE)
par_columns <- !(names(draws) %in% c(".chain", ".iteration", ".draw"))
unconstrained <- lapply(split(draws, f = draws$.chain), function(chain) {
lapply(asplit(chain, 1), function(draw) {
par_list <- utils::relist(as.numeric(draw[par_columns]), skeleton)
self$unconstrain_variables(variables = par_list)
})
meta_columns <- !par_columns
unconstrained <- lapply(asplit(draws, 1), function(draw) {
par_list <- utils::relist(as.numeric(draw[par_columns]), skeleton)
self$unconstrain_variables(variables = par_list)
})
unconstrained

unconstrained <- do.call(rbind.data.frame, unconstrained)
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
names(unconstrained) <- repair_variable_names(uncon_names)
maybe_convert_draws_format(cbind.data.frame(unconstrained, draws[,meta_columns]), format)
}
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)

Expand Down Expand Up @@ -1546,7 +1552,7 @@ loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...)
loo = loo_result,
post_draws = function(x, ...) { x$draws(format = "draws_matrix") },
log_lik_i = log_lik_i,
unconstrain_pars = function(x, pars, ...) { do.call(rbind, lapply(x$unconstrain_draws(), function(chain) { do.call(rbind, chain) })) },
unconstrain_pars = function(x, pars, ...) { x$unconstrain_draws(format = "draws_matrix") },
log_prob_upars = function(x, upars, ...) { apply(upars, 1, x$log_prob) },
log_lik_i_upars = log_lik_i_upars,
...
Expand Down
19 changes: 19 additions & 0 deletions inst/include/model_methods.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <Rcpp.h>
#include <stan/model/model_base.hpp>
#include <stan/model/log_prob_grad.hpp>
#include <stan/model/log_prob_propto.hpp>
#include <boost/random/additive_combine.hpp>
Expand Down Expand Up @@ -115,3 +116,21 @@ std::vector<double> constrain_variables(SEXP ext_model_ptr, SEXP base_rng,
ptr->write_array(*rng.get(), upars, params_i, vars, return_trans_pars, return_gen_quants);
return vars;
}

// [[Rcpp::export]]
std::vector<std::string> unconstrained_param_names(SEXP ext_model_ptr, bool return_trans_pars, bool return_gen_quants) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<std::string> rtn_names;
ptr->unconstrained_param_names(rtn_names, return_trans_pars, return_gen_quants);
return rtn_names;
}

// [[Rcpp::export]]
std::vector<std::string> constrained_param_names(SEXP ext_model_ptr,
bool return_trans_pars,
bool return_gen_quants) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
std::vector<std::string> rtn_names;
ptr->constrained_param_names(rtn_names, return_trans_pars, return_gen_quants);
return rtn_names;
}
9 changes: 8 additions & 1 deletion man/fit-method-unconstrain_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions tests/testthat/test-model-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,20 @@ test_that("unconstrain_draws returns correct values", {
mod <- cmdstan_model(write_stan_file(model_code),
compile_model_methods = TRUE,
force_recompile = TRUE)
fit <- mod$sample(data = list(N = 0), chains = 1)
fit <- mod$sample(data = list(N = 0), chains = 2)

x_draws <- fit$draws(format = "draws_df")$x

# Unconstrain all internal draws
unconstrained_internal_draws <- fit$unconstrain_draws()[[1]]
unconstrained_internal_draws <- fit$unconstrain_draws()
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_internal_draws))

# Unconstrain external CmdStan CSV files
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())[[1]]
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_csv))

# Unconstrain existing draws object
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())[[1]]
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_draws))

# With a lower-bounded constraint, the parameter draws should be the
Expand All @@ -253,19 +253,19 @@ test_that("unconstrain_draws returns correct values", {
mod <- cmdstan_model(write_stan_file(model_code),
compile_model_methods = TRUE,
force_recompile = TRUE)
fit <- mod$sample(data = list(N = 0), chains = 1)
fit <- mod$sample(data = list(N = 0), chains = 2)

x_draws <- fit$draws(format = "draws_df")$x

unconstrained_internal_draws <- fit$unconstrain_draws()[[1]]
unconstrained_internal_draws <- fit$unconstrain_draws()
expect_equal(as.numeric(x_draws), exp(as.numeric(unconstrained_internal_draws)))

# Unconstrain external CmdStan CSV files
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())[[1]]
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())
expect_equal(as.numeric(x_draws), exp(as.numeric(unconstrained_csv)))

# Unconstrain existing draws object
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())[[1]]
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())
expect_equal(as.numeric(x_draws), exp(as.numeric(unconstrained_draws)))
})

Expand Down

0 comments on commit 1b77cf4

Please sign in to comment.