diff --git a/R/models.R b/R/models.R index 1a12da6a..d4a54abb 100644 --- a/R/models.R +++ b/R/models.R @@ -4,7 +4,7 @@ #' @export list_models <- function() { list <- c( - "arima", "arima-boost", "cubist", "croston", "ets", "glmnet", "mars", "meanf", + "arima", "arima-boost", "arimax", "cubist", "croston", "ets", "glmnet", "mars", "meanf", "nnetar", "nnetar-xregs", "prophet", "prophet-boost", "prophet-xregs", "snaive", "stlm-arima", "stlm-ets", "svm-poly", "svm-rbf", "tbats", "theta", "xgboost" ) @@ -112,17 +112,17 @@ get_recipe_combo <- function(train_data) { #' @noRd get_recipe_configurable <- function(train_data, - mutate_adj_half = FALSE, - rm_date = "plain", - step_nzv = "zv", - norm_date_adj_year = FALSE, - dummy_one_hot = TRUE, - character_factor = FALSE, - center_scale = FALSE, - one_hot = FALSE, - pca = TRUE, - corr = FALSE, - lincomb = FALSE) { + mutate_adj_half = FALSE, + rm_date = "plain", + step_nzv = "zv", + norm_date_adj_year = FALSE, + dummy_one_hot = TRUE, + character_factor = FALSE, + center_scale = FALSE, + one_hot = FALSE, + pca = TRUE, + corr = FALSE, + lincomb = FALSE) { mutate_adj_half_fn <- function(df) { if (mutate_adj_half) { df %>% @@ -146,7 +146,7 @@ get_recipe_configurable <- function(train_data, "none" = df ) } - + corr_fn <- function(df) { if (corr) { df %>% @@ -215,7 +215,7 @@ get_recipe_configurable <- function(train_data, rm_lincomb_fn <- function(df) { if (lincomb) { df %>% - recipes::step_lincomb(recipes::all_numeric_predictors(), id = "remove_linear_combs") + recipes::step_lincomb(recipes::all_numeric_predictors(), id = "remove_linear_combs") } else { df } @@ -491,9 +491,8 @@ arima <- function(train_data, #' @return Get the ARIMAX based model #' @noRd arimax <- function(train_data, - frequency, - pca) { - + frequency, + pca) { recipe_spec_arimax <- train_data %>% get_recipe_configurable( step_nzv = "zv", @@ -506,12 +505,12 @@ arimax <- function(train_data, seasonal_period = frequency ) %>% parsnip::set_engine("auto_arima") - + wflw_spec <- get_workflow_simple( model_spec_arima, recipe_spec_arimax ) - + return(wflw_spec) } diff --git a/R/run_info.R b/R/run_info.R index 774d1f3f..bf9acc5b 100644 --- a/R/run_info.R +++ b/R/run_info.R @@ -80,7 +80,7 @@ set_run_info <- function(experiment_name = "finn_fcst", fs::dir_create(tempdir(), models_folder) fs::dir_create(tempdir(), forecasts_folder) fs::dir_create(tempdir(), logs_folder) - } else if (is.null(storage_object) & substr(path, 1, 6) == "/synfs") { + } else if (is.null(storage_object) & substr(path, 1, 6) == "/synfs") { temp_path <- stringr::str_replace(path, "/synfs/", "synfs:/") if (!dir.exists(fs::path(path, prep_data_folder) %>% as.character())) {