Skip to content

Commit

Permalink
update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoodEtedadi committed Dec 5, 2024
1 parent 29cd288 commit d40d9fd
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/seismometer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,5 @@ def run_startup(
config = config_provider or ConfigProvider(config_path, output_path=output_path, definitions=definitions)
loader = loader_factory(config)
sg = Seismogram(config, loader)

sg.load_data(predictions=predictions_frame, events=events_frame)
2 changes: 1 addition & 1 deletion src/seismometer/data/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def calculate_bin_stats(
fpr = fps / total_negatives

ppv = tps / (tps + fps)
ppv[np.isnan(ppv)] = 0
ppv[np.isnan(ppv)] = 1

# TN / TN + FN
npv = np.divide(tns, tns + fns)
Expand Down
15 changes: 15 additions & 0 deletions src/seismometer/table/analytics_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,21 @@ def __init__(
self.rows_group_length = len(self.target_columns) if self.top_level == "Score" else len(self.score_columns)
self.num_of_rows = len(self.score_columns) * len(self.target_columns)
self.per_context = per_context
# If polars package is not installed, overwrite is_na function in great_tables package to treat Agnostic
# as pandas dataframe.
try:
import polars as pl

# Use 'pl' to avoid the F401 error
_ = pl.DataFrame()
except ImportError:
from typing import Any

from great_tables._tbl_data import Agnostic, PdDataFrame, is_na

@is_na.register(Agnostic)
def _(df: PdDataFrame, x: Any) -> bool:
return pd.isna(x)

def _validate_df_statistics_data(self):
if not self._initializing: # Skip validation during initial setup
Expand Down
File renamed without changes.
36 changes: 18 additions & 18 deletions tests/data/test_cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@ def input_df():
def expected_df(cohorts):
data_rows = np.vstack(
(
# TP,FP,TN,FN, Acc,Sens,Spec,PPV,NPV, Flag, LR+, NNE, NNT1/2, cohort,ct,tgtct,
[[0, 0, 1, 1, 0.5, 0.0, 1.0, 0.0, 0.5, 0.0, np.nan, np.inf, np.inf, "<1.0", 2, 1]] * 70,
[[0, 1, 0, 1, 0, 0, 0, 0, 0, 0.5, 0, np.inf, np.inf, "<1.0", 2, 1]] * 10,
[[1, 1, 0, 0, 0.5, 1, 0, 0.5, 1, 1, 1, 2, 4, "<1.0", 2, 1]] * 21,
# TP,FP,TN,FN, Acc,Sens,Spec, PPV, NPV, Flag, LR+, NNE, NNT1/2, cohort,ct,tgtct
[[0, 0, 2, 2, 0.5, 0.0, 1.0, 0.0, 0.5, 0.0, np.nan, np.inf, np.inf, ">=1.0", 4, 2]] * 30,
[[1, 0, 2, 1, 0.75, 0.5, 1, 1, 2 / 3, 0.25, np.inf, 1, 2, ">=1.0", 4, 2]] * 10,
[[1, 1, 1, 1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0, 2, 4, ">=1.0", 4, 2]] * 10,
[[2, 1, 1, 0, 0.75, 1, 0.5, 2 / 3, 1, 0.75, 2, 1.5, 3, ">=1.0", 4, 2]] * 10,
[[2, 2, 0, 0, 0.5, 1, 0, 0.5, 1, 1, 1, 2, 4, ">=1.0", 4, 2]] * 41,
# TP,FP,TN,FN, Acc,Sens,Spec,PPV,NPV, Flag, LR+, NNE, NNT1/2, cohort,ct,tgtct
[[0, 0, 1, 1, 0.5, 0.0, 1.0, 0.0, 0.5, 0.0, np.nan, np.inf, np.inf, "1.0-2.0", 2, 1]] * 50,
[[1, 0, 1, 0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, np.inf, 1, 2, "1.0-2.0", 2, 1]] * 10,
[[1, 1, 0, 0, 0.5, 1, 0, 0.5, 1, 1, 1, 2, 4, "1.0-2.0", 2, 1]] * 41,
# TP,FP,TN,FN, Acc,Sens,Spec,PPV,NPV, Flag, LR+, NNE, NNT1/2, cohort,ct,tgtct
[[0, 0, 1, 1, 0.5, 0.0, 1.0, 0.0, 0.5, 0.0, np.nan, np.inf, np.inf, ">=2.0", 2, 1]] * 30,
[[1, 0, 1, 0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, np.inf, 1, 2, ">=2.0", 2, 1]] * 10,
[[1, 1, 0, 0, 0.5, 1, 0, 0.5, 1, 1, 1, 2, 4, ">=2.0", 2, 1]] * 61,
# TP,FP,TN,FN, Acc,Sens,Spec,PPV,NPV, Flag, LR+, NNT1/2, cohort,ct,tgtct,
[[0, 0, 1, 1, 0.5, 0.0, 1.0, 1.0, 0.5, 0.0, np.nan, 2.0, "<1.0", 2, 1]] * 70,
[[0, 1, 0, 1, 0, 0, 0, 0, 0, 0.5, 0, np.inf, "<1.0", 2, 1]] * 10,
[[1, 1, 0, 0, 0.5, 1, 0, 0.5, 1, 1, 1, 4, "<1.0", 2, 1]] * 21,
# TP,FP,TN,FN, Acc,Sens,Spec, PPV, NPV, Flag, LR+, cohort,ct,tgtct
[[0, 0, 2, 2, 0.5, 0.0, 1.0, 1.0, 0.5, 0.0, np.nan, 2.0, ">=1.0", 4, 2]] * 30,
[[1, 0, 2, 1, 0.75, 0.5, 1, 1, 2 / 3, 0.25, np.inf, 2, ">=1.0", 4, 2]] * 10,
[[1, 1, 1, 1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0, 4, ">=1.0", 4, 2]] * 10,
[[2, 1, 1, 0, 0.75, 1, 0.5, 2 / 3, 1, 0.75, 2, 3, ">=1.0", 4, 2]] * 10,
[[2, 2, 0, 0, 0.5, 1, 0, 0.5, 1, 1, 1, 4, ">=1.0", 4, 2]] * 41,
# TP,FP,TN,FN, Acc,Sens,Spec,PPV,NPV, Flag, LR+, cohort,ct,tgtct
[[0, 0, 1, 1, 0.5, 0.0, 1.0, 1.0, 0.5, 0.0, np.nan, 2.0, "1.0-2.0", 2, 1]] * 50,
[[1, 0, 1, 0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, np.inf, 2, "1.0-2.0", 2, 1]] * 10,
[[1, 1, 0, 0, 0.5, 1, 0, 0.5, 1, 1, 1, 4, "1.0-2.0", 2, 1]] * 41,
# TP,FP,TN,FN, Acc,Sens,Spec,PPV,NPV, Flag, LR+, cohort,ct,tgtct
[[0, 0, 1, 1, 0.5, 0.0, 1.0, 1.0, 0.5, 0.0, np.nan, 2.0, ">=2.0", 2, 1]] * 30,
[[1, 0, 1, 0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, np.inf, 2, ">=2.0", 2, 1]] * 10,
[[1, 1, 0, 0, 0.5, 1, 0, 0.5, 1, 1, 1, 4, ">=2.0", 2, 1]] * 61,
)
)

Expand Down
6 changes: 3 additions & 3 deletions tests/data/test_perf_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def stats_case_base():

expected = []
# Threshold, TP, FP, TN, FN, Acc, Sens, Spec, PPV, NPV, Flag Rate, LR+, NNE, NBS, NNT1/3 | Threshold
expected.append([100, 0, 0, 1, 2, 1 / 3, 0, 1, 0, 1 / 3, 0, np.nan, np.inf, np.nan, np.inf]) # 1
expected.append([100, 0, 0, 1, 2, 1 / 3, 0, 1, 1, 1 / 3, 0, np.nan, np.inf, np.nan, 3]) # 1
expected.append([50, 1, 0, 1, 1, 2 / 3, 0.5, 1, 1, 0.5, 1 / 3, np.inf, 1, 1 / 3, 3]) # .5
expected.append([10, 2, 1, 0, 0, 2 / 3, 1, 0, 2 / 3, 1, 1, 1, 1.5, 17 / 27, 4.5]) # .1
expected.append([0, 2, 1, 0, 0, 2 / 3, 1, 0, 2 / 3, 1, 1, 1, 1.5, 2 / 3, 4.5]) # 0
Expand All @@ -46,7 +46,7 @@ def stats_case_0():

expected = []
# Threshold, TP, FP, TN, FN, Acc, Sens, Spec, PPV, NPV, Flag Rate, LR+, NNE, NBS, NNT1/3 | Threshold
expected.append([100, 0, 0, 2, 2, 0.5, 0, 1, 0, 0.5, 0, np.nan, np.inf, np.nan, np.inf]) # 1
expected.append([100, 0, 0, 2, 2, 0.5, 0, 1, 1, 0.5, 0, np.nan, np.inf, np.nan, 3]) # 1
expected.append([50, 1, 0, 2, 1, 0.75, 0.5, 1, 1, 2 / 3, 0.25, np.inf, 1, 1 / 4, 3]) # .5
expected.append([10, 2, 1, 1, 0, 0.75, 1, 0.5, 2 / 3, 1, 0.75, 2, 1.5, 17 / 36, 4.5]) # .1
expected.append([0, 2, 2, 0, 0, 0.5, 1, 0, 0.5, 1, 1, 1, 2, 1 / 2, 6]) # 0
Expand Down Expand Up @@ -90,7 +90,7 @@ def stats_case_0_4():

expected = []
# Threshold, TP, FP, TN, FN, Acc, Sens, Spec, PPV, NPV, Flag Rate, LR+, NNE, NBS, NNT1/3 | Threshold
expected.append([100, 0, 0, 2, 2, 0.5, 0, 1, 0, 0.5, 0, np.nan, np.inf, np.nan, np.inf]) # 1
expected.append([100, 0, 0, 2, 2, 0.5, 0, 1, 1, 0.5, 0, np.nan, np.inf, np.nan, 3]) # 1
expected.append([75, 1, 0, 2, 1, 0.75, 0.5, 1, 1, 2 / 3, 0.25, np.inf, 1, 1 / 4, 3]) # .75
expected.append([50, 2, 0, 2, 0, 1, 1, 1, 1, 1, 0.5, np.inf, 1, 1 / 2, 3]) # .5
expected.append([25, 2, 1, 1, 0, 0.75, 1, 0.5, 2 / 3, 1, 0.75, 2, 1.5, 5 / 12, 4.5]) # .25
Expand Down

0 comments on commit d40d9fd

Please sign in to comment.