Skip to content

Commit

Permalink
Update doc and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed May 24, 2024
1 parent 515bc18 commit bec155d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 31 deletions.
70 changes: 42 additions & 28 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1588,8 +1588,20 @@ CmdStanMCMC$set("public", name = "loo", value = loo)
#' approximation using bridge sampling. This method requires the
#' \pkg{bridgesampling} package.
#'
#' @param ... Arguments (e.g., `repetitions`, `cores`, `maxiter`, etc.)
#' passed to [bridgesampling::bridge_sampler()].
#' @param method (character) The method to use for bridge sampling. Options are
#' `"normal"` (default) or `"warp3"`.
#' @param repetitions (integer) The number of repetitions for bridge sampling.
#' @param cores (integer) The number of cores to be used by the \pkg{bridgesampling}
#' package. Defaults to `1`. See the \pkg{bridgesampling} package documentation
#' for more details.
#' @param use_neff (logical) Whether to use the effective sample size (ESS) in
#' the optimal bridge function. Default is TRUE. If FALSE, the number of samples
#' is used instead.
#' @param maxiter (integer) The maximum number of iterations for bridge sampling.
#' @param silent (logical) Whether to suppress output from the bridge sampling
#' algorithm. Defaults to `FALSE`.
#' @param verbose (logical) Whether to print verbose output. Defaults to `FALSE`.
#' @param ... Other arguments passed to the bridge sampling function.
#'
#' @return The object returned by the bridge sampling function.
#'
Expand All @@ -1604,45 +1616,47 @@ CmdStanMCMC$set("public", name = "loo", value = loo)
#' print(bridge_result)
#' }
#'
bridge_sampler <- function(...) {
bridge_sampler <- function(method = "normal", repetitions = 1, cores = 1,
use_neff = TRUE, maxiter = 1000, silent = FALSE,
verbose = FALSE, ...) {
require_suggested_package("bridgesampling")
self$init_model_methods()

upars_cr <- self$unconstrain_draws(format = "draws_array")
nr_cr <- posterior::niterations(upars_cr)
half_iter <- nr_cr %/% 2
upars <- self$unconstrain_draws(format = "draws_array")
nr <- posterior::niterations(upars)
half_iter <- nr %/% 2

samples_4_iter_cr <- posterior::subset_draws(upars_cr, iteration = seq.int(from=half_iter + 1, nr_cr))
neff_cr <- median(posterior::summarise_draws(samples_4_iter_cr,"ess_median")$ess_median)
samples_4_iter <- posterior::subset_draws(upars, iteration = seq.int(from=half_iter + 1, nr))
par_ess <- posterior::summarise_draws(samples_4_iter, "ess_median")$ess_median
neff <- posterior::quantile2(par_ess, 0.5)

parameters_cr <- attributes(upars_cr)$dimnames$variable
transTypes_cr <- rep("unbounded", length(parameters_cr))
names(transTypes_cr) <- parameters_cr
lb_cr <- rep(-Inf, length(parameters_cr))
ub_cr <- rep(Inf, length(parameters_cr))
names(lb_cr) <- names(ub_cr) <- parameters_cr
parameters <- attributes(upars)$dimnames$variable
transTypes <- rep("unbounded", length(parameters))
names(transTypes) <- parameters
lb <- rep(-Inf, length(parameters))
ub <- rep(Inf, length(parameters))
names(lb) <- names(ub) <- parameters

samples_4_fit_cr <- posterior::subset_draws(upars_cr, iteration = seq_len(half_iter))
samples_4_fit_cr <- posterior::as_draws_matrix(samples_4_fit_cr)
samples_4_iter_cr <- posterior::as_draws_matrix(samples_4_iter_cr)

colnames(samples_4_fit_cr) <- paste0("trans_", parameters_cr)
colnames(samples_4_iter_cr) <- paste0("trans_", parameters_cr)
samples_4_fit <- posterior::subset_draws(upars, iteration = seq_len(half_iter))
samples_4_fit <- posterior::as_draws_matrix(samples_4_fit)
samples_4_iter <- posterior::as_draws_matrix(samples_4_iter)

colnames(samples_4_fit) <- paste0("trans_", parameters)
colnames(samples_4_iter) <- paste0("trans_", parameters)

do.call(rlang::ns_env("bridgesampling")[[paste0(".bridge.sampler.", method)]],
args = list(samples_4_fit = samples_4_fit_cr,
samples_4_iter = samples_4_iter_cr,
neff = neff_cr,
args = list(samples_4_fit = samples_4_fit,
samples_4_iter = samples_4_iter,
neff = neff,
log_posterior = function(s.row, data) { data$fitobj$log_prob(s.row) },
data = list(fitobj = self),
lb = lb_cr, ub = ub_cr,
param_types = rep("real", ncol(samples_4_fit_cr)),
transTypes = transTypes_cr,
lb = lb, ub = ub,
param_types = rep("real", ncol(samples_4_fit)),
transTypes = transTypes,
repetitions = repetitions, cores = cores,
maxiter = maxiter, silent = silent,
verbose = verbose,
r0 = 0.5, tol1 = 1e-10, tol2 = 1e-4))
verbose = verbose, r0 = 0.5, tol1 = 1e-10, tol2 = 1e-4,
...))
}
CmdStanMCMC$set("public", name = "bridge_sampler", value = bridge_sampler)

Expand Down
34 changes: 31 additions & 3 deletions man/fit-method-bridge_sampler.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit bec155d

Please sign in to comment.