Skip to content

Commit

Permalink
changes according to reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoodEtedadi committed Dec 5, 2024
1 parent 196cb87 commit 29cd288
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 54 deletions.
23 changes: 8 additions & 15 deletions src/seismometer/data/binary_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
from numpy.typing import ArrayLike
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.metrics import auc

from . import calculate_bin_stats

Expand Down Expand Up @@ -69,7 +69,7 @@ def calculate_stats(
- 'Prevalence': Prevalence of positive samples.
- 'AUROC': Area under the receiver operating characteristic curve.
- 'AUPRC': Area under the precision-recall curve.
- Additional metrics (PPV, Flag Rate, Sensitivity, Specificity, Threshold).
- Additional metrics (PPV, Flag Rate, Sensitivity, Specificity, Threshold, etc.).
"""
# Check if metric is a valid name.
try:
Expand All @@ -85,22 +85,22 @@ def calculate_stats(
metrics_to_display = metrics_to_display if metrics_to_display else list(GENERATED_COLUMNS.keys())
_metrics_to_display_lower = [metric_to_display.lower() for metric_to_display in metrics_to_display]

stats = calculate_bin_stats(y_true, y_pred)

# Calculate overall statistics
if "positives" in _metrics_to_display_lower:
row_data["Positives"] = sum(y_true)
row_data["Positives"] = stats["TP"].iloc[-1]
if "prevalence" in _metrics_to_display_lower:
row_data["Prevalence"] = sum(y_true) / len(y_true)
row_data["Prevalence"] = stats["TP"].iloc[-1] / len(y_true)
if "auroc" in _metrics_to_display_lower:
row_data["AUROC"] = roc_auc_score(y_true, y_pred)
row_data["AUROC"] = auc(1 - stats["Specificity"], stats["Sensitivity"])
if "auprc" in _metrics_to_display_lower:
row_data["AUPRC"] = average_precision_score(y_true, y_pred)
row_data["AUPRC"] = auc(stats["Sensitivity"], stats["PPV"])

# Order/round metric values
metric_values = sorted([round(num, decimals) for num in metric_values])
metric_values = [0 if val == 0.0 else val for val in metric_values]

stats = calculate_bin_stats(y_true, y_pred)

metric_data = stats[GENERATED_COLUMNS[metric]].to_numpy()
thresholds = stats["Threshold"].to_numpy()

Expand All @@ -123,10 +123,3 @@ def calculate_stats(
)

return row_data


def is_binary_array(arr):
# Convert the input to a NumPy array if it isn't already
arr = np.asarray(arr)
# Check if all elements are either 0 or 1
return np.all((arr == 0) | (arr == 1))
14 changes: 14 additions & 0 deletions src/seismometer/seismogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import Optional

import numpy as np
import pandas as pd

from seismometer.configuration import AggregationStrategies, ConfigProvider, MergeStrategies
Expand Down Expand Up @@ -376,4 +377,17 @@ def create_cohorts(self) -> None:
self.cohort_cols.append(disp_attr)
logger.debug(f"Created cohorts: {', '.join(self.cohort_cols)}")

def _is_binary_array(self, arr):
# Convert the input to a NumPy array if it isn't already
arr = np.asarray(arr)
# Check if all elements are either 0 or 1
return np.all((arr == 0) | (arr == 1))

def get_binary_targets(self):
return [
pdh.event_value(target_col)
for target_col in self.target_cols
if self._is_binary_array(self.dataframe[[pdh.event_value(target_col)]])
]

# endregion
45 changes: 6 additions & 39 deletions src/seismometer/table/analytics_table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from enum import Enum
from typing import Any, List, Optional
from typing import List, Optional

import pandas as pd

Expand All @@ -14,7 +14,7 @@
from seismometer.controls.selection import MultiselectDropdownWidget
from seismometer.controls.styles import BOX_GRID_LAYOUT, WIDE_LABEL_STYLE
from seismometer.data import pandas_helpers as pdh
from seismometer.data.binary_performance import GENERATED_COLUMNS, Metric, calculate_stats, is_binary_array
from seismometer.data.binary_performance import GENERATED_COLUMNS, Metric, calculate_stats
from seismometer.data.performance import ( # MetricGenerator,
OVERALL_PERFORMANCE,
STATNAMES,
Expand Down Expand Up @@ -109,15 +109,7 @@ def __init__(
self.score_columns = score_columns if score_columns else sg.output_list
self.target_columns = target_columns
if sg.dataframe is not None:
self.target_columns = (
self.target_columns
if self.target_columns
else [
pdh.event_value(target_col)
for target_col in sg.target_cols
if is_binary_array(sg.dataframe[[pdh.event_value(target_col)]])
]
)
self.target_columns = self.target_columns if self.target_columns else sg.get_binary_targets()
self.statistics_data = statistics_data
if self.df is None and self.statistics_data is None:
raise ValueError("At least one of 'df' or 'statistics_data' needs to be provided.")
Expand All @@ -144,20 +136,6 @@ def __init__(
self.num_of_rows = len(self.score_columns) * len(self.target_columns)
self.per_context = per_context

# If polars package is not installed, overwrite is_na function in great_tables package to treat Agnostic
# as pandas dataframe.
try:
import polars as pl

# Use 'pl' to avoid the F401 error
_ = pl.DataFrame()
except ImportError:
from great_tables._tbl_data import Agnostic, PdDataFrame, is_na

@is_na.register(Agnostic)
def _(df: PdDataFrame, x: Any) -> bool:
return pd.isna(x)

def _validate_df_statistics_data(self):
if not self._initializing: # Skip validation during initial setup
if self.df is None and self.statistics_data is None:
Expand Down Expand Up @@ -506,15 +484,6 @@ def binary_analytics_table(
HTML
The HTML table for the fairness evaluation.
"""
from seismometer.seismogram import Seismogram

sg = Seismogram()
target_cols = [
pdh.event_value(target_col)
for target_col in target_cols
if is_binary_array(sg.dataframe[[pdh.event_value(target_col)]])
]

table_config = AnalyticsTableConfig(**COLORING_CONFIG_DEFAULT)
performance_metrics = PerformanceMetrics(
df=None,
Expand Down Expand Up @@ -607,13 +576,11 @@ def __init__(
sg = Seismogram()
self.model_options_widget = model_options_widget
self.title = title
self.binary_targets = [
target_col for target_col in target_cols if is_binary_array(sg.dataframe[[pdh.event_value(target_col)]])
]

# Multiple select dropdowns for targets and scores
self._target_cols = MultiselectDropdownWidget(
self.binary_targets,
value=self.binary_targets,
sg.get_binary_targets(),
value=sg.get_binary_targets(),
title="Targets",
)
self._score_cols = MultiselectDropdownWidget(
Expand Down

0 comments on commit 29cd288

Please sign in to comment.