From a45d4f7d686aa6b57ce25f342a71eea79507f01c Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Sat, 15 Jun 2024 18:22:33 +0300 Subject: [PATCH] Fix factor arg handling (#999) --- R/data.R | 21 +++++++++++---------- tests/testthat/test-data.R | 3 +++ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/R/data.R b/R/data.R index ce5b9cf6..f483ba4c 100644 --- a/R/data.R +++ b/R/data.R @@ -186,18 +186,19 @@ process_data <- function(data, model_variables = NULL) { # generating a decimal point in write_stan_json if (data_variables[[var_name]]$type == "int" && !is.integer(data[[var_name]])) { - if (!isTRUE(all(is_wholenumber(data[[var_name]])))) { - # Don't warn for NULL/NA, as different warnings are used for those - if (!isTRUE(any(is.na(data[[var_name]])))) { - warning("A non-integer value was supplied for '", var_name, "'!", - " It will be truncated to an integer.", call. = FALSE) + if (!is.factor(data[[var_name]])) { + if (!isTRUE(all(is_wholenumber(data[[var_name]])))) { + # Don't warn for NULL/NA, as different warnings are used for those + if (!isTRUE(any(is.na(data[[var_name]])))) { + warning("A non-integer value was supplied for '", var_name, "'!", + " It will be truncated to an integer.", call. = FALSE) + } + } else { + # Round before setting mode to integer to avoid floating point errors + data[[var_name]] <- round(data[[var_name]]) } - mode(data[[var_name]]) <- "integer" - } else { - # Round before setting mode to integer to avoid floating point errors - data[[var_name]] <- round(data[[var_name]]) - mode(data[[var_name]]) <- "integer" } + mode(data[[var_name]]) <- "integer" } } } diff --git a/tests/testthat/test-data.R b/tests/testthat/test-data.R index 61ca79e3..c7283d9b 100644 --- a/tests/testthat/test-data.R +++ b/tests/testthat/test-data.R @@ -383,6 +383,9 @@ test_that("process_data warns on int coercion", { expect_no_warning( process_data(list(a = c(1, 2, 3)), model_variables = mod$variables()) ) + expect_no_warning( + process_data(list(a = factor(c("a", "b", "c"))), model_variables = mod$variables()) + ) }) test_that("Floating-point differences do not cause truncation towards 0", {