diff --git a/DESCRIPTION b/DESCRIPTION index 34dc7d4e..4464e5c5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: cmdstanr Title: R Interface to 'CmdStan' -Version: 0.8.1 +Version: 0.8.1.9000 Date: 2024-06-06 Authors@R: c(person(given = "Jonah", family = "Gabry", role = "aut", diff --git a/R/data.R b/R/data.R index a1e4ab7c..ce5b9cf6 100644 --- a/R/data.R +++ b/R/data.R @@ -186,7 +186,18 @@ process_data <- function(data, model_variables = NULL) { # generating a decimal point in write_stan_json if (data_variables[[var_name]]$type == "int" && !is.integer(data[[var_name]])) { - mode(data[[var_name]]) <- "integer" + if (!isTRUE(all(is_wholenumber(data[[var_name]])))) { + # Don't warn for NULL/NA, as different warnings are used for those + if (!isTRUE(any(is.na(data[[var_name]])))) { + warning("A non-integer value was supplied for '", var_name, "'!", + " It will be truncated to an integer.", call. = FALSE) + } + mode(data[[var_name]]) <- "integer" + } else { + # Round before setting mode to integer to avoid floating point errors + data[[var_name]] <- round(data[[var_name]]) + mode(data[[var_name]]) <- "integer" + } } } } diff --git a/R/utils.R b/R/utils.R index ef36d0c7..ecac323c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -36,6 +36,10 @@ matching_variables <- function(variable_filters, variables) { ) } +is_wholenumber <- function(x, tol = sqrt(.Machine$double.eps)) { + abs(x - round(x)) < tol +} + # checks for OS and hardware ---------------------------------------------- os_is_windows <- function() { diff --git a/tests/testthat/test-data.R b/tests/testthat/test-data.R index 01fa7291..61ca79e3 100644 --- a/tests/testthat/test-data.R +++ b/tests/testthat/test-data.R @@ -355,3 +355,50 @@ test_that("process_data() corrrectly casts integers and floating point numbers", fixed = TRUE ) }) + +test_that("process_data warns on int coercion", { + stan_file <- write_stan_file(" + data { + int a; + real b; + } + ") + mod <- cmdstan_model(stan_file, compile = FALSE) + expect_warning( + process_data(list(a = 1.1, b = 2.1), model_variables = mod$variables()), + "A non-integer value was supplied for 'a'! It will be truncated to an integer." + ) + + stan_file <- write_stan_file(" + data { + array[3] int a; + } + ") + mod <- cmdstan_model(stan_file, compile = FALSE) + expect_warning( + process_data(list(a = c(1, 2.1, 3)), model_variables = mod$variables()), + "A non-integer value was supplied for 'a'! It will be truncated to an integer." + ) + + expect_no_warning( + process_data(list(a = c(1, 2, 3)), model_variables = mod$variables()) + ) +}) + +test_that("Floating-point differences do not cause truncation towards 0", { + stan_file <- write_stan_file(" + data { + int a; + real b; + } + ") + mod <- cmdstan_model(stan_file, compile = FALSE) + a <- 10*(3-2.7) + expect_false(is.integer(a)) + test_file <- process_data(list(a = a, b = 2.0), model_variables = mod$variables()) + expect_match( + " \"a\": 3,", + readLines(test_file)[2], + fixed = TRUE + ) +})