Skip to content

Commit

Permalink
recursive chunk size - modeltime#197
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Oct 17, 2022
1 parent e237622 commit f39e1c9
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 55 deletions.
156 changes: 103 additions & 53 deletions R/modeltime_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,14 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data,

.transform <- object$spec[["transform"]]
train_tail <- object$spec$train_tail
chunk_size <- object$spec$chunk_size

idx_sets <- split(x = seq_len(nrow(new_data)),
f = (seq_len(nrow(new_data)) - 1) %/% chunk_size)

# LOOP LOGIC ----
.first_slice <- new_data %>%
dplyr::slice_head(n = 1)
dplyr::slice_head(n = chunk_size)

.forecasts <- modeltime::mdl_time_forecast(
object,
Expand All @@ -414,37 +418,48 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data,
.forecast_from_model <- .forecasts %>%
dplyr::filter(.key == "prediction")

new_data[1, y_var] <- .forecast_from_model$.value
new_data[idx_sets[[1]], y_var] <- .forecast_from_model$.value

for (i in 2:nrow(new_data)) {
.temp_new_data <- dplyr::bind_rows(
train_tail,
new_data
)

.temp_new_data <- dplyr::bind_rows(
train_tail,
new_data
)
n_train_tail <- nrow(train_tail)

.nth_slice <- .transform(.temp_new_data, nrow(new_data), i)
if (length(idx_sets) > 1){
for (i in 2:length(idx_sets)) {

.nth_forecast <- modeltime::mdl_time_forecast(
object,
new_data = .nth_slice,
h = h,
actual_data = actual_data,
keep_data = keep_data,
arrange_index = arrange_index,
...
)
transform_window_start <- min(idx_sets[[i]])
transform_window_end <- max(idx_sets[[i]]) + n_train_tail

.nth_forecast_from_model <-
.nth_forecast %>%
dplyr::filter(.key == "prediction") %>%
.[1,]
# .nth_slice <- .transform(.temp_new_data, nrow(new_data), i)
.nth_slice <- .transform(.temp_new_data[transform_window_start:transform_window_end,], length(idx_sets[[i]]))

.forecasts <- dplyr::bind_rows(
.forecasts, .nth_forecast_from_model
)
# print(.nth_slice)

new_data[i, y_var] <- .nth_forecast_from_model$.value
.nth_forecast <- modeltime::mdl_time_forecast(
object,
new_data = .nth_slice,
h = h,
actual_data = actual_data,
keep_data = keep_data,
arrange_index = arrange_index,
...
)

# print(.nth_forecast)

.nth_forecast_from_model <- .nth_forecast %>%
dplyr::filter(.key == "prediction") %>%
.[1,]

.forecasts <- dplyr::bind_rows(
.forecasts, .nth_forecast_from_model
)

new_data[idx_sets[[i]], y_var] <- .nth_forecast_from_model$.value
}
}

return(.forecasts)
Expand All @@ -463,9 +478,24 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data,
.transform <- object$spec[["transform"]]
train_tail <- object$spec$train_tail
id <- object$spec$id
chunk_size <- object$spec$chunk_size

.id <- dplyr::ensym(id)

unique_id_new_data <- new_data %>% dplyr::select(!! .id) %>% unique() %>% dplyr::pull()

unique_id_train_tail <- train_tail %>% dplyr::select(!! .id) %>% unique() %>% dplyr::pull()

if (length(dplyr::setdiff(unique_id_train_tail, unique_id_new_data)) >= 1){
train_tail <- train_tail %>% dplyr::filter(!! .id %in% unique_id_new_data)
}

n_groups <- dplyr::n_distinct(new_data[[id]])
group_size <- max(table(new_data[[id]]))

idx_sets <- split(x = seq_len(group_size),
f = (seq_len(group_size) - 1) %/% chunk_size)

# LOOP LOGIC ----

.preds <- tibble::tibble(.id = new_data %>% dplyr::pull(!! .id),
Expand All @@ -476,7 +506,7 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data,

.first_slice <- new_data %>%
dplyr::group_by(!! .id) %>%
dplyr::slice_head(n = 1) %>%
dplyr::slice_head(n = chunk_size) %>%
dplyr::ungroup()

new_data <- new_data %>%
Expand All @@ -502,7 +532,7 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data,
.forecast_from_model <- .forecasts %>%
dplyr::filter(.key == "prediction")

.preds[.preds$rowid.. == 1, 2] <- new_data[new_data$rowid.. == 1, y_var] <- .forecast_from_model$.value
.preds[.preds$rowid.. %in% idx_sets[[1]], 2] <- new_data[.preds$rowid.. %in% idx_sets[[1]], y_var] <- .forecast_from_model$.value

.groups <- new_data %>%
dplyr::group_by(!! .id) %>%
Expand All @@ -512,41 +542,61 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data,

new_data_size <- nrow(.preds)/.groups

for (i in 2:new_data_size) {
.temp_new_data <- dplyr::bind_rows(
train_tail,
new_data
)

.temp_new_data <- dplyr::bind_rows(
train_tail,
new_data
)
n_train_tail <- max(table(train_tail[[id]]))

.nth_slice <- .transform(.temp_new_data, new_data_size, i, id)
if (length(idx_sets) > 1){
for (i in 2:length(idx_sets)) {

if ("rowid.." %in% names(.first_slice)) {
.first_slice <- .first_slice %>% dplyr::select(-rowid..)
}
transform_window_start <- min(idx_sets[[i]])
transform_window_end <- max(idx_sets[[i]]) + n_train_tail

.nth_slice <- .nth_slice[names(.first_slice)]

.nth_forecast <- modeltime::mdl_time_forecast(
object,
new_data = .nth_slice,
h = h,
actual_data = actual_data,
keep_data = keep_data,
arrange_index = arrange_index,
...
) %>%
dplyr::filter(!is.na(.value))

.nth_forecast_from_model <- .nth_forecast %>%
dplyr::filter(.key == "prediction")
.nth_slice <- .transform(.temp_new_data %>%
dplyr::group_by(!! .id) %>%
dplyr::slice(transform_window_start:transform_window_end),
idx_sets[[i]], id)

# Fix - When ID is dummied
if (!is.null(object$spec$remove_id)) {
if (object$spec$remove_id) {
.nth_slice <- .nth_slice %>%
dplyr::select(-(!! .id))
}
}

if ("rowid.." %in% names(.nth_slice)) {
.nth_slice <- .nth_slice %>% dplyr::select(-rowid..)
}

.forecasts <- dplyr::bind_rows(
.forecasts, .nth_forecast_from_model
)
.nth_slice <- .nth_slice[names(.first_slice)]

.nth_forecast <- modeltime::mdl_time_forecast(
object,
new_data = .nth_slice,
h = h,
actual_data = actual_data,
keep_data = keep_data,
arrange_index = arrange_index,
...
) %>%
dplyr::filter(!is.na(.value))

.preds[.preds$rowid.. == i, 2] <- new_data[new_data$rowid.. == i, y_var] <- .nth_forecast_from_model$.value
.nth_forecast_from_model <- .nth_forecast %>%
dplyr::filter(.key == "prediction")

.forecasts <- dplyr::bind_rows(
.forecasts, .nth_forecast_from_model
)


.preds[.preds$rowid.. %in% idx_sets[[i]], 2] <- .temp_new_data[.temp_new_data$rowid.. %in% idx_sets[[i]], y_var] <- .nth_forecast_from_model$.value
}
}

return(.forecasts)
Expand Down
5 changes: 3 additions & 2 deletions R/modeltime_recursive.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@

#' @export
#' @importFrom modeltime recursive
recursive.mdl_time_ensemble <- function(object, transform, train_tail, id = NULL, ...){
recursive.mdl_time_ensemble <- function(object, transform, train_tail, id = NULL, chunk_size = 1, ...){

.class_obj <- if(!is.null(id)){"recursive_panel"} else {"recursive"}

object$spec[["forecast"]] <- .class_obj
object$spec[["transform"]] <- if(!is.null(id)){.prepare_panel_transform(transform)} else {.prepare_transform(transform)}
object$spec[["train_tail"]] <- train_tail
object$spec[["chunk_size"]] <- as.integer(chunk_size)

# Workflow and Model Fit Objects store y_var (outcome) differently
model_1 <- object$model_tbl$.model[[1]]
if (inherits(model_1, "workflow")) {
mld <- model_1 %>% workflows::pull_workflow_mold()
object$spec[["y_var"]] <- names(mld$outcomes)
} else if (inherits(model_1, "model_fit")) {
object$spec[["y_var"]] <- object$model_tbl$.model[[1]]$preproc$y_var
object$spec[["y_var"]] <- object$model_tbl$.model[[1]]$preproc$y_var
} else {
rlang::abort("Recursive does not currently support multi-stacked ensembles.")
}
Expand Down

0 comments on commit f39e1c9

Please sign in to comment.