Skip to content

Commit

Permalink
Fixed python38 support and empty DataFrame test
Browse files Browse the repository at this point in the history
  • Loading branch information
ladsmund committed Nov 6, 2023
1 parent 475805f commit d3ab8a7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
14 changes: 6 additions & 8 deletions src/pypromice/qc/percentiles/outlier_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import xarray as xr


__all__ = [
"ThresholdBasedOutlierDetector"
]
__all__ = ["ThresholdBasedOutlierDetector"]

season_month_map = {
"winter": {12, 1, 2},
Expand All @@ -18,7 +16,7 @@
}


def get_season_index_mask(data_set: pd.DataFrame, season: str) -> np.ndarray[bool]:
def get_season_index_mask(data_set: pd.DataFrame, season: str) -> np.ndarray:
season_months = season_month_map.get(
season, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
)
Expand Down Expand Up @@ -67,7 +65,7 @@ def filter_data(self, ds: xr.Dataset) -> xr.Dataset:
stid = ds.station_id

stid_thresholds = self.thresholds.query(f"stid == '{stid}'")
if not stid_thresholds.any():
if stid_thresholds.empty:
return ds

data_df = ds.to_dataframe() # Switch to pandas
Expand All @@ -84,7 +82,7 @@ def filter_data(self, ds: xr.Dataset) -> xr.Dataset:
return ds_out

@classmethod
def from_csv_config(cls, config_file: Path) -> 'ThresholdBasedOutlierDetector':
def from_csv_config(cls, config_file: Path) -> "ThresholdBasedOutlierDetector":
"""
Instantiate using explicit csv file with explicit thresholds
Expand All @@ -103,12 +101,12 @@ def from_csv_config(cls, config_file: Path) -> 'ThresholdBasedOutlierDetector':
return cls(thresholds=pd.read_csv(config_file))

@classmethod
def default(cls) -> 'ThresholdBasedOutlierDetector':
def default(cls) -> "ThresholdBasedOutlierDetector":
"""
Instantiate using aws thresholds stored in the python package.
Returns
-------
"""
default_thresholds_path = Path(__file__).parent.joinpath('thresholds.csv')
default_thresholds_path = Path(__file__).parent.joinpath("thresholds.csv")
return cls.from_csv_config(default_thresholds_path)
45 changes: 44 additions & 1 deletion src/pypromice/test/test_percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@

import numpy as np
import pandas as pd
import xarray as xr

from pypromice.qc.percentiles.outlier_detector import detect_outliers, filter_data
from pypromice.qc.percentiles.outlier_detector import (
detect_outliers,
filter_data,
ThresholdBasedOutlierDetector,
)


class PercentileQCTestCase(unittest.TestCase):
Expand Down Expand Up @@ -184,3 +189,41 @@ def test_remove_outliers(self):
output_data = filter_data(input_data, thresholds)
self.assertIsNot(output_data, input_data)
pd.testing.assert_frame_equal(output_data, expected_output_data)


class ThresholdBasedOutlierDetectorTestCase(unittest.TestCase):
def test_default_init(self):
outlier_detector = ThresholdBasedOutlierDetector.default()
self.assertIsInstance(outlier_detector, ThresholdBasedOutlierDetector)

def test_filter_data_aws_with_threshold(self):
stid = "NUK_K"
index = pd.period_range("2023-10-01", "2023-11-01", freq="1h")
columns = ["p_i", "t_i", "p_l", "wpsd_u", "foo"]
dataset: xr.Dataset = pd.DataFrame(
index=index,
columns=columns,
data=np.random.random((len(index), len(columns))),
).to_xarray()
dataset = dataset.assign_attrs(dict(station_id=stid))
outlier_detector = ThresholdBasedOutlierDetector.default()

dataset_output = outlier_detector.filter_data(dataset)

self.assertIsInstance(dataset_output, xr.Dataset)
self.assertSetEqual(
set(dict(dataset.items())),
set(dict(dataset_output.items())),
)

pass

def test_filter_data_aws_without_threshold(self):
stid = "non_exsiting"
dataset = xr.Dataset(attrs=dict(station_id=stid))
outlier_detector = ThresholdBasedOutlierDetector.default()
self.assertNotIn(stid, outlier_detector.thresholds.stid)

output_dataset = outlier_detector.filter_data(dataset)

xr.testing.assert_equal(output_dataset, dataset)

0 comments on commit d3ab8a7

Please sign in to comment.