Skip to content

Commit

Permalink
#22 recursive chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Oct 17, 2022
1 parent f39e1c9 commit 624d44d
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions R/modeltime_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,14 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data,
f = (seq_len(nrow(new_data)) - 1) %/% chunk_size)

# LOOP LOGIC ----

# print(new_data)

.first_slice <- new_data %>%
dplyr::slice_head(n = chunk_size)

# print(.first_slice)

.forecasts <- modeltime::mdl_time_forecast(
object,
new_data = .first_slice,
Expand All @@ -418,8 +423,12 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data,
.forecast_from_model <- .forecasts %>%
dplyr::filter(.key == "prediction")

# print(.forecast_from_model)

new_data[idx_sets[[1]], y_var] <- .forecast_from_model$.value

# print(new_data)

.temp_new_data <- dplyr::bind_rows(
train_tail,
new_data
Expand All @@ -433,6 +442,8 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data,
transform_window_start <- min(idx_sets[[i]])
transform_window_end <- max(idx_sets[[i]]) + n_train_tail

# print(.temp_new_data[transform_window_start:transform_window_end,])

# .nth_slice <- .transform(.temp_new_data, nrow(new_data), i)
.nth_slice <- .transform(.temp_new_data[transform_window_start:transform_window_end,], length(idx_sets[[i]]))

Expand All @@ -448,20 +459,22 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data,
...
)

# print(.nth_forecast)

.nth_forecast_from_model <- .nth_forecast %>%
dplyr::filter(.key == "prediction") %>%
.[1,]

# print(.nth_forecast_from_model)

.forecasts <- dplyr::bind_rows(
.forecasts, .nth_forecast_from_model
)

new_data[idx_sets[[i]], y_var] <- .nth_forecast_from_model$.value
.temp_new_data[idx_sets[[i]] + n_train_tail, y_var] <- .nth_forecast_from_model$.value
}
}

# print(.forecasts)

return(.forecasts)
}

Expand Down Expand Up @@ -514,6 +527,14 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data,
dplyr::mutate(rowid.. = dplyr::row_number()) %>%
dplyr::ungroup()

# Fix - When ID is dummied
if (!is.null(object$spec$remove_id)) {
if (object$spec$remove_id) {
.first_slice <- .first_slice %>%
dplyr::select(-(!! .id))
}
}

if ("rowid.." %in% names(.first_slice)) {
.first_slice <- .first_slice %>% dplyr::select(-rowid..)
}
Expand All @@ -532,7 +553,7 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data,
.forecast_from_model <- .forecasts %>%
dplyr::filter(.key == "prediction")

.preds[.preds$rowid.. %in% idx_sets[[1]], 2] <- new_data[.preds$rowid.. %in% idx_sets[[1]], y_var] <- .forecast_from_model$.value
new_data[.preds$rowid.. %in% idx_sets[[1]], y_var] <- .forecast_from_model$.value

.groups <- new_data %>%
dplyr::group_by(!! .id) %>%
Expand All @@ -556,7 +577,6 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data,
transform_window_end <- max(idx_sets[[i]]) + n_train_tail



.nth_slice <- .transform(.temp_new_data %>%
dplyr::group_by(!! .id) %>%
dplyr::slice(transform_window_start:transform_window_end),
Expand Down Expand Up @@ -594,8 +614,7 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data,
.forecasts, .nth_forecast_from_model
)


.preds[.preds$rowid.. %in% idx_sets[[i]], 2] <- .temp_new_data[.temp_new_data$rowid.. %in% idx_sets[[i]], y_var] <- .nth_forecast_from_model$.value
.temp_new_data[.temp_new_data$rowid.. %in% idx_sets[[i]], y_var] <- .nth_forecast_from_model$.value
}
}

Expand Down

0 comments on commit 624d44d

Please sign in to comment.