From 84bf5b7300e54fc8a7735379869eae36fd942792 Mon Sep 17 00:00:00 2001 From: Matt Dancho Date: Fri, 20 Oct 2023 22:23:30 -0400 Subject: [PATCH] Check data was anomalized --- src/pytimetk/plot/plot_anomalies.py | 5 ++++- src/pytimetk/utils/checks.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/pytimetk/plot/plot_anomalies.py b/src/pytimetk/plot/plot_anomalies.py index c7e44e25..08e646eb 100644 --- a/src/pytimetk/plot/plot_anomalies.py +++ b/src/pytimetk/plot/plot_anomalies.py @@ -8,7 +8,7 @@ from pytimetk.plot.plot_timeseries import plot_timeseries from pytimetk.utils.plot_helpers import hex_to_rgba, rgba_to_hex, parse_rgba -from pytimetk.utils.checks import check_dataframe_or_groupby, check_date_column +from pytimetk.utils.checks import check_dataframe_or_groupby, check_date_column, check_anomalize_data from pytimetk.plot.theme import theme_timetk @@ -247,6 +247,9 @@ def plot_anomalies( check_dataframe_or_groupby(data) check_date_column(data, date_column) + # Check data was anomalized first + check_anomalize_data(data) + # Handle line_size if line_size is None: if engine == 'plotnine': diff --git a/src/pytimetk/utils/checks.py b/src/pytimetk/utils/checks.py index 9d8bd38a..45cc2cf7 100644 --- a/src/pytimetk/utils/checks.py +++ b/src/pytimetk/utils/checks.py @@ -6,6 +6,30 @@ from typing import Union, List +def check_anomalize_data(data: Union[pd.DataFrame, pd.core.groupby.generic.DataFrameGroupBy]) -> None: + + if isinstance(data, pd.core.groupby.generic.DataFrameGroupBy): + data = data.obj + + expected_colnames = [ + 'observed', + 'seasonal', + 'seasadj', + 'trend', + 'remainder', + 'anomaly', + 'anomaly_score', + 'anomaly_direction', + 'recomposed_l1', + 'recomposed_l2', + 'observed_clean' + ] + + if not all([column in data.columns for column in expected_colnames]): + raise ValueError(f"data does not have required colnames: {expected_colnames}. Did you run `anomalize()`?") + + return None + def check_dataframe_or_groupby(data: Union[pd.DataFrame, pd.core.groupby.generic.DataFrameGroupBy]) -> None: if not isinstance(data, pd.DataFrame):