diff --git a/R/fit.R b/R/fit.R index 30e7310a..0f76b4d0 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1616,6 +1616,9 @@ CmdStanMCMC$set("public", name = "loo", value = loo) #' } #' sampler_diagnostics <- function(inc_warmup = FALSE, format = getOption("cmdstanr_draws_format", "draws_array")) { + if (isTRUE(private$metadata_$algorithm == "fixed_param")) { + stop("There are no sampler diagnostics when fixed_param = TRUE.", call. = FALSE) + } if (is.null(private$sampler_diagnostics_) && !length(self$output_files(include_failed = FALSE))) { stop("No chains finished successfully. Unable to retrieve the sampler diagnostics.", call. = FALSE) diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index 122dc503..f7e227d4 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -19,6 +19,12 @@ fit_mcmc_3 <- testing_fit("logistic", method = "sample", iter_sampling = 0, save_warmup = 1, refresh = 0, metric = "dense_e") +fit_mcmc_fixed_param <- testing_fit("logistic", method = "sample", + seed = 1234, chains = 1, + iter_warmup = 100, + iter_sampling = 0, + save_warmup = 1, + refresh = 0, fixed_param = TRUE) PARAM_NAMES <- c("alpha", "beta[1]", "beta[2]", "beta[3]") test_that("draws() stops for unkown variables", { @@ -399,3 +405,10 @@ test_that("metadata()$time has chains rowss", { expect_equal(nrow(fit_mcmc_2$metadata()$time), fit_mcmc_2$num_chains()) expect_equal(nrow(fit_mcmc_3$metadata()$time), fit_mcmc_3$num_chains()) }) + +test_that("sampler_diagnostics() throws informative error when fixed_param=TRUE", { + expect_error( + fit_mcmc_fixed_param$sampler_diagnostics(), + "There are no sampler diagnostics when fixed_param = TRUE" + ) +})