Skip to content

Commit

Permalink
Merge pull request #161 from microsoft/mitokic/05302024/global_model_…
Browse files Browse the repository at this point in the history
…fixes

global model bug fix for multistep horizon forecasting
  • Loading branch information
mitokic authored May 30, 2024
2 parents d757e25 + 5b5bc0b commit 03eabf3
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck, vip=?ignore-before-r=4.1.0, Boruta=?ignore-before-r=4.1.0
extra-packages: any::rcmdcheck, vip=?ignore-before-r=4.1.0, Boruta=?ignore-before-r=4.1.0, corrr=?ignore-before-r=4.1.0
needs: check

- uses: r-lib/actions/check-r-package@v2
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: finnts
Title: Microsoft Finance Time Series Forecasting Framework
Version: 0.4.0.9003
Version: 0.4.0.9004
Authors@R:
c(person(given = "Mike",
family = "Tokic",
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# finnts 0.4.0.9003 (DEVELOPMENT VERSION)
# finnts 0.4.0.9004 (DEVELOPMENT VERSION)

## Improvements

- Added support for hierarchical forecasting with external regressors
- Allow global models for hierarchical forecasts
- Multistep horizon forecasts for R1 recipe, listed as `multistep_horizon` within `prep_data()`

## Bug Fixes
Expand Down
30 changes: 25 additions & 5 deletions R/multistep_cubist.R
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,13 @@ predict.cubist_multistep_fit_impl <- function(object, new_data, ...) {
#' @export
cubist_multistep_predict_impl <- function(object, new_data, ...) {

# Date Mapping Table
date_tbl <- new_data %>%
dplyr::select(Date, Date_index.num) %>%
dplyr::distinct() %>%
dplyr::arrange(Date) %>%
dplyr::mutate(Run_Number = dplyr::row_number())

# PREPARE INPUTS
xreg_recipe <- object$extras$xreg_recipe
h_horizon <- nrow(new_data)
Expand All @@ -447,14 +454,16 @@ cubist_multistep_predict_impl <- function(object, new_data, ...) {
xreg_tbl <- modeltime::bake_xreg_recipe(xreg_recipe,
new_data,
format = "tbl"
)
) %>%
dplyr::left_join(date_tbl, by = "Date_index.num") %>%
dplyr::mutate(Row_Num = dplyr::row_number())

# PREDICTIONS
final_prediction <- c()
final_prediction <- tibble::tibble()
start_val <- 1

for (model_name in names(object$models)) {
if (start_val > nrow(xreg_tbl)) {
if (start_val > nrow(date_tbl)) {
break
}

Expand All @@ -463,17 +472,28 @@ cubist_multistep_predict_impl <- function(object, new_data, ...) {
cubist_model <- object$models[[model_name]]

xreg_tbl_final <- xreg_tbl %>%
dplyr::slice(start_val:lag_number)
dplyr::filter(
Run_Number >= start_val,
Run_Number <= lag_number
)

if (!is.null(xreg_tbl)) {
preds_cubist <- predict(cubist_model, xreg_tbl_final)
} else {
preds_cubist <- rep(0, h_horizon)
}

preds_cubist <- tibble::tibble(.pred = preds_cubist) %>%
dplyr::mutate(Row_Num = xreg_tbl_final$Row_Num)

start_val <- as.numeric(lag_number) + 1
final_prediction <- c(final_prediction, preds_cubist)
final_prediction <- rbind(final_prediction, preds_cubist)
}

# Ensure it's sorted correctly for global models
final_prediction <- final_prediction %>%
dplyr::arrange(Row_Num) %>%
dplyr::pull(.pred)

return(final_prediction)
}
26 changes: 23 additions & 3 deletions R/multistep_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@ predict.glmnet_multistep_fit_impl <- function(object, new_data, ...) {
#' @export
glmnet_multistep_predict_impl <- function(object, new_data, ...) {

# Date Mapping Table
date_tbl <- new_data %>%
dplyr::select(Date, Date_index.num) %>%
dplyr::distinct() %>%
dplyr::arrange(Date) %>%
dplyr::mutate(Run_Number = dplyr::row_number())

# PREPARE INPUTS
xreg_recipe <- object$extras$xreg_recipe
h_horizon <- nrow(new_data)
Expand All @@ -435,14 +442,16 @@ glmnet_multistep_predict_impl <- function(object, new_data, ...) {
xreg_tbl <- modeltime::bake_xreg_recipe(xreg_recipe,
new_data,
format = "tbl"
)
) %>%
dplyr::left_join(date_tbl, by = "Date_index.num") %>%
dplyr::mutate(Row_Num = dplyr::row_number())

# PREDICTIONS
final_prediction <- tibble::tibble()
start_val <- 1

for (model_name in names(object$models)) {
if (start_val > nrow(xreg_tbl)) {
if (start_val > nrow(date_tbl)) {
break
}

Expand All @@ -451,17 +460,28 @@ glmnet_multistep_predict_impl <- function(object, new_data, ...) {
glmnet_model <- object$models[[model_name]]

xreg_tbl_final <- xreg_tbl %>%
dplyr::slice(start_val:lag_number)
dplyr::filter(
Run_Number >= start_val,
Run_Number <= lag_number
)

if (!is.null(xreg_tbl)) {
preds_glmnet <- predict(glmnet_model, xreg_tbl_final)
} else {
preds_glmnet <- rep(0, h_horizon)
}

preds_glmnet <- preds_glmnet %>%
dplyr::mutate(Row_Num = xreg_tbl_final$Row_Num)

start_val <- as.numeric(lag_number) + 1
final_prediction <- rbind(final_prediction, preds_glmnet)
}

# Ensure it's sorted correctly for global models
final_prediction <- final_prediction %>%
dplyr::arrange(Row_Num) %>%
dplyr::select(.pred)

return(final_prediction)
}
26 changes: 23 additions & 3 deletions R/multistep_mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,13 @@ predict.mars_multistep_fit_impl <- function(object, new_data, ...) {
#' @export
mars_multistep_predict_impl <- function(object, new_data, ...) {

# Date Mapping Table
date_tbl <- new_data %>%
dplyr::select(Date, Date_index.num) %>%
dplyr::distinct() %>%
dplyr::arrange(Date) %>%
dplyr::mutate(Run_Number = dplyr::row_number())

# PREPARE INPUTS
xreg_recipe <- object$extras$xreg_recipe
h_horizon <- nrow(new_data)
Expand All @@ -458,14 +465,16 @@ mars_multistep_predict_impl <- function(object, new_data, ...) {
xreg_tbl <- modeltime::bake_xreg_recipe(xreg_recipe,
new_data,
format = "tbl"
)
) %>%
dplyr::left_join(date_tbl, by = "Date_index.num") %>%
dplyr::mutate(Row_Num = dplyr::row_number())

# PREDICTIONS
final_prediction <- tibble::tibble()
start_val <- 1

for (model_name in names(object$models)) {
if (start_val > nrow(xreg_tbl)) {
if (start_val > nrow(date_tbl)) {
break
}

Expand All @@ -474,17 +483,28 @@ mars_multistep_predict_impl <- function(object, new_data, ...) {
mars_model <- object$models[[model_name]]

xreg_tbl_final <- xreg_tbl %>%
dplyr::slice(start_val:lag_number)
dplyr::filter(
Run_Number >= start_val,
Run_Number <= lag_number
)

if (!is.null(xreg_tbl)) {
preds_mars <- predict(mars_model, xreg_tbl_final)
} else {
preds_mars <- rep(0, h_horizon)
}

preds_mars <- preds_mars %>%
dplyr::mutate(Row_Num = xreg_tbl_final$Row_Num)

start_val <- as.numeric(lag_number) + 1
final_prediction <- rbind(final_prediction, preds_mars)
}

# Ensure it's sorted correctly for global models
final_prediction <- final_prediction %>%
dplyr::arrange(Row_Num) %>%
dplyr::select(.pred)

return(final_prediction)
}
26 changes: 23 additions & 3 deletions R/multistep_svm_poly.R
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,13 @@ predict.svm_poly_multistep_fit_impl <- function(object, new_data, ...) {
#' @export
svm_poly_multistep_predict_impl <- function(object, new_data, ...) {

# Date Mapping Table
date_tbl <- new_data %>%
dplyr::select(Date, Date_index.num) %>%
dplyr::distinct() %>%
dplyr::arrange(Date) %>%
dplyr::mutate(Run_Number = dplyr::row_number())

# PREPARE INPUTS
xreg_recipe <- object$extras$xreg_recipe
h_horizon <- nrow(new_data)
Expand All @@ -484,14 +491,16 @@ svm_poly_multistep_predict_impl <- function(object, new_data, ...) {
xreg_tbl <- modeltime::bake_xreg_recipe(xreg_recipe,
new_data,
format = "tbl"
)
) %>%
dplyr::left_join(date_tbl, by = "Date_index.num") %>%
dplyr::mutate(Row_Num = dplyr::row_number())

# PREDICTIONS
final_prediction <- tibble::tibble()
start_val <- 1

for (model_name in names(object$models)) {
if (start_val > nrow(xreg_tbl)) {
if (start_val > nrow(date_tbl)) {
break
}

Expand All @@ -500,17 +509,28 @@ svm_poly_multistep_predict_impl <- function(object, new_data, ...) {
svm_poly_model <- object$models[[model_name]]

xreg_tbl_final <- xreg_tbl %>%
dplyr::slice(start_val:lag_number)
dplyr::filter(
Run_Number >= start_val,
Run_Number <= lag_number
)

if (!is.null(xreg_tbl)) {
preds_svm_poly <- predict(svm_poly_model, xreg_tbl_final)
} else {
preds_svm_poly <- rep(0, h_horizon)
}

preds_svm_poly <- preds_svm_poly %>%
dplyr::mutate(Row_Num = xreg_tbl_final$Row_Num)

start_val <- as.numeric(lag_number) + 1
final_prediction <- rbind(final_prediction, preds_svm_poly)
}

# Ensure it's sorted correctly for global models
final_prediction <- final_prediction %>%
dplyr::arrange(Row_Num) %>%
dplyr::select(.pred)

return(final_prediction)
}
26 changes: 23 additions & 3 deletions R/multistep_svm_rbf.R
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@ predict.svm_rbf_multistep_fit_impl <- function(object, new_data, ...) {
#' @export
svm_rbf_multistep_predict_impl <- function(object, new_data, ...) {

# Date Mapping Table
date_tbl <- new_data %>%
dplyr::select(Date, Date_index.num) %>%
dplyr::distinct() %>%
dplyr::arrange(Date) %>%
dplyr::mutate(Run_Number = dplyr::row_number())

# PREPARE INPUTS
xreg_recipe <- object$extras$xreg_recipe
h_horizon <- nrow(new_data)
Expand All @@ -464,14 +471,16 @@ svm_rbf_multistep_predict_impl <- function(object, new_data, ...) {
xreg_tbl <- modeltime::bake_xreg_recipe(xreg_recipe,
new_data,
format = "tbl"
)
) %>%
dplyr::left_join(date_tbl, by = "Date_index.num") %>%
dplyr::mutate(Row_Num = dplyr::row_number())

# PREDICTIONS
final_prediction <- tibble::tibble()
start_val <- 1

for (model_name in names(object$models)) {
if (start_val > nrow(xreg_tbl)) {
if (start_val > nrow(date_tbl)) {
break
}

Expand All @@ -480,17 +489,28 @@ svm_rbf_multistep_predict_impl <- function(object, new_data, ...) {
svm_rbf_model <- object$models[[model_name]]

xreg_tbl_final <- xreg_tbl %>%
dplyr::slice(start_val:lag_number)
dplyr::filter(
Run_Number >= start_val,
Run_Number <= lag_number
)

if (!is.null(xreg_tbl)) {
preds_svm_rbf <- predict(svm_rbf_model, xreg_tbl_final)
} else {
preds_svm_rbf <- rep(0, h_horizon)
}

preds_svm_rbf <- preds_svm_rbf %>%
dplyr::mutate(Row_Num = xreg_tbl_final$Row_Num)

start_val <- as.numeric(lag_number) + 1
final_prediction <- rbind(final_prediction, preds_svm_rbf)
}

# Ensure it's sorted correctly for global models
final_prediction <- final_prediction %>%
dplyr::arrange(Row_Num) %>%
dplyr::select(.pred)

return(final_prediction)
}
Loading

0 comments on commit 03eabf3

Please sign in to comment.