From d3ab8a70e7f3d93de600d682b18d31e934c5608a Mon Sep 17 00:00:00 2001 From: Mads Christian Lund Date: Mon, 6 Nov 2023 13:52:16 +0100 Subject: [PATCH] Fixed python38 support and empty DataFrame test --- .../qc/percentiles/outlier_detector.py | 14 +++--- src/pypromice/test/test_percentile.py | 45 ++++++++++++++++++- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/pypromice/qc/percentiles/outlier_detector.py b/src/pypromice/qc/percentiles/outlier_detector.py index dada58b6..2eba53dd 100644 --- a/src/pypromice/qc/percentiles/outlier_detector.py +++ b/src/pypromice/qc/percentiles/outlier_detector.py @@ -6,9 +6,7 @@ import xarray as xr -__all__ = [ - "ThresholdBasedOutlierDetector" -] +__all__ = ["ThresholdBasedOutlierDetector"] season_month_map = { "winter": {12, 1, 2}, @@ -18,7 +16,7 @@ } -def get_season_index_mask(data_set: pd.DataFrame, season: str) -> np.ndarray[bool]: +def get_season_index_mask(data_set: pd.DataFrame, season: str) -> np.ndarray: season_months = season_month_map.get( season, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} ) @@ -67,7 +65,7 @@ def filter_data(self, ds: xr.Dataset) -> xr.Dataset: stid = ds.station_id stid_thresholds = self.thresholds.query(f"stid == '{stid}'") - if not stid_thresholds.any(): + if stid_thresholds.empty: return ds data_df = ds.to_dataframe() # Switch to pandas @@ -84,7 +82,7 @@ def filter_data(self, ds: xr.Dataset) -> xr.Dataset: return ds_out @classmethod - def from_csv_config(cls, config_file: Path) -> 'ThresholdBasedOutlierDetector': + def from_csv_config(cls, config_file: Path) -> "ThresholdBasedOutlierDetector": """ Instantiate using explicit csv file with explicit thresholds @@ -103,12 +101,12 @@ def from_csv_config(cls, config_file: Path) -> 'ThresholdBasedOutlierDetector': return cls(thresholds=pd.read_csv(config_file)) @classmethod - def default(cls) -> 'ThresholdBasedOutlierDetector': + def default(cls) -> "ThresholdBasedOutlierDetector": """ Instantiate using aws thresholds stored in the python package. Returns ------- """ - default_thresholds_path = Path(__file__).parent.joinpath('thresholds.csv') + default_thresholds_path = Path(__file__).parent.joinpath("thresholds.csv") return cls.from_csv_config(default_thresholds_path) diff --git a/src/pypromice/test/test_percentile.py b/src/pypromice/test/test_percentile.py index 1f95cc11..dfb84560 100644 --- a/src/pypromice/test/test_percentile.py +++ b/src/pypromice/test/test_percentile.py @@ -4,8 +4,13 @@ import numpy as np import pandas as pd +import xarray as xr -from pypromice.qc.percentiles.outlier_detector import detect_outliers, filter_data +from pypromice.qc.percentiles.outlier_detector import ( + detect_outliers, + filter_data, + ThresholdBasedOutlierDetector, +) class PercentileQCTestCase(unittest.TestCase): @@ -184,3 +189,41 @@ def test_remove_outliers(self): output_data = filter_data(input_data, thresholds) self.assertIsNot(output_data, input_data) pd.testing.assert_frame_equal(output_data, expected_output_data) + + +class ThresholdBasedOutlierDetectorTestCase(unittest.TestCase): + def test_default_init(self): + outlier_detector = ThresholdBasedOutlierDetector.default() + self.assertIsInstance(outlier_detector, ThresholdBasedOutlierDetector) + + def test_filter_data_aws_with_threshold(self): + stid = "NUK_K" + index = pd.period_range("2023-10-01", "2023-11-01", freq="1h") + columns = ["p_i", "t_i", "p_l", "wpsd_u", "foo"] + dataset: xr.Dataset = pd.DataFrame( + index=index, + columns=columns, + data=np.random.random((len(index), len(columns))), + ).to_xarray() + dataset = dataset.assign_attrs(dict(station_id=stid)) + outlier_detector = ThresholdBasedOutlierDetector.default() + + dataset_output = outlier_detector.filter_data(dataset) + + self.assertIsInstance(dataset_output, xr.Dataset) + self.assertSetEqual( + set(dict(dataset.items())), + set(dict(dataset_output.items())), + ) + + pass + + def test_filter_data_aws_without_threshold(self): + stid = "non_exsiting" + dataset = xr.Dataset(attrs=dict(station_id=stid)) + outlier_detector = ThresholdBasedOutlierDetector.default() + self.assertNotIn(stid, outlier_detector.thresholds.stid) + + output_dataset = outlier_detector.filter_data(dataset) + + xr.testing.assert_equal(output_dataset, dataset)