From fda57df775fd5f61222253270bb8d7c5d3f235b9 Mon Sep 17 00:00:00 2001 From: Mike Tokic Date: Wed, 24 Apr 2024 09:25:15 -0700 Subject: [PATCH] adjust how many sub models are built for each date type --- R/prep_data.R | 9 +++++++++ tests/testthat/test-multistep_horizon.R | 21 +++++++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/R/prep_data.R b/R/prep_data.R index 459d24d6..3a0f6e17 100644 --- a/R/prep_data.R +++ b/R/prep_data.R @@ -947,7 +947,16 @@ get_lag_periods <- function(lag_periods, "day" = c(7, 14, 21, 28, 60, 90, 180, 365) ) + # change multistep horizons to run based on date type if (multistep_horizon) { + if (date_type == "day") { + oplist <- c(28, 90, 180) + } else if (date_type == "week") { + oplist <- c(4, 12, 24) + } else if (date_type == "month") { + c(1, 2, 3, 6, 12) + } + if (max(oplist) < forecast_horizon) { lag_periods <- c(oplist, forecast_horizon) } else { diff --git a/tests/testthat/test-multistep_horizon.R b/tests/testthat/test-multistep_horizon.R index 36943191..38780ddc 100644 --- a/tests/testthat/test-multistep_horizon.R +++ b/tests/testthat/test-multistep_horizon.R @@ -26,7 +26,8 @@ test_that("multistep_horizon yearly data", { back_test_scenarios = 4, models_to_run = "xgboost", run_ensemble_models = FALSE, - num_hyperparameters = 1 + num_hyperparameters = 1, + pca = TRUE ) # train models @@ -69,7 +70,8 @@ test_that("multistep_horizon quarterly data", { back_test_scenarios = 6, models_to_run = "mars", run_ensemble_models = FALSE, - num_hyperparameters = 1 + num_hyperparameters = 1, + pca = TRUE ) # train models @@ -118,7 +120,8 @@ test_that("multistep_horizon monthly data", { back_test_scenarios = 2, models_to_run = "cubist", run_ensemble_models = FALSE, - num_hyperparameters = 1 + num_hyperparameters = 1, + pca = TRUE ) # train models @@ -164,7 +167,8 @@ test_that("multistep_horizon weekly data", { back_test_scenarios = 4, models_to_run = "glmnet", run_ensemble_models = FALSE, - num_hyperparameters = 1 + num_hyperparameters = 1, + pca = TRUE ) # train models @@ -177,7 +181,7 @@ test_that("multistep_horizon weekly data", { model_length <- length(workflow_tbl$Model_Fit[[1]]$fit$fit$fit$models) # Assertions - expect_equal(model_length, 4) + expect_equal(model_length, 1) }) test_that("multistep_horizon daily data", { @@ -200,7 +204,7 @@ test_that("multistep_horizon daily data", { combo_variables = c("id"), target_variable = "value", date_type = "day", - forecast_horizon = 28, + forecast_horizon = 30, recipes_to_run = "R1", multistep_horizon = TRUE ) @@ -211,7 +215,8 @@ test_that("multistep_horizon daily data", { back_test_spacing = 7, models_to_run = "glmnet", run_ensemble_models = FALSE, - num_hyperparameters = 1 + num_hyperparameters = 1, + pca = TRUE ) # train models @@ -224,5 +229,5 @@ test_that("multistep_horizon daily data", { model_length <- length(workflow_tbl$Model_Fit[[1]]$fit$fit$fit$models) # Assertions - expect_equal(model_length, 4) + expect_equal(model_length, 2) })