Skip to content

Commit

Permalink
Merge pull request #170 from microsoft/mitokic/1025/2024/global-model…
Browse files Browse the repository at this point in the history
…s-update

Mitokic/1025/2024/global models update
  • Loading branch information
mitokic authored Oct 29, 2024
2 parents f5825bd + b1fe9ed commit 370d625
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 14 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: finnts
Title: Microsoft Finance Time Series Forecasting Framework
Version: 0.5.0
Version: 0.5.0.9001
Authors@R:
c(person(given = "Mike",
family = "Tokic",
Expand Down
12 changes: 12 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# finnts 0.5.9001 (development version)

## Improvements

- Shortened global model list to just xgboost
- Faster xgboost model training for larger datasets
- Faster feature selection for global model training

## Bug Fixes

- Error in formatting of training data for global models

# finnts 0.5.0

## Improvements
Expand Down
2 changes: 1 addition & 1 deletion R/ensemble_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ ensemble_models <- function(run_info,
avail_arg_list <- list(
"train_data" = prep_ensemble_tbl %>% dplyr::select(-Train_Test_ID),
"model_type" = "ensemble",
"pca" = FALSE,
"pca" = FALSE,
"multistep" = FALSE
)

Expand Down
14 changes: 10 additions & 4 deletions R/feature_selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ run_feature_selection <- function(input_data,
return(fs_list)
}

# check for multiple time series
if (length(unique(input_data$Combo)) > 1) {
global <- TRUE
} else {
global <- FALSE
}

# check for external regressors future values
future_xregs <- multi_future_xreg_check(
input_data,
Expand Down Expand Up @@ -80,11 +87,11 @@ run_feature_selection <- function(input_data,
}

# run feature selection
if (date_type %in% c("day", "week")) {
if (date_type %in% c("day", "week") | global) {
# number of votes needed for feature to be selected
votes_needed <- 3

# don't run leave one feature out process for daily and weekly data
# don't run leave one feature out process for daily, weekly, or global model data
lofo_results <- tibble::tibble()

# target correlation
Expand All @@ -96,7 +103,7 @@ run_feature_selection <- function(input_data,
) %>%
dplyr::select(Feature, Vote, Auto_Accept)

# don't run boruta process for daily and weekly data
# don't run boruta for daily, weekly, or global model data
boruta_results <- tibble::tibble()
} else {
if (!fast) { # full implementation
Expand Down Expand Up @@ -253,7 +260,6 @@ run_feature_selection <- function(input_data,

return(setNames(list(fs_list), element_name))
}

return(fs_list_final)
}

Expand Down
5 changes: 3 additions & 2 deletions R/models.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ list_r2_models <- function() {
#' @return list of models
#' @noRd
list_global_models <- function() {
list <- c("cubist", "glmnet", "mars", "svm-poly", "svm-rbf", "xgboost")
list <- c("xgboost")

return(list)
}
Expand All @@ -69,7 +69,8 @@ list_global_models <- function() {
#' @noRd
list_multivariate_models <- function() {
list <- c(
list_global_models(), "arima-boost", "arimax", "prophet-boost", "prophet-xregs",
"cubist", "glmnet", "mars", "svm-poly", "svm-rbf", "xgboost",
"arima-boost", "arimax", "prophet-boost", "prophet-xregs",
"nnetar-xregs"
)

Expand Down
50 changes: 50 additions & 0 deletions R/train_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,21 @@ train_models <- function(run_info,

workflow <- workflow$Model_Workflow[[1]]

if (nrow(prep_data) > 500 & model == "xgboost") {
# update xgboost model to use 'hist' tree method to speed up training
workflow <- workflows::update_model(workflow,
workflows::extract_spec_parsnip(workflow) %>%
parsnip::set_args(tree_method = "hist"))
}

if (combo_hash == "All-Data") {
# adjust column types to match original data
prep_data <- adjust_column_types(
prep_data,
workflows::extract_recipe(workflow, estimated = FALSE)
)
}

if (feature_selection & model %in% fs_model_list) {
# update model workflow to only use features from feature selection process
if (data_prep_recipe == "R1") {
Expand Down Expand Up @@ -987,3 +1002,38 @@ undifference_recipe <- function(recipe_data,

return(final_recipe_data)
}

#' Function to enforce correct column formatting
#'
#' @param data data
#' @param recipe recipe
#'
#' @return tbl with correct column types
#' @noRd
adjust_column_types <- function(data, recipe) {
# Extract the required column types from the recipe
expected_types <- recipe$var_info %>%
dplyr::select(variable, type) %>%
dplyr::mutate(type = purrr::map_chr(type, ~ .x[[1]]))

# Identify and coerce mismatched columns
for (i in seq_len(nrow(expected_types))) {
col_name <- expected_types$variable[i]
expected_type <- expected_types$type[i]

# Check if column exists and type mismatch
if (col_name %in% names(data)) {
actual_type <- class(data[[col_name]])[1]

# Convert if types are different
if (expected_type == "string" && actual_type != "character") {
data[[col_name]] <- as.character(data[[col_name]])
} else if (expected_type %in% c("numeric", "double") && actual_type != "numeric") {
data[[col_name]] <- as.numeric(data[[col_name]])
} else if (expected_type == "date" && actual_type != "Date") {
data[[col_name]] <- as.Date(data[[col_name]])
}
}
}
return(data)
}
2 changes: 1 addition & 1 deletion R/utility.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ utils::globalVariables(c(
"term", "Column", "Box_Cox_Lambda", "get_recipie_configurable", "Agg", "Unique", "Var",
"Var_Combo", "regressor", "regressor_tbl", "value_level_iter", ".actual", ".fitted",
"forecast_horizon", "lag", "new_data", "object", "fit", "Row_Num", "Run_Number", "weight",
"Total", "Weight", "batch"
"Total", "Weight", "batch", "variable", "type"
))

#' @importFrom magrittr %>%
Expand Down
10 changes: 5 additions & 5 deletions vignettes/models-used-in-finnts.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ reactable::reactable(
rbind(data.frame(Model = "arima", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "Regression model that is based on finding relationships between lagged values of the target variable you are trying to forecast.")) %>%
rbind(data.frame(Model = "arima-boost", Type = "multivariate, local", Underlying.Package = "modeltime, forecast, xgboost", Description = "Arima model (refer to arima) that models the trend compoent of target variable, then uses xgboost model (refer to xgboost) to train on the remaining residuals.")) %>%
rbind(data.frame(Model = "arimax", Type = "multivariate, local", Underlying.Package = "modeltime, forecast", Description = "ARIMA model that incorporates external regressors and other engineered features.")) %>%
rbind(data.frame(Model = "cubist", Type = "multivariate, local, global, ensemble", Underlying.Package = "rules", Description = "Hybrid of tree based and linear regression approach. Many decision trees are built, but regression coefficients are used at each terminal node instead of averging values in other tree based approaches.")) %>%
rbind(data.frame(Model = "cubist", Type = "multivariate, local, ensemble", Underlying.Package = "rules", Description = "Hybrid of tree based and linear regression approach. Many decision trees are built, but regression coefficients are used at each terminal node instead of averging values in other tree based approaches.")) %>%
rbind(data.frame(Model = "croston", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "Useful for intermittent demand forecasting, aka when there are a lot of periods of zero values. Involves simple exponential smoothing on non-zero values of target variable and another application of seasonal exponential smoothing on periods between non-zero elements of the target variable. Refer to ets for more details on exponential smoothing.")) %>%
rbind(data.frame(Model = "ets", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "Forecasts produced using exponential smoothing methods are weighted averages of past observations, with the weights decaying exponentially as the observations get older. Exponential smoothing models try to forecast the components of a time series which can be broken down in to error, trend, and seasonality. These components can be forecasted separately then either added or multiplied together to get the final forecast output.")) %>%
rbind(data.frame(Model = "glmnet", Type = "multivariate, local, global, ensemble", Underlying.Package = "parsnip, glmnet", Description = "Linear regression (line of best fit) with regularization to help prevent overfitting and built in variable selection.")) %>%
rbind(data.frame(Model = "mars", Type = "multivariate, local, global", Underlying.Package = "parsnip, earth", Description = "An extension to linear regression that captures nonlinearities and interactions between variables.")) %>%
rbind(data.frame(Model = "glmnet", Type = "multivariate, local, ensemble", Underlying.Package = "parsnip, glmnet", Description = "Linear regression (line of best fit) with regularization to help prevent overfitting and built in variable selection.")) %>%
rbind(data.frame(Model = "mars", Type = "multivariate, local", Underlying.Package = "parsnip, earth", Description = "An extension to linear regression that captures nonlinearities and interactions between variables.")) %>%
rbind(data.frame(Model = "meanf", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "Simple average of previous year of target variable values.")) %>%
rbind(data.frame(Model = "nnetar", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "A neural network autoregression model is a traditional feed forward neural network (sometimes called an perceptron) that is fed by lagged values of the historical data set (similar to ARIMA).")) %>%
rbind(data.frame(Model = "nnetar-xregs", Type = "multivariate, local", Underlying.Package = "modeltime, forecast", Description = "Same approach as nnetar but can incorporate other features in addition to the target variable, like external regressors and date features.")) %>%
Expand All @@ -37,8 +37,8 @@ reactable::reactable(
rbind(data.frame(Model = "snaive", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "Simple model that takes the value from the same period in the previous year.")) %>%
rbind(data.frame(Model = "stlm-arima", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "Applies an STL decomposition (breaks out target variable into seasonal, trend, and error/residual/remainder components), models the seasonally adjusted data, reseasonalizes, and returns the forecasts. An arima model (refer to arima) is used in forecasting the seasonaly adjusted data.")) %>%
rbind(data.frame(Model = "stlm-ets", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "Applies an STL decomposition (breaks out target variable into seasonal, trend, and error/residual/remainder components), models the seasonally adjusted data, reseasonalizes, and returns the forecasts. An ets model (refer to ets) is used in forecasting the seasonaly adjusted data.")) %>%
rbind(data.frame(Model = "svm-poly", Type = "multivariate, local, global, ensemble", Underlying.Package = "parsnip, kernlab", Description = "Uses a nonlinear function, specifically a polynomial function, to create a regression line of the target variable.")) %>%
rbind(data.frame(Model = "svm-rbf", Type = "multivariate, local, global, ensemble", Underlying.Package = "parsnip, kernlab", Description = "Uses a nonlinear function, specifically a radial basis function, to create a regression line of the target variable.")) %>%
rbind(data.frame(Model = "svm-poly", Type = "multivariate, local, ensemble", Underlying.Package = "parsnip, kernlab", Description = "Uses a nonlinear function, specifically a polynomial function, to create a regression line of the target variable.")) %>%
rbind(data.frame(Model = "svm-rbf", Type = "multivariate, local, ensemble", Underlying.Package = "parsnip, kernlab", Description = "Uses a nonlinear function, specifically a radial basis function, to create a regression line of the target variable.")) %>%
rbind(data.frame(Model = "tbats", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "A spin off of the traditional ets model (refer to ets), with some additional components to capture multiple seasonalities.")) %>%
rbind(data.frame(Model = "theta", Type = "univariate, local", Underlying.Package = "modeltime, forecast", Description = "Theta is similar to exponential smoothing (refer to ets) but with another component called drift. Adding drift to exponential smoothing allows the forecast to increase or decrease over time, where the amount of change over time (called the drift) is set to be the average change seen within the historical data.")) %>%
rbind(data.frame(Model = "xgboost", Type = "multivariate, local, global, ensemble", Underlying.Package = "parsnip, xgboost", Description = "Builds many decision trees (similar to random forests), but predictions that are initially inaccurate are applied more weight in subsequent training rounds to increase accuracy across all predictions.")),
Expand Down

0 comments on commit 370d625

Please sign in to comment.