-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML-44835] Add DeepAR mlflow model logging (#146)
* Create mlflow pythonmodel for DeepAR * Fix predict bugs * Unit tests * Remove unused imports * Add dependencies to environment.txt * Fix torch version
- Loading branch information
Showing
8 changed files
with
299 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
137 changes: 137 additions & 0 deletions
137
runtime/databricks/automl_runtime/forecast/deepar/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# | ||
# Copyright (C) 2024 Databricks, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from typing import List, Optional | ||
|
||
import gluonts | ||
import mlflow | ||
import pandas as pd | ||
from gluonts.dataset.pandas import PandasDataset | ||
from gluonts.torch.model.predictor import PyTorchPredictor | ||
from mlflow.utils.environment import _mlflow_conda_env | ||
|
||
from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model | ||
|
||
DEEPAR_CONDA_ENV = _mlflow_conda_env( | ||
additional_pip_deps=[ | ||
f"gluonts[torch]=={gluonts.__version__}", | ||
f"pandas=={pd.__version__}", | ||
] | ||
) | ||
|
||
|
||
class DeepARModel(ForecastModel): | ||
""" | ||
DeepAR mlflow model wrapper for forecasting. | ||
""" | ||
|
||
def __init__(self, model: PyTorchPredictor, horizon: int, num_samples: int, | ||
target_col: str, time_col: str, | ||
id_cols: Optional[List[str]] = None) -> None: | ||
""" | ||
Initialize the DeepAR mlflow Python model wrapper | ||
:param model: DeepAR model | ||
:param horizon: the number of periods to forecast forward | ||
:param num_samples: the number of samples to draw from the distribution | ||
:param target_col: the target column name | ||
:param time_col: the time column name | ||
:param id_cols: the column names of the identity columns for multi-series time series; None for single series | ||
""" | ||
|
||
# TODO: combine id_cols in predict() to ts_id when there are multiple id_cols | ||
if id_cols and len(id_cols) > 1: | ||
raise NotImplementedError("Logging multiple id_cols for DeepAR in AutoML are not supported yet") | ||
|
||
super().__init__() | ||
self._model = model | ||
self._horizon = horizon | ||
self._num_samples = num_samples | ||
self._target_col = target_col | ||
self._time_col = time_col | ||
self._id_cols = id_cols | ||
|
||
@property | ||
def model_env(self): | ||
return DEEPAR_CONDA_ENV | ||
|
||
def predict(self, | ||
context: mlflow.pyfunc.model.PythonModelContext, | ||
model_input: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Predict the future dataframe given the history dataframe | ||
:param context: A :class:`~PythonModelContext` instance containing artifacts that the model | ||
can use to perform inference. | ||
:param model_input: Input dataframe that contains the history data | ||
:return: predicted pd.DataFrame that starts after the last timestamp in the input dataframe, | ||
and predicts the horizon using the mean of the samples | ||
""" | ||
required_cols = [self._target_col, self._time_col] | ||
if self._id_cols: | ||
required_cols += self._id_cols | ||
self._validate_cols(model_input, required_cols) | ||
|
||
forecast_sample_list = self.predict_samples(model_input, num_samples=self._num_samples) | ||
|
||
pred_df = pd.concat( | ||
[ | ||
forecast.mean_ts.rename('yhat').reset_index().assign(item_id=forecast.item_id) | ||
for forecast in forecast_sample_list | ||
], | ||
ignore_index=True | ||
) | ||
|
||
pred_df = pred_df.rename(columns={'index': self._time_col}) | ||
if self._id_cols: | ||
pred_df = pred_df.rename(columns={'item_id': self._id_cols[0]}) | ||
else: | ||
pred_df = pred_df.drop(columns='item_id') | ||
|
||
pred_df[self._time_col] = pred_df[self._time_col].dt.to_timestamp() | ||
|
||
return pred_df | ||
|
||
def predict_samples(self, | ||
model_input: pd.DataFrame, | ||
num_samples: int = None) -> List[gluonts.model.forecast.SampleForecast]: | ||
""" | ||
Predict the future samples given the history dataframe | ||
:param model_input: Input dataframe that contains the history data | ||
:param num_samples: the number of samples to draw from the distribution | ||
:return: List of SampleForecast, where each SampleForecast contains num_samples sampled forecasts | ||
""" | ||
if num_samples is None: | ||
num_samples = self._num_samples | ||
|
||
model_input = model_input.set_index(self._time_col) | ||
if self._id_cols: | ||
test_ds = PandasDataset.from_long_dataframe(model_input, target=self._target_col, | ||
item_id=self._id_cols[0], unchecked=True) | ||
else: | ||
test_ds = PandasDataset(model_input, target=self._target_col) | ||
|
||
forecast_iter = self._model.predict(test_ds, num_samples=num_samples) | ||
forecast_sample_list = list(forecast_iter) | ||
|
||
return forecast_sample_list | ||
|
||
|
||
def mlflow_deepar_log_model(deepar_model: DeepARModel, | ||
sample_input: pd.DataFrame = None) -> None: | ||
""" | ||
Log the DeepAR model to mlflow | ||
:param deepar_model: DeepAR mlflow PythonModel wrapper | ||
:param sample_input: sample input Dataframes for model inference | ||
""" | ||
mlflow_forecast_log_model(deepar_model, sample_input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
154 changes: 154 additions & 0 deletions
154
runtime/tests/automl_runtime/forecast/deepar/model_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# | ||
# Copyright (C) 2024 Databricks, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import unittest | ||
|
||
import mlflow | ||
import pandas as pd | ||
import torch | ||
import torch.nn as nn | ||
from gluonts.dataset.field_names import FieldName | ||
from gluonts.transform import InstanceSplitter, TestSplitSampler | ||
from gluonts.torch.model.predictor import PyTorchPredictor | ||
|
||
from databricks.automl_runtime.forecast.deepar.model import ( | ||
DeepARModel, mlflow_deepar_log_model, | ||
) | ||
|
||
|
||
class TestDeepARModel(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls) -> None: | ||
# Adapted from https://github.com/awslabs/gluonts/blob/dev/test/torch/model/test_torch_predictor.py | ||
class RandomNetwork(nn.Module): | ||
def __init__( | ||
self, | ||
prediction_length: int, | ||
context_length: int, | ||
) -> None: | ||
super().__init__() | ||
self.prediction_length = prediction_length | ||
self.context_length = context_length | ||
self.net = nn.Linear(context_length, prediction_length) | ||
torch.nn.init.uniform_(self.net.weight, -1.0, 1.0) | ||
|
||
def forward(self, past_target): | ||
out = self.net(past_target.float()) | ||
return out.unsqueeze(1) | ||
|
||
cls.context_length = 5 | ||
cls.prediction_length = 5 | ||
|
||
cls.pred_net = RandomNetwork( | ||
prediction_length=cls.context_length, context_length=cls.context_length | ||
) | ||
|
||
cls.transformation = InstanceSplitter( | ||
target_field=FieldName.TARGET, | ||
is_pad_field=FieldName.IS_PAD, | ||
start_field=FieldName.START, | ||
forecast_start_field=FieldName.FORECAST_START, | ||
instance_sampler=TestSplitSampler(), | ||
past_length=cls.context_length, | ||
future_length=cls.prediction_length, | ||
) | ||
|
||
cls.model = PyTorchPredictor( | ||
prediction_length=cls.prediction_length, | ||
input_names=["past_target"], | ||
prediction_net=cls.pred_net, | ||
batch_size=16, | ||
input_transform=cls.transformation, | ||
device="cpu", | ||
) | ||
|
||
def test_model_save_and_load_single_series(self): | ||
target_col = "sales" | ||
time_col = "date" | ||
|
||
deepar_model = DeepARModel( | ||
model=self.model, | ||
horizon=self.prediction_length, | ||
num_samples=1, | ||
target_col=target_col, | ||
time_col=time_col, | ||
) | ||
|
||
num_rows = 10 | ||
sample_input = pd.concat( | ||
[ | ||
pd.to_datetime( | ||
pd.Series(range(num_rows), name=time_col).apply( | ||
lambda i: f"2020-10-{3 * i + 1}" | ||
) | ||
), | ||
pd.Series(range(num_rows), name=target_col), | ||
], | ||
axis=1, | ||
) | ||
|
||
with mlflow.start_run() as run: | ||
mlflow_deepar_log_model(deepar_model, sample_input) | ||
|
||
run_id = run.info.run_id | ||
loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") | ||
|
||
pred_df = loaded_model.predict(sample_input) | ||
|
||
assert pred_df.columns.tolist() == [time_col, "yhat"] | ||
assert len(pred_df) == self.prediction_length | ||
assert pred_df[time_col].min() > sample_input[time_col].max() | ||
|
||
def test_model_save_and_load_multi_series(self): | ||
target_col = "sales" | ||
time_col = "date" | ||
id_col = "store" | ||
|
||
deepar_model = DeepARModel( | ||
model=self.model, | ||
horizon=self.prediction_length, | ||
num_samples=1, | ||
target_col=target_col, | ||
time_col=time_col, | ||
id_cols=[id_col], | ||
) | ||
|
||
num_rows = 10 | ||
sample_input_base = pd.concat( | ||
[ | ||
pd.to_datetime( | ||
pd.Series(range(num_rows), name=time_col).apply( | ||
lambda i: f"2020-10-{3 * i + 1}" | ||
) | ||
), | ||
pd.Series(range(num_rows), name=target_col), | ||
], | ||
axis=1, | ||
) | ||
sample_input = pd.concat([sample_input_base.copy(), sample_input_base.copy()], ignore_index=True) | ||
sample_input[id_col] = [1] * num_rows + [2] * num_rows | ||
|
||
with mlflow.start_run() as run: | ||
mlflow_deepar_log_model(deepar_model, sample_input) | ||
|
||
run_id = run.info.run_id | ||
loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") | ||
|
||
pred_df = loaded_model.predict(sample_input) | ||
|
||
assert pred_df.columns.tolist() == [time_col, "yhat", id_col] | ||
assert len(pred_df) == self.prediction_length * 2 | ||
assert pred_df[time_col].min() > sample_input[time_col].max() |