Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Varying calibration / prediction results #22

Closed
GitHunter0 opened this issue Mar 17, 2021 · 4 comments
Closed

Varying calibration / prediction results #22

GitHunter0 opened this issue Mar 17, 2021 · 4 comments

Comments

@GitHunter0
Copy link

Every calibration / prediction for the deepAR model gives a slightly different result. Below, forecast1 and forecast2 will never be exactly the same.

library(modeltime.gluonts)
library(tidymodels)
library(tidyverse)

model_fit <- modeltime.gluonts::deep_ar(
    id                    = "id",
    freq                  = "M",
    prediction_length     = 24,
    lookback_length       = 36,
    epochs                = 10, 
    num_batches_per_epoch = 50,
    learn_rate            = 0.001,
    num_layers            = 2,
    dropout               = 0.10
  ) %>%
  set_engine("gluonts_deepar") %>%
  fit(formula = value ~ ., data = training(m750_splits))
  
forecast1 <- 
  modeltime_table(model_fit) %>%
  modeltime_forecast(new_data = testing(m750_splits),
                     actual_data = m750) %>% 
  filter(.key=="prediction") %>% 
  select(.index, .value)

forecast2 <- 
  modeltime_table(model_fit) %>%
  modeltime_forecast(new_data = testing(m750_splits),
                     actual_data = m750) %>% 
  filter(.key=="prediction") %>% 
  select(.index, .value)

I was aware deep learning models estimates cannot be exactly replicated but I did not know the forecast would vary each time too, even using the same parameters and estimates. Do you know the reason?
A disclaimer about that in the documentation might be useful for non-experts in deep learning like myself.

Thanks again

I saw that you just launch modeltime.h2o, very cool stuff, I will explore that now.

@mdancho84
Copy link
Contributor

mdancho84 commented Mar 17, 2021

This variation is actually a feature of GluonTS. But there is a solution, set the MXNet seed (discussed here and below).

Why is variation a feature?

Because GluonTS bills itself as a "Probabilistic Forecasting Software". This statement varies by algorithm, but for DeepAR, you get a probabilistic forecast.

What's happening under the hood that makes it probabilistic?

We are actually making a mean prediction from many paths that DeepAR generates. We only use the mean inside of Modeltime, but if you forecast with GluonTS DeepAR in python you can actually get quantiles around the forecast.

We don't use that feature since it doesn't fit into the Modeltime framework (currently - we use a quantile around a calibration forecast, refer to modeltime_calibrate()), but maybe there are benefits to including extras from the GluonTS model like quantiles that you can return.

That's for another day... for now, just recognize that your forecast variance is actually a result of how DeepAR forecasts and not an error occurring.

Solution

GluonTS internally uses a seed from the MXNet library. We can set this via reticulate.

mxnet <- reticulate::import("mxnet")
mxnet$random$seed <- 123 

@GitHunter0
Copy link
Author

Hey Matt, I was surprised to see that even the forecast is probabilistic. I tried mxnet seed solution but it is still giving different forecasts (forecast1 != forecast2). Since the variation is always small, it is not a big deal. Feel free to close this issue. Thanks man

@mdancho84
Copy link
Contributor

Ok, that's odd. I would think the MXNet Random Seed would solve it. I'll close for now, but can reopen if something is amiss.

@mesdi
Copy link

mesdi commented Sep 23, 2023

I think not retrieving reproducible results will pose a significant obstacle to using this package.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants