diff --git a/runtime/databricks/automl_runtime/forecast/deepar/__init__.py b/runtime/databricks/automl_runtime/forecast/deepar/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runtime/databricks/automl_runtime/forecast/deepar/model.py b/runtime/databricks/automl_runtime/forecast/deepar/model.py new file mode 100644 index 00000000..b831c90f --- /dev/null +++ b/runtime/databricks/automl_runtime/forecast/deepar/model.py @@ -0,0 +1,137 @@ +# +# Copyright (C) 2024 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List, Optional + +import gluonts +import mlflow +import pandas as pd +from gluonts.dataset.pandas import PandasDataset +from gluonts.torch.model.predictor import PyTorchPredictor +from mlflow.utils.environment import _mlflow_conda_env + +from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model + +DEEPAR_CONDA_ENV = _mlflow_conda_env( + additional_pip_deps=[ + f"gluonts[torch]=={gluonts.__version__}", + f"pandas=={pd.__version__}", + ] +) + + +class DeepARModel(ForecastModel): + """ + DeepAR mlflow model wrapper for forecasting. + """ + + def __init__(self, model: PyTorchPredictor, horizon: int, num_samples: int, + target_col: str, time_col: str, + id_cols: Optional[List[str]] = None) -> None: + """ + Initialize the DeepAR mlflow Python model wrapper + :param model: DeepAR model + :param horizon: the number of periods to forecast forward + :param num_samples: the number of samples to draw from the distribution + :param target_col: the target column name + :param time_col: the time column name + :param id_cols: the column names of the identity columns for multi-series time series; None for single series + """ + + # TODO: combine id_cols in predict() to ts_id when there are multiple id_cols + if id_cols and len(id_cols) > 1: + raise NotImplementedError("Logging multiple id_cols for DeepAR in AutoML are not supported yet") + + super().__init__() + self._model = model + self._horizon = horizon + self._num_samples = num_samples + self._target_col = target_col + self._time_col = time_col + self._id_cols = id_cols + + @property + def model_env(self): + return DEEPAR_CONDA_ENV + + def predict(self, + context: mlflow.pyfunc.model.PythonModelContext, + model_input: pd.DataFrame) -> pd.DataFrame: + """ + Predict the future dataframe given the history dataframe + :param context: A :class:`~PythonModelContext` instance containing artifacts that the model + can use to perform inference. + :param model_input: Input dataframe that contains the history data + :return: predicted pd.DataFrame that starts after the last timestamp in the input dataframe, + and predicts the horizon using the mean of the samples + """ + required_cols = [self._target_col, self._time_col] + if self._id_cols: + required_cols += self._id_cols + self._validate_cols(model_input, required_cols) + + forecast_sample_list = self.predict_samples(model_input, num_samples=self._num_samples) + + pred_df = pd.concat( + [ + forecast.mean_ts.rename('yhat').reset_index().assign(item_id=forecast.item_id) + for forecast in forecast_sample_list + ], + ignore_index=True + ) + + pred_df = pred_df.rename(columns={'index': self._time_col}) + if self._id_cols: + pred_df = pred_df.rename(columns={'item_id': self._id_cols[0]}) + else: + pred_df = pred_df.drop(columns='item_id') + + pred_df[self._time_col] = pred_df[self._time_col].dt.to_timestamp() + + return pred_df + + def predict_samples(self, + model_input: pd.DataFrame, + num_samples: int = None) -> List[gluonts.model.forecast.SampleForecast]: + """ + Predict the future samples given the history dataframe + :param model_input: Input dataframe that contains the history data + :param num_samples: the number of samples to draw from the distribution + :return: List of SampleForecast, where each SampleForecast contains num_samples sampled forecasts + """ + if num_samples is None: + num_samples = self._num_samples + + model_input = model_input.set_index(self._time_col) + if self._id_cols: + test_ds = PandasDataset.from_long_dataframe(model_input, target=self._target_col, + item_id=self._id_cols[0], unchecked=True) + else: + test_ds = PandasDataset(model_input, target=self._target_col) + + forecast_iter = self._model.predict(test_ds, num_samples=num_samples) + forecast_sample_list = list(forecast_iter) + + return forecast_sample_list + + +def mlflow_deepar_log_model(deepar_model: DeepARModel, + sample_input: pd.DataFrame = None) -> None: + """ + Log the DeepAR model to mlflow + :param deepar_model: DeepAR mlflow PythonModel wrapper + :param sample_input: sample input Dataframes for model inference + """ + mlflow_forecast_log_model(deepar_model, sample_input) diff --git a/runtime/databricks/automl_runtime/version.py b/runtime/databricks/automl_runtime/version.py index 574eeb2e..353f93e0 100644 --- a/runtime/databricks/automl_runtime/version.py +++ b/runtime/databricks/automl_runtime/version.py @@ -14,4 +14,4 @@ # limitations under the License. # -__version__ = "0.2.20.1" # pragma: no cover +__version__ = "0.2.20.2.dev0" # pragma: no cover diff --git a/runtime/environment.txt b/runtime/environment.txt index a52cc682..e286b7fb 100644 --- a/runtime/environment.txt +++ b/runtime/environment.txt @@ -2,8 +2,10 @@ # * Keep dependencies sorted. category_encoders==2.6.0 +gluonts[torch]==0.15.1 holidays==0.28 hyperopt==0.2.7 +lightning==2.0.1 mlflow==2.0.1 numpy==1.23.5 pandas==1.5.3 @@ -11,4 +13,5 @@ pmdarima==2.0.3 prophet==1.1.4 pyarrow==8.0.0 scikit-learn==1.1.1 +torch==2.0.1 wrapt==1.14.1 diff --git a/runtime/requirements.txt b/runtime/requirements.txt index c601c6de..d612c67b 100644 --- a/runtime/requirements.txt +++ b/runtime/requirements.txt @@ -2,6 +2,7 @@ # * Keep dependencies sorted. category_encoders +gluonts holidays hyperopt mlflow @@ -13,4 +14,6 @@ prophet pyarrow requests scikit-learn +torch +lightning wrapt diff --git a/runtime/setup.py b/runtime/setup.py index d379e94e..17f162c7 100644 --- a/runtime/setup.py +++ b/runtime/setup.py @@ -46,6 +46,7 @@ "databricks", "databricks.automl_runtime", "databricks.automl_runtime.forecast", + "databricks.automl_runtime.forecast.deepar", "databricks.automl_runtime.forecast.pmdarima", "databricks.automl_runtime.forecast.prophet", "databricks.automl_runtime.hyperopt", diff --git a/runtime/tests/automl_runtime/forecast/deepar/__init__.py b/runtime/tests/automl_runtime/forecast/deepar/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runtime/tests/automl_runtime/forecast/deepar/model_test.py b/runtime/tests/automl_runtime/forecast/deepar/model_test.py new file mode 100644 index 00000000..f896688a --- /dev/null +++ b/runtime/tests/automl_runtime/forecast/deepar/model_test.py @@ -0,0 +1,154 @@ +# +# Copyright (C) 2024 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import mlflow +import pandas as pd +import torch +import torch.nn as nn +from gluonts.dataset.field_names import FieldName +from gluonts.transform import InstanceSplitter, TestSplitSampler +from gluonts.torch.model.predictor import PyTorchPredictor + +from databricks.automl_runtime.forecast.deepar.model import ( + DeepARModel, mlflow_deepar_log_model, +) + + +class TestDeepARModel(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + # Adapted from https://github.com/awslabs/gluonts/blob/dev/test/torch/model/test_torch_predictor.py + class RandomNetwork(nn.Module): + def __init__( + self, + prediction_length: int, + context_length: int, + ) -> None: + super().__init__() + self.prediction_length = prediction_length + self.context_length = context_length + self.net = nn.Linear(context_length, prediction_length) + torch.nn.init.uniform_(self.net.weight, -1.0, 1.0) + + def forward(self, past_target): + out = self.net(past_target.float()) + return out.unsqueeze(1) + + cls.context_length = 5 + cls.prediction_length = 5 + + cls.pred_net = RandomNetwork( + prediction_length=cls.context_length, context_length=cls.context_length + ) + + cls.transformation = InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + instance_sampler=TestSplitSampler(), + past_length=cls.context_length, + future_length=cls.prediction_length, + ) + + cls.model = PyTorchPredictor( + prediction_length=cls.prediction_length, + input_names=["past_target"], + prediction_net=cls.pred_net, + batch_size=16, + input_transform=cls.transformation, + device="cpu", + ) + + def test_model_save_and_load_single_series(self): + target_col = "sales" + time_col = "date" + + deepar_model = DeepARModel( + model=self.model, + horizon=self.prediction_length, + num_samples=1, + target_col=target_col, + time_col=time_col, + ) + + num_rows = 10 + sample_input = pd.concat( + [ + pd.to_datetime( + pd.Series(range(num_rows), name=time_col).apply( + lambda i: f"2020-10-{3 * i + 1}" + ) + ), + pd.Series(range(num_rows), name=target_col), + ], + axis=1, + ) + + with mlflow.start_run() as run: + mlflow_deepar_log_model(deepar_model, sample_input) + + run_id = run.info.run_id + loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") + + pred_df = loaded_model.predict(sample_input) + + assert pred_df.columns.tolist() == [time_col, "yhat"] + assert len(pred_df) == self.prediction_length + assert pred_df[time_col].min() > sample_input[time_col].max() + + def test_model_save_and_load_multi_series(self): + target_col = "sales" + time_col = "date" + id_col = "store" + + deepar_model = DeepARModel( + model=self.model, + horizon=self.prediction_length, + num_samples=1, + target_col=target_col, + time_col=time_col, + id_cols=[id_col], + ) + + num_rows = 10 + sample_input_base = pd.concat( + [ + pd.to_datetime( + pd.Series(range(num_rows), name=time_col).apply( + lambda i: f"2020-10-{3 * i + 1}" + ) + ), + pd.Series(range(num_rows), name=target_col), + ], + axis=1, + ) + sample_input = pd.concat([sample_input_base.copy(), sample_input_base.copy()], ignore_index=True) + sample_input[id_col] = [1] * num_rows + [2] * num_rows + + with mlflow.start_run() as run: + mlflow_deepar_log_model(deepar_model, sample_input) + + run_id = run.info.run_id + loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") + + pred_df = loaded_model.predict(sample_input) + + assert pred_df.columns.tolist() == [time_col, "yhat", id_col] + assert len(pred_df) == self.prediction_length * 2 + assert pred_df[time_col].min() > sample_input[time_col].max()