Skip to content

Commit

Permalink
add arimax model
Browse files Browse the repository at this point in the history
  • Loading branch information
mitokic committed Sep 12, 2023
1 parent aaad702 commit bd8f177
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 20 deletions.
37 changes: 18 additions & 19 deletions R/models.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' @export
list_models <- function() {
list <- c(
"arima", "arima-boost", "cubist", "croston", "ets", "glmnet", "mars", "meanf",
"arima", "arima-boost", "arimax", "cubist", "croston", "ets", "glmnet", "mars", "meanf",
"nnetar", "nnetar-xregs", "prophet", "prophet-boost", "prophet-xregs", "snaive",
"stlm-arima", "stlm-ets", "svm-poly", "svm-rbf", "tbats", "theta", "xgboost"
)
Expand Down Expand Up @@ -112,17 +112,17 @@ get_recipe_combo <- function(train_data) {
#' @noRd

get_recipe_configurable <- function(train_data,
mutate_adj_half = FALSE,
rm_date = "plain",
step_nzv = "zv",
norm_date_adj_year = FALSE,
dummy_one_hot = TRUE,
character_factor = FALSE,
center_scale = FALSE,
one_hot = FALSE,
pca = TRUE,
corr = FALSE,
lincomb = FALSE) {
mutate_adj_half = FALSE,
rm_date = "plain",
step_nzv = "zv",
norm_date_adj_year = FALSE,
dummy_one_hot = TRUE,
character_factor = FALSE,
center_scale = FALSE,
one_hot = FALSE,
pca = TRUE,
corr = FALSE,
lincomb = FALSE) {
mutate_adj_half_fn <- function(df) {
if (mutate_adj_half) {
df %>%
Expand All @@ -146,7 +146,7 @@ get_recipe_configurable <- function(train_data,
"none" = df
)
}

corr_fn <- function(df) {
if (corr) {
df %>%
Expand Down Expand Up @@ -215,7 +215,7 @@ get_recipe_configurable <- function(train_data,
rm_lincomb_fn <- function(df) {
if (lincomb) {
df %>%
recipes::step_lincomb(recipes::all_numeric_predictors(), id = "remove_linear_combs")
recipes::step_lincomb(recipes::all_numeric_predictors(), id = "remove_linear_combs")
} else {
df
}
Expand Down Expand Up @@ -491,9 +491,8 @@ arima <- function(train_data,
#' @return Get the ARIMAX based model
#' @noRd
arimax <- function(train_data,
frequency,
pca) {

frequency,
pca) {
recipe_spec_arimax <- train_data %>%
get_recipe_configurable(
step_nzv = "zv",
Expand All @@ -506,12 +505,12 @@ arimax <- function(train_data,
seasonal_period = frequency
) %>%
parsnip::set_engine("auto_arima")

wflw_spec <- get_workflow_simple(
model_spec_arima,
recipe_spec_arimax
)

return(wflw_spec)
}

Expand Down
2 changes: 1 addition & 1 deletion R/run_info.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ set_run_info <- function(experiment_name = "finn_fcst",
fs::dir_create(tempdir(), models_folder)
fs::dir_create(tempdir(), forecasts_folder)
fs::dir_create(tempdir(), logs_folder)
} else if (is.null(storage_object) & substr(path, 1, 6) == "/synfs") {
} else if (is.null(storage_object) & substr(path, 1, 6) == "/synfs") {
temp_path <- stringr::str_replace(path, "/synfs/", "synfs:/")

if (!dir.exists(fs::path(path, prep_data_folder) %>% as.character())) {
Expand Down

0 comments on commit bd8f177

Please sign in to comment.