Skip to content

Commit

Permalink
Take average to aggregate duplicate time_col for DeepAR (#155)
Browse files Browse the repository at this point in the history
* Take average to aggregate duplicate time_col for DeepAR

* nit

* clean up unused import

* nit in test

* fix

* bump 0.2.20.5.dev1 -> 0.2.20.5
  • Loading branch information
es94129 authored Nov 18, 2024
1 parent de71cd8 commit a65242b
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 9 deletions.
7 changes: 7 additions & 0 deletions runtime/databricks/automl_runtime/forecast/deepar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def predict_samples(self,
if num_samples is None:
num_samples = self._num_samples

# Group by the time column in case there are multiple rows for each time column,
# for example, the user didn't provide all the identity columns for a multi-series dataset
group_cols = [self._time_col]
if self._id_cols:
group_cols += self._id_cols
model_input = model_input.groupby(group_cols).agg({self._target_col: "mean"}).reset_index()

model_input_transformed = set_index_and_fill_missing_time_steps(model_input,
self._time_col,
self._frequency,
Expand Down
2 changes: 1 addition & 1 deletion runtime/databricks/automl_runtime/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#

__version__ = "0.2.20.5.dev0" # pragma: no cover
__version__ = "0.2.20.5" # pragma: no cover
63 changes: 55 additions & 8 deletions runtime/tests/automl_runtime/forecast/deepar/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def forward(self, past_target):
device="cpu",
)

def _check_requirements(self, run_id: str):
# read requirements.txt from the run
requirements_path = mlflow.artifacts.download_artifacts(f"runs:/{run_id}/model/requirements.txt")
with open(requirements_path, "r") as f:
requirements = f.read()
# check if all additional dependencies are logged
for dependency in DEEPAR_ADDITIONAL_PIP_DEPS:
self.assertIn(dependency, requirements, f"requirements.txt should contain {dependency} but got {requirements}")

def test_model_save_and_load_single_series(self):
target_col = "sales"
time_col = "date"
Expand Down Expand Up @@ -210,11 +219,49 @@ def test_model_save_and_load_multi_series_multi_id_cols(self):
self.assertEqual(len(pred_df), self.prediction_length * 4)
self.assertGreater(pred_df[time_col].min(), sample_input[time_col].max())

def _check_requirements(self, run_id: str):
# read requirements.txt from the run
requirements_path = mlflow.artifacts.download_artifacts(f"runs:/{run_id}/model/requirements.txt")
with open(requirements_path, "r") as f:
requirements = f.read()
# check if all additional dependencies are logged
for dependency in DEEPAR_ADDITIONAL_PIP_DEPS:
self.assertIn(dependency, requirements, f"requirements.txt should contain {dependency} but got {requirements}")
def test_model_prediction_with_duplicate_timestamps(self):
"""
Test that the model correctly handles and averages multiple rows with the same timestamp
when identity columns are not provided.
"""
target_col = "sales"
time_col = "date"

deepar_model = DeepARModel(
model=self.model,
horizon=self.prediction_length,
frequency="d",
num_samples=1,
target_col=target_col,
time_col=time_col,
)

# Create sample input with duplicate timestamps
dates = pd.to_datetime([
"2020-10-01", "2020-10-01", # duplicate date with different values
"2020-10-04", "2020-10-04", "2020-10-04", # triple duplicate
"2020-10-07" # single entry
])

sales = [10, 20, # should average to 15
30, 60, 90, # should average to 60
100] # single value stays 100

sample_input = pd.DataFrame({
time_col: dates,
target_col: sales
})

with mlflow.start_run() as run:
mlflow_deepar_log_model(deepar_model, sample_input)

run_id = run.info.run_id

# Load the model and predict
loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
pred_df = loaded_model.predict(sample_input)

# Verify the prediction output format
self.assertEqual(pred_df.columns.tolist(), [time_col, "yhat"])
self.assertEqual(len(pred_df), self.prediction_length)
self.assertGreater(pred_df[time_col].min(), sample_input[time_col].max())

0 comments on commit a65242b

Please sign in to comment.