Skip to content

Commit

Permalink
Add training class with hyperparameter tunings for Prophet forecast (#9)
Browse files Browse the repository at this point in the history
Follow the style of hyperopt-sklearn, add ProphetHyperoptEstimator as the wrapper for hyperparameter tuning with prophet.
  • Loading branch information
lu-wang-dl authored Sep 10, 2021
1 parent d502f60 commit 5bafcf1
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 4 deletions.
170 changes: 170 additions & 0 deletions runtime/databricks/automl_runtime/forecast/prophet/forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#
# Copyright (C) 2021 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
from enum import Enum
from functools import partial
from typing import Any, Dict, Optional

import hyperopt
import numpy as np
import pandas as pd

from databricks.automl_runtime.forecast.prophet.diagnostics import generate_cutoffs


class ProphetHyperParams(Enum):
CHANGEPOINT_PRIOR_SCALE = "changepoint_prior_scale"
SEASONALITY_PRIOR_SCALE = "seasonality_prior_scale"
HOLIDAYS_PRIOR_SCALE = "holidays_prior_scale"
SEASONALITY_MODE = "seasonality_mode"


def _prophet_fit_predict(params: Dict[str, Any], history_pd: pd.DataFrame,
horizon: int, frequency: str, num_folds: int,
interval_width: int, primary_metric: str,
country_holidays: Optional[str] = None) -> Dict[str, Any]:
"""
Training function for hyperparameter tuning with hyperopt
:param params: Input hyperparameters
:param history_pd: pd.DataFrame containing the history. Must have columns ds (date
type) and y, the time series
:param horizon: Forecast horizon
:param frequency: Frequency of the time series
:param num_folds: Number of folds for cross validation
:param interval_width: Width of the uncertainty intervals provided for the forecast
:param primary_metric: Metric that will be optimized across trials
:param country_holidays: Built-in holidays for the specified country
:return: Dictionary as the format for hyperopt
"""
import pandas as pd
from prophet import Prophet
from prophet.diagnostics import cross_validation, performance_metrics

model = Prophet(interval_width=interval_width, **params)
if country_holidays:
model.add_country_holidays(country_name=country_holidays)
model.fit(history_pd, iter=200)

# Evaluate Metrics
horizon_timedelta = pd.to_timedelta(horizon, unit=frequency)
cutoffs = generate_cutoffs(model, horizon=horizon_timedelta, num_folds=num_folds)
# Disable tqdm to make it work with the ipykernel and reduce the output size
df_cv = cross_validation(model, horizon=horizon_timedelta, cutoffs=cutoffs, disable_tqdm=True)
df_metrics = performance_metrics(df_cv)

metrics = df_metrics.mean().drop("horizon").to_dict()

return {"loss": metrics[primary_metric], "metrics": metrics, "status": hyperopt.STATUS_OK}


class ProphetHyperoptEstimator(ABC):
"""
Class to do hyper-parameter tunings for prophet with hyperopt
"""
SUPPORTED_METRICS = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]

def __init__(self, horizon: int, frequency_unit: str, metric: str, interval_width: int,
country_holidays: str, search_space: Dict[str, Any],
algo=hyperopt.tpe.suggest, num_folds: int = 5,
max_eval: int = 10, trial_timeout: int = None,
random_state: int = 0, is_parallel: bool = True) -> None:
"""
Initialization
:param horizon: Number of periods to forecast forward
:param frequency_unit: Frequency of the time series
:param metric: Metric that will be optimized across trials
:param interval_width: Width of the uncertainty intervals provided for the forecast
:param country_holidays: Built-in holidays for the specified country
:param search_space: Search space for hyperparameter tuning with hyperopt
:param algo: Search algorithm
:param num_folds: Number of folds for cross validation
:param max_eval: Max number of trials generated in hyperopt
:param trial_timeout: timeout for hyperopt
:param random_state: random seed for hyperopt
:param is_parallel: Indicators to decide that whether run hyperopt in parallel
"""
self._horizon = horizon
self._frequency_unit = frequency_unit
self._metric = metric
self._interval_width = interval_width
self._country_holidays = country_holidays
self._search_space = search_space
self._algo = algo
self._num_folds = num_folds
self._random_state = np.random.RandomState(random_state)
self._max_eval = max_eval
self._timeout = trial_timeout
self._is_parallel = is_parallel

def fit(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Fit the Prophet model with hyperparameter tunings
:param df: pd.DataFrame containing the history. Must have columns ds (date
type) and y
:return: DataFrame with model json and metrics in cross validation
"""
import pandas as pd
from prophet import Prophet
from prophet.serialize import model_to_json
from hyperopt import fmin, Trials, SparkTrials

seasonality_mode = ["additive", "multiplicative"]
search_space = self._search_space
algo = self._algo

train_fn = partial(_prophet_fit_predict, history_pd=df, horizon=self._horizon,
frequency=self._frequency_unit, num_folds=self._num_folds,
interval_width=self._interval_width,
primary_metric=self._metric, country_holidays=self._country_holidays)

if self._is_parallel:
trials = SparkTrials() # pragma: no cover
else:
trials = Trials()

best_result = fmin(
fn=train_fn,
space=search_space,
algo=algo,
max_evals=self._max_eval,
trials=trials,
timeout=self._timeout,
rstate=self._random_state)

# Retrain the model with all history data.
model = Prophet(changepoint_prior_scale=best_result.get(ProphetHyperParams.CHANGEPOINT_PRIOR_SCALE, 0.05),
seasonality_prior_scale=best_result.get(ProphetHyperParams.SEASONALITY_PRIOR_SCALE, 10.0),
holidays_prior_scale=best_result.get(ProphetHyperParams.HOLIDAYS_PRIOR_SCALE, 10.0),
seasonality_mode=seasonality_mode[best_result.get(ProphetHyperParams.SEASONALITY_MODE, 0)],
interval_width=self._interval_width)

if self._country_holidays:
model.add_country_holidays(country_name=self._country_holidays)

model.fit(df)

model_json = model_to_json(model)
metrics = trials.best_trial["result"]["metrics"]

results_pd = pd.DataFrame({"model_json": model_json}, index=[0])
results_pd.reset_index(level=0, inplace=True)
for metric in self.SUPPORTED_METRICS:
results_pd[metric] = metrics[metric]
results_pd["prophet_params"] = str(best_result)

return results_pd
24 changes: 22 additions & 2 deletions runtime/databricks/automl_runtime/forecast/prophet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,26 @@
import pandas as pd
import prophet

OFFSET_ALIAS_MAP = {
"W": "W",
"d": "D",
"D": "D",
"days": "D",
"day": "D",
"hours": "H",
"hour": "H",
"hr": "H",
"h": "H",
"m": "min",
"minute": "min",
"min": "min",
"minutes": "min",
"T": "T",
"S": "S",
"seconds": "S",
"sec": "S",
"second": "S"
}

PROPHET_CONDA_ENV = {
"channels": ["conda-forge"],
Expand All @@ -28,7 +48,7 @@
"pip": [
f"prophet=={prophet.__version__}",
f"cloudpickle=={cloudpickle.__version__}",
f"databricks-automl-runtime==0.1.0"
f"databricks-automl-runtime==0.2.0"
]
}
],
Expand Down Expand Up @@ -146,7 +166,7 @@ def _make_future_dataframe(self, id: str, horizon: int) -> pd.DataFrame:
date_rng = pd.date_range(
start=start_time,
end=end_time + pd.Timedelta(value=horizon, unit=self._frequency),
freq=self._frequency
freq=OFFSET_ALIAS_MAP[self._frequency]
)
return pd.DataFrame(date_rng, columns=["ds"])

Expand Down
30 changes: 30 additions & 0 deletions runtime/databricks/automl_runtime/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# Copyright (C) 2021 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import wrapt


def fail_safe_with_default(default_result):
"""
Decorator to ensure that individual failures don't fail training
"""
@wrapt.decorator
def fail_safe(func, self, args, kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"Encountered an exception: {repr(e)}")
return default_result
return fail_safe
2 changes: 1 addition & 1 deletion runtime/databricks/automl_runtime/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#

__version__ = "0.1.0" # pragma: no cover
__version__ = "0.2.0" # pragma: no cover
1 change: 1 addition & 0 deletions runtime/environment.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# * Keep dependencies sorted.

holidays==0.10.5.2
hyperopt==0.2.5
koalas==1.8.1
mlflow==1.20.1
numpy==1.19.2
Expand Down
1 change: 1 addition & 0 deletions runtime/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# * Keep dependencies sorted.

holidays
hyperopt
mlflow
numpy
pandas
Expand Down
53 changes: 53 additions & 0 deletions runtime/tests/automl_runtime/prophet/forecast_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#
# Copyright (C) 2021 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest
import pandas as pd
from hyperopt import hp

from databricks.automl_runtime.forecast.prophet.forecast import ProphetHyperoptEstimator


class TestProphetHyperoptEstimator(unittest.TestCase):

def setUp(self) -> None:
num_rows = 12
self.df = pd.concat([
pd.to_datetime(pd.Series(range(num_rows), name="ds").apply(lambda i: f"2020-07-{i+1}")),
pd.Series(range(num_rows), name="y")
], axis=1)
self.search_space = {"changepoint_prior_scale": hp.loguniform("changepoint_prior_scale", -6.9, -0.69)}

def test_sequential_training(self):
hyperopt_estim = ProphetHyperoptEstimator(horizon=1,
frequency_unit="d",
metric="smape",
interval_width=0.8,
country_holidays="US",
search_space=self.search_space,
num_folds=2,
trial_timeout=1000,
random_state=0,
is_parallel=False)

results = hyperopt_estim.fit(self.df)
self.assertAlmostEqual(results["mse"][0], 0)
self.assertAlmostEqual(results["rmse"][0], 0)
self.assertAlmostEqual(results["mae"][0], 0)
self.assertAlmostEqual(results["mape"][0], 0)
self.assertAlmostEqual(results["mdape"][0], 0)
self.assertAlmostEqual(results["smape"][0], 0)
self.assertAlmostEqual(results["coverage"][0], 1)
2 changes: 1 addition & 1 deletion runtime/tests/automl_runtime/prophet/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_model_save_and_load_multi_series(self):
multi_series_model_json = {"1": model_json, "2": model_json}
multi_series_start = {"1": pd.Timestamp("2020-07-01"), "2": pd.Timestamp("2020-07-01")}
prophet_model = MultiSeriesProphetModel(multi_series_model_json, multi_series_start,
"2020-07-25", 1, "d")
"2020-07-25", 1, "days")
with mlflow.start_run() as run:
mlflow_prophet_log_model(prophet_model)

Expand Down
29 changes: 29 additions & 0 deletions runtime/tests/automl_runtime/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# Copyright (C) 2021 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from databricks.automl_runtime.utils import fail_safe_with_default


@fail_safe_with_default(1)
def failed_function():
raise Exception()


class TestUtilFunctions(unittest.TestCase):
def test_failed_functions(self):
result = failed_function()
self.assertEqual(result, 1)

0 comments on commit 5bafcf1

Please sign in to comment.