From a65242bf0f3936730c48c17d8f0ec9eff172b9c4 Mon Sep 17 00:00:00 2001 From: Ying Chen Date: Mon, 18 Nov 2024 09:57:56 -0800 Subject: [PATCH] Take average to aggregate duplicate time_col for DeepAR (#155) * Take average to aggregate duplicate time_col for DeepAR * nit * clean up unused import * nit in test * fix * bump 0.2.20.5.dev1 -> 0.2.20.5 --- .../automl_runtime/forecast/deepar/model.py | 7 +++ runtime/databricks/automl_runtime/version.py | 2 +- .../forecast/deepar/model_test.py | 63 ++++++++++++++++--- 3 files changed, 63 insertions(+), 9 deletions(-) diff --git a/runtime/databricks/automl_runtime/forecast/deepar/model.py b/runtime/databricks/automl_runtime/forecast/deepar/model.py index c92aa8b..4ce75c9 100644 --- a/runtime/databricks/automl_runtime/forecast/deepar/model.py +++ b/runtime/databricks/automl_runtime/forecast/deepar/model.py @@ -119,6 +119,13 @@ def predict_samples(self, if num_samples is None: num_samples = self._num_samples + # Group by the time column in case there are multiple rows for each time column, + # for example, the user didn't provide all the identity columns for a multi-series dataset + group_cols = [self._time_col] + if self._id_cols: + group_cols += self._id_cols + model_input = model_input.groupby(group_cols).agg({self._target_col: "mean"}).reset_index() + model_input_transformed = set_index_and_fill_missing_time_steps(model_input, self._time_col, self._frequency, diff --git a/runtime/databricks/automl_runtime/version.py b/runtime/databricks/automl_runtime/version.py index 67b3bb3..d1e4e23 100644 --- a/runtime/databricks/automl_runtime/version.py +++ b/runtime/databricks/automl_runtime/version.py @@ -14,4 +14,4 @@ # limitations under the License. # -__version__ = "0.2.20.5.dev0" # pragma: no cover +__version__ = "0.2.20.5" # pragma: no cover diff --git a/runtime/tests/automl_runtime/forecast/deepar/model_test.py b/runtime/tests/automl_runtime/forecast/deepar/model_test.py index 1d3ed3e..d534d9b 100644 --- a/runtime/tests/automl_runtime/forecast/deepar/model_test.py +++ b/runtime/tests/automl_runtime/forecast/deepar/model_test.py @@ -77,6 +77,15 @@ def forward(self, past_target): device="cpu", ) + def _check_requirements(self, run_id: str): + # read requirements.txt from the run + requirements_path = mlflow.artifacts.download_artifacts(f"runs:/{run_id}/model/requirements.txt") + with open(requirements_path, "r") as f: + requirements = f.read() + # check if all additional dependencies are logged + for dependency in DEEPAR_ADDITIONAL_PIP_DEPS: + self.assertIn(dependency, requirements, f"requirements.txt should contain {dependency} but got {requirements}") + def test_model_save_and_load_single_series(self): target_col = "sales" time_col = "date" @@ -210,11 +219,49 @@ def test_model_save_and_load_multi_series_multi_id_cols(self): self.assertEqual(len(pred_df), self.prediction_length * 4) self.assertGreater(pred_df[time_col].min(), sample_input[time_col].max()) - def _check_requirements(self, run_id: str): - # read requirements.txt from the run - requirements_path = mlflow.artifacts.download_artifacts(f"runs:/{run_id}/model/requirements.txt") - with open(requirements_path, "r") as f: - requirements = f.read() - # check if all additional dependencies are logged - for dependency in DEEPAR_ADDITIONAL_PIP_DEPS: - self.assertIn(dependency, requirements, f"requirements.txt should contain {dependency} but got {requirements}") + def test_model_prediction_with_duplicate_timestamps(self): + """ + Test that the model correctly handles and averages multiple rows with the same timestamp + when identity columns are not provided. + """ + target_col = "sales" + time_col = "date" + + deepar_model = DeepARModel( + model=self.model, + horizon=self.prediction_length, + frequency="d", + num_samples=1, + target_col=target_col, + time_col=time_col, + ) + + # Create sample input with duplicate timestamps + dates = pd.to_datetime([ + "2020-10-01", "2020-10-01", # duplicate date with different values + "2020-10-04", "2020-10-04", "2020-10-04", # triple duplicate + "2020-10-07" # single entry + ]) + + sales = [10, 20, # should average to 15 + 30, 60, 90, # should average to 60 + 100] # single value stays 100 + + sample_input = pd.DataFrame({ + time_col: dates, + target_col: sales + }) + + with mlflow.start_run() as run: + mlflow_deepar_log_model(deepar_model, sample_input) + + run_id = run.info.run_id + + # Load the model and predict + loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") + pred_df = loaded_model.predict(sample_input) + + # Verify the prediction output format + self.assertEqual(pred_df.columns.tolist(), [time_col, "yhat"]) + self.assertEqual(len(pred_df), self.prediction_length) + self.assertGreater(pred_df[time_col].min(), sample_input[time_col].max())