Skip to content

Commit

Permalink
[ML-44835] Add DeepAR mlflow model logging (#146)
Browse files Browse the repository at this point in the history
* Create mlflow pythonmodel for DeepAR

* Fix predict bugs

* Unit tests

* Remove unused imports

* Add dependencies to environment.txt

* Fix torch version
  • Loading branch information
es94129 authored Sep 25, 2024
1 parent 90a5cc3 commit 0810c44
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 1 deletion.
Empty file.
137 changes: 137 additions & 0 deletions runtime/databricks/automl_runtime/forecast/deepar/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#
# Copyright (C) 2024 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional

import gluonts
import mlflow
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.torch.model.predictor import PyTorchPredictor
from mlflow.utils.environment import _mlflow_conda_env

from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model

DEEPAR_CONDA_ENV = _mlflow_conda_env(
additional_pip_deps=[
f"gluonts[torch]=={gluonts.__version__}",
f"pandas=={pd.__version__}",
]
)


class DeepARModel(ForecastModel):
"""
DeepAR mlflow model wrapper for forecasting.
"""

def __init__(self, model: PyTorchPredictor, horizon: int, num_samples: int,
target_col: str, time_col: str,
id_cols: Optional[List[str]] = None) -> None:
"""
Initialize the DeepAR mlflow Python model wrapper
:param model: DeepAR model
:param horizon: the number of periods to forecast forward
:param num_samples: the number of samples to draw from the distribution
:param target_col: the target column name
:param time_col: the time column name
:param id_cols: the column names of the identity columns for multi-series time series; None for single series
"""

# TODO: combine id_cols in predict() to ts_id when there are multiple id_cols
if id_cols and len(id_cols) > 1:
raise NotImplementedError("Logging multiple id_cols for DeepAR in AutoML are not supported yet")

super().__init__()
self._model = model
self._horizon = horizon
self._num_samples = num_samples
self._target_col = target_col
self._time_col = time_col
self._id_cols = id_cols

@property
def model_env(self):
return DEEPAR_CONDA_ENV

def predict(self,
context: mlflow.pyfunc.model.PythonModelContext,
model_input: pd.DataFrame) -> pd.DataFrame:
"""
Predict the future dataframe given the history dataframe
:param context: A :class:`~PythonModelContext` instance containing artifacts that the model
can use to perform inference.
:param model_input: Input dataframe that contains the history data
:return: predicted pd.DataFrame that starts after the last timestamp in the input dataframe,
and predicts the horizon using the mean of the samples
"""
required_cols = [self._target_col, self._time_col]
if self._id_cols:
required_cols += self._id_cols
self._validate_cols(model_input, required_cols)

forecast_sample_list = self.predict_samples(model_input, num_samples=self._num_samples)

pred_df = pd.concat(
[
forecast.mean_ts.rename('yhat').reset_index().assign(item_id=forecast.item_id)
for forecast in forecast_sample_list
],
ignore_index=True
)

pred_df = pred_df.rename(columns={'index': self._time_col})
if self._id_cols:
pred_df = pred_df.rename(columns={'item_id': self._id_cols[0]})
else:
pred_df = pred_df.drop(columns='item_id')

pred_df[self._time_col] = pred_df[self._time_col].dt.to_timestamp()

return pred_df

def predict_samples(self,
model_input: pd.DataFrame,
num_samples: int = None) -> List[gluonts.model.forecast.SampleForecast]:
"""
Predict the future samples given the history dataframe
:param model_input: Input dataframe that contains the history data
:param num_samples: the number of samples to draw from the distribution
:return: List of SampleForecast, where each SampleForecast contains num_samples sampled forecasts
"""
if num_samples is None:
num_samples = self._num_samples

model_input = model_input.set_index(self._time_col)
if self._id_cols:
test_ds = PandasDataset.from_long_dataframe(model_input, target=self._target_col,
item_id=self._id_cols[0], unchecked=True)
else:
test_ds = PandasDataset(model_input, target=self._target_col)

forecast_iter = self._model.predict(test_ds, num_samples=num_samples)
forecast_sample_list = list(forecast_iter)

return forecast_sample_list


def mlflow_deepar_log_model(deepar_model: DeepARModel,
sample_input: pd.DataFrame = None) -> None:
"""
Log the DeepAR model to mlflow
:param deepar_model: DeepAR mlflow PythonModel wrapper
:param sample_input: sample input Dataframes for model inference
"""
mlflow_forecast_log_model(deepar_model, sample_input)
2 changes: 1 addition & 1 deletion runtime/databricks/automl_runtime/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#

__version__ = "0.2.20.1" # pragma: no cover
__version__ = "0.2.20.2.dev0" # pragma: no cover
3 changes: 3 additions & 0 deletions runtime/environment.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# * Keep dependencies sorted.

category_encoders==2.6.0
gluonts[torch]==0.15.1
holidays==0.28
hyperopt==0.2.7
lightning==2.0.1
mlflow==2.0.1
numpy==1.23.5
pandas==1.5.3
pmdarima==2.0.3
prophet==1.1.4
pyarrow==8.0.0
scikit-learn==1.1.1
torch==2.0.1
wrapt==1.14.1
3 changes: 3 additions & 0 deletions runtime/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# * Keep dependencies sorted.

category_encoders
gluonts
holidays
hyperopt
mlflow
Expand All @@ -13,4 +14,6 @@ prophet
pyarrow
requests
scikit-learn
torch
lightning
wrapt
1 change: 1 addition & 0 deletions runtime/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"databricks",
"databricks.automl_runtime",
"databricks.automl_runtime.forecast",
"databricks.automl_runtime.forecast.deepar",
"databricks.automl_runtime.forecast.pmdarima",
"databricks.automl_runtime.forecast.prophet",
"databricks.automl_runtime.hyperopt",
Expand Down
Empty file.
154 changes: 154 additions & 0 deletions runtime/tests/automl_runtime/forecast/deepar/model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#
# Copyright (C) 2024 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

import mlflow
import pandas as pd
import torch
import torch.nn as nn
from gluonts.dataset.field_names import FieldName
from gluonts.transform import InstanceSplitter, TestSplitSampler
from gluonts.torch.model.predictor import PyTorchPredictor

from databricks.automl_runtime.forecast.deepar.model import (
DeepARModel, mlflow_deepar_log_model,
)


class TestDeepARModel(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
# Adapted from https://github.com/awslabs/gluonts/blob/dev/test/torch/model/test_torch_predictor.py
class RandomNetwork(nn.Module):
def __init__(
self,
prediction_length: int,
context_length: int,
) -> None:
super().__init__()
self.prediction_length = prediction_length
self.context_length = context_length
self.net = nn.Linear(context_length, prediction_length)
torch.nn.init.uniform_(self.net.weight, -1.0, 1.0)

def forward(self, past_target):
out = self.net(past_target.float())
return out.unsqueeze(1)

cls.context_length = 5
cls.prediction_length = 5

cls.pred_net = RandomNetwork(
prediction_length=cls.context_length, context_length=cls.context_length
)

cls.transformation = InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=TestSplitSampler(),
past_length=cls.context_length,
future_length=cls.prediction_length,
)

cls.model = PyTorchPredictor(
prediction_length=cls.prediction_length,
input_names=["past_target"],
prediction_net=cls.pred_net,
batch_size=16,
input_transform=cls.transformation,
device="cpu",
)

def test_model_save_and_load_single_series(self):
target_col = "sales"
time_col = "date"

deepar_model = DeepARModel(
model=self.model,
horizon=self.prediction_length,
num_samples=1,
target_col=target_col,
time_col=time_col,
)

num_rows = 10
sample_input = pd.concat(
[
pd.to_datetime(
pd.Series(range(num_rows), name=time_col).apply(
lambda i: f"2020-10-{3 * i + 1}"
)
),
pd.Series(range(num_rows), name=target_col),
],
axis=1,
)

with mlflow.start_run() as run:
mlflow_deepar_log_model(deepar_model, sample_input)

run_id = run.info.run_id
loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")

pred_df = loaded_model.predict(sample_input)

assert pred_df.columns.tolist() == [time_col, "yhat"]
assert len(pred_df) == self.prediction_length
assert pred_df[time_col].min() > sample_input[time_col].max()

def test_model_save_and_load_multi_series(self):
target_col = "sales"
time_col = "date"
id_col = "store"

deepar_model = DeepARModel(
model=self.model,
horizon=self.prediction_length,
num_samples=1,
target_col=target_col,
time_col=time_col,
id_cols=[id_col],
)

num_rows = 10
sample_input_base = pd.concat(
[
pd.to_datetime(
pd.Series(range(num_rows), name=time_col).apply(
lambda i: f"2020-10-{3 * i + 1}"
)
),
pd.Series(range(num_rows), name=target_col),
],
axis=1,
)
sample_input = pd.concat([sample_input_base.copy(), sample_input_base.copy()], ignore_index=True)
sample_input[id_col] = [1] * num_rows + [2] * num_rows

with mlflow.start_run() as run:
mlflow_deepar_log_model(deepar_model, sample_input)

run_id = run.info.run_id
loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")

pred_df = loaded_model.predict(sample_input)

assert pred_df.columns.tolist() == [time_col, "yhat", id_col]
assert len(pred_df) == self.prediction_length * 2
assert pred_df[time_col].min() > sample_input[time_col].max()

0 comments on commit 0810c44

Please sign in to comment.