Skip to content

Commit

Permalink
Warn on input data int conversion (#994)
Browse files Browse the repository at this point in the history
* Warn on input data int conversion

* Better robustness

* Add handling and tests for floating-pointness

* Bumo dev version

* Handling for NULL/NA

* Update data.R
  • Loading branch information
andrjohns authored Jun 8, 2024
1 parent 02259ef commit 356fa04
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: cmdstanr
Title: R Interface to 'CmdStan'
Version: 0.8.1
Version: 0.8.1.9000
Date: 2024-06-06
Authors@R:
c(person(given = "Jonah", family = "Gabry", role = "aut",
Expand Down
13 changes: 12 additions & 1 deletion R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,18 @@ process_data <- function(data, model_variables = NULL) {
# generating a decimal point in write_stan_json
if (data_variables[[var_name]]$type == "int"
&& !is.integer(data[[var_name]])) {
mode(data[[var_name]]) <- "integer"
if (!isTRUE(all(is_wholenumber(data[[var_name]])))) {
# Don't warn for NULL/NA, as different warnings are used for those
if (!isTRUE(any(is.na(data[[var_name]])))) {
warning("A non-integer value was supplied for '", var_name, "'!",
" It will be truncated to an integer.", call. = FALSE)
}
mode(data[[var_name]]) <- "integer"
} else {
# Round before setting mode to integer to avoid floating point errors
data[[var_name]] <- round(data[[var_name]])
mode(data[[var_name]]) <- "integer"
}
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ matching_variables <- function(variable_filters, variables) {
)
}

is_wholenumber <- function(x, tol = sqrt(.Machine$double.eps)) {
abs(x - round(x)) < tol
}

# checks for OS and hardware ----------------------------------------------

os_is_windows <- function() {
Expand Down
47 changes: 47 additions & 0 deletions tests/testthat/test-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,50 @@ test_that("process_data() corrrectly casts integers and floating point numbers",
fixed = TRUE
)
})

test_that("process_data warns on int coercion", {
stan_file <- write_stan_file("
data {
int a;
real b;
}
")
mod <- cmdstan_model(stan_file, compile = FALSE)
expect_warning(
process_data(list(a = 1.1, b = 2.1), model_variables = mod$variables()),
"A non-integer value was supplied for 'a'! It will be truncated to an integer."
)

stan_file <- write_stan_file("
data {
array[3] int a;
}
")
mod <- cmdstan_model(stan_file, compile = FALSE)
expect_warning(
process_data(list(a = c(1, 2.1, 3)), model_variables = mod$variables()),
"A non-integer value was supplied for 'a'! It will be truncated to an integer."
)

expect_no_warning(
process_data(list(a = c(1, 2, 3)), model_variables = mod$variables())
)
})

test_that("Floating-point differences do not cause truncation towards 0", {
stan_file <- write_stan_file("
data {
int a;
real b;
}
")
mod <- cmdstan_model(stan_file, compile = FALSE)
a <- 10*(3-2.7)
expect_false(is.integer(a))
test_file <- process_data(list(a = a, b = 2.0), model_variables = mod$variables())
expect_match(
" \"a\": 3,",
readLines(test_file)[2],
fixed = TRUE
)
})

0 comments on commit 356fa04

Please sign in to comment.