-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented outlier detector class for handling configuration state
* Implemented core functionality for season based threshold filter * Implemented script to compute thresholds using historical L1 (inferred) data * Added default thresholds * Updated setup.py to allow thresholds data file
- Loading branch information
Showing
7 changed files
with
839 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
import sys | ||
from datetime import datetime | ||
|
||
import pandas as pd | ||
|
||
from pypromice.process import AWS | ||
from pathlib import Path | ||
import logging | ||
from pypromice.qc.github_data_issues import adjustTime, flagNAN, adjustData | ||
|
||
|
||
# %% | ||
logger = logging.getLogger("ComputeThreshold") | ||
|
||
|
||
# %% | ||
def compute_all_thresholds( | ||
station_thresholds_root: Path, | ||
thresholds_output_path: Path, | ||
aws_l0_repo_path: Path, | ||
start_time: datetime, | ||
end_time: datetime, | ||
): | ||
logger.info("Computing all thresholds for stations available in the L0 repository") | ||
logger.info(f"station_thresholds_root: {station_thresholds_root}") | ||
logger.info(f"thresholds_output_path: {thresholds_output_path}") | ||
logger.info(f"aws_l0_repo_path: {aws_l0_repo_path}") | ||
logger.info(f"start_time: {start_time}") | ||
logger.info(f"end_time: {end_time}") | ||
|
||
station_thresholds_root.mkdir(parents=True, exist_ok=True) | ||
|
||
# %% | ||
output_paths = [] | ||
for config_path in aws_l0_repo_path.glob("raw/config/*.toml"): | ||
stid = config_path.stem | ||
|
||
logger.info(f"Processing {stid}") | ||
data_path = aws_l0_repo_path.joinpath("raw", stid) | ||
output_path = station_thresholds_root.joinpath(f"{stid}.csv") | ||
try: | ||
if not output_path.exists(): | ||
threshold = find_thresholds( | ||
stid, | ||
config_path, | ||
data_path, | ||
start_time, | ||
end_time, | ||
) | ||
threshold.to_csv( | ||
path_or_buf=output_path, index=False, float_format="{:.2f}".format | ||
) | ||
output_paths.append(output_path) | ||
except Exception: | ||
logger.exception(f"Failed processing {stid}") | ||
continue | ||
|
||
logger.info("Merge threshold files") | ||
pd.concat(pd.read_csv(p) for p in output_paths).to_csv( | ||
thresholds_output_path, index=False, float_format="{:.2f}".format | ||
) | ||
|
||
|
||
def find_thresholds( | ||
stid: str, | ||
config_path: Path, | ||
data_path: Path, | ||
start_time: datetime, | ||
end_time: datetime, | ||
) -> pd.DataFrame: | ||
""" | ||
Compute variable threshold for a station using historical distribution quantiles. | ||
Parameters | ||
---------- | ||
stid | ||
config_path | ||
data_path | ||
start_time | ||
end_time | ||
Returns | ||
------- | ||
Upper and lower thresholds for a set of variables and seasons | ||
""" | ||
stid_logger = logger.getChild(stid) | ||
# %% | ||
|
||
stid_logger.info("Read AWS data and get L1") | ||
aws = AWS(config_file=config_path.as_posix(), inpath=data_path.as_posix()) | ||
aws.getL1() | ||
|
||
# %% | ||
stid_logger.info("Apply QC filters on data") | ||
ds = aws.L1A.copy(deep=True) # Reassign dataset | ||
ds = adjustTime(ds) # Adjust time after a user-defined csv files | ||
ds = flagNAN(ds) # Flag NaNs after a user-defined csv files | ||
ds = adjustData(ds) | ||
|
||
# %% | ||
stid_logger.info("Determine thresholds") | ||
df = ( | ||
ds[["rh_u", "wspd_u", "p_u", "t_u"]] | ||
.to_pandas() | ||
.loc[start_time:end_time] | ||
.assign(season=lambda df: (df.index.month // 3) % 4) | ||
) | ||
|
||
threshold_rows = [] | ||
|
||
# Pressure | ||
p_lo, p_hi = df["p_u"].quantile([0.005, 0.995]) + [-12, 12] | ||
threshold_rows.append( | ||
dict( | ||
stid=stid, | ||
variable_pattern="p_[ul]", | ||
lo=p_lo, | ||
hi=p_hi, | ||
) | ||
) | ||
threshold_rows.append( | ||
dict( | ||
stid=stid, | ||
variable_pattern="p_i", | ||
lo=p_lo - 1000, | ||
hi=p_hi - 1000, | ||
) | ||
) | ||
|
||
# Wind speed | ||
lo, hi = df["wspd_u"].quantile([0.005, 0.995]) + [-12, 12] | ||
threshold_rows.append( | ||
dict( | ||
stid=stid, | ||
variable_pattern="wspd_[uli]", | ||
lo=lo, | ||
hi=hi, | ||
) | ||
) | ||
|
||
# Temperature | ||
season_map = ["winter", "spring", "summer", "fall"] | ||
for season_index, season_df in df[["t_u", "season"]].groupby( | ||
(df.index.month // 3) % 4 | ||
): | ||
lo, hi = season_df.quantile([0.005, 0.995])["t_u"] + [-9, 9] | ||
|
||
threshold_rows.append( | ||
dict( | ||
stid=stid, | ||
variable_pattern="t_[uli]", | ||
season=season_map[season_index], | ||
lo=lo, | ||
hi=hi, | ||
) | ||
) | ||
|
||
threshold = pd.DataFrame(threshold_rows) | ||
stid_logger.info(threshold) | ||
return threshold | ||
# %% | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--l0", | ||
required=True, | ||
type=Path, | ||
help="L0 repository root path", | ||
) | ||
parser.add_argument( | ||
"--thresholds_output_path", | ||
"-o", | ||
default=Path(__file__).parent.joinpath("thresholds.csv"), | ||
type=Path, | ||
help="Output csv file with thresholds for all stations", | ||
) | ||
parser.add_argument( | ||
"--station_thresholds_root", | ||
"--str", | ||
default=Path(__file__).parent.joinpath("station_thresholds"), | ||
type=Path, | ||
help="Directory containing threshold files for the individual stations", | ||
) | ||
parser.add_argument( | ||
"--start_time", | ||
default="2000-01-01", | ||
help="Start time for data series. Format: %Y-%m-%d", | ||
) | ||
parser.add_argument( | ||
"--end_time", | ||
default="2023-10-01", | ||
help="End time for data series. Format: %Y-%m-%d", | ||
) | ||
args = parser.parse_args() | ||
|
||
logging.basicConfig( | ||
format="%(asctime)s; %(levelname)s; %(name)s; %(message)s", | ||
level=logging.INFO, | ||
stream=sys.stdout, | ||
) | ||
|
||
thresholds_output_path = args.thresholds_output_path | ||
station_thresholds_root = args.station_thresholds_root | ||
|
||
start_time = datetime.strptime(args.start_time, "%Y-%m-%d") | ||
end_time = datetime.strptime(args.end_time, "%Y-%m-%d") | ||
aws_l0_repo_path = args.l0 | ||
|
||
compute_all_thresholds( | ||
station_thresholds_root=station_thresholds_root, | ||
thresholds_output_path=thresholds_output_path, | ||
aws_l0_repo_path=aws_l0_repo_path, | ||
start_time=start_time, | ||
end_time=end_time, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from pathlib import Path | ||
|
||
import attrs | ||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
|
||
|
||
__all__ = [ | ||
"ThresholdBasedOutlierDetector" | ||
] | ||
|
||
season_month_map = { | ||
"winter": {12, 1, 2}, | ||
"spring": {3, 4, 5}, | ||
"summer": {6, 7, 8}, | ||
"fall": {9, 10, 11}, | ||
} | ||
|
||
|
||
def get_season_index_mask(data_set: pd.DataFrame, season: str) -> np.ndarray[bool]: | ||
season_months = season_month_map.get( | ||
season, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} | ||
) | ||
return data_set.index.month.isin(season_months)[:, None] | ||
|
||
|
||
def detect_outliers(data_set: pd.DataFrame, thresholds: pd.DataFrame) -> pd.DataFrame: | ||
masks = [] | ||
|
||
season_index_mask = { | ||
season: get_season_index_mask(data_set, season) | ||
for season in thresholds["season"].unique() | ||
} | ||
|
||
for variable_pattern, pattern_configs in thresholds.groupby("variable_pattern"): | ||
df = data_set.filter(regex=f"^{variable_pattern}$") | ||
mask = None | ||
for _, season_config in pattern_configs.iterrows(): | ||
threshold_mask = (df < season_config.lo) | (df > season_config.hi) | ||
season_mask = threshold_mask & season_index_mask[season_config.season] | ||
|
||
if mask is None: | ||
mask = season_mask | ||
else: | ||
mask |= season_mask | ||
masks.append(mask) | ||
|
||
return pd.concat(masks, axis=1) | ||
|
||
|
||
def filter_data(data_set: pd.DataFrame, thresholds: pd.DataFrame) -> pd.DataFrame: | ||
mask = detect_outliers(data_set, thresholds) | ||
output_data = data_set.copy() | ||
output_data[mask] = np.nan | ||
return output_data | ||
|
||
|
||
@attrs.define | ||
class ThresholdBasedOutlierDetector: | ||
thresholds: pd.DataFrame = attrs.field() | ||
|
||
def filter_data(self, ds: xr.Dataset) -> xr.Dataset: | ||
""" | ||
Filter samples across all variables by assigning to nan | ||
""" | ||
stid = ds.station_id | ||
|
||
stid_thresholds = self.thresholds.query(f"stid == '{stid}'") | ||
if not stid_thresholds.any(): | ||
return ds | ||
|
||
data_df = ds.to_dataframe() # Switch to pandas | ||
data_df = filter_data( | ||
data_set=data_df, | ||
thresholds=stid_thresholds, | ||
) | ||
|
||
ds_out: xr.Dataset = data_df.to_xarray() | ||
ds_out = ds_out.assign_attrs(ds.attrs) # Dataset attrs | ||
for x in ds_out.data_vars: # variable-specific attrs | ||
ds_out[x].attrs = ds[x].attrs | ||
|
||
return ds_out | ||
|
||
@classmethod | ||
def from_csv_config(cls, config_file: Path) -> 'ThresholdBasedOutlierDetector': | ||
""" | ||
Instantiate using explicit csv file with explicit thresholds | ||
The CSV file shall have the format: | ||
* Comma separated | ||
* First row is header | ||
* Columns | ||
* stid: Station id | ||
* variabel_pattern: regular expression filtering the variable name | ||
* lo: Low threshold | ||
* hi: High threshold | ||
* season: The season of the filter: [, winter, spring, summer, fall]. The empty string means all seasons | ||
""" | ||
return cls(thresholds=pd.read_csv(config_file)) | ||
|
||
@classmethod | ||
def default(cls) -> 'ThresholdBasedOutlierDetector': | ||
""" | ||
Instantiate using aws thresholds stored in the python package. | ||
Returns | ||
------- | ||
""" | ||
default_thresholds_path = Path(__file__).parent.joinpath('thresholds.csv') | ||
return cls.from_csv_config(default_thresholds_path) |
Oops, something went wrong.