From 29cd288a461f405ca6b7976435ea1f33c9d2ac39 Mon Sep 17 00:00:00 2001 From: metedadi Date: Wed, 4 Dec 2024 14:07:12 +0000 Subject: [PATCH] changes according to reviewer comments --- src/seismometer/data/binary_performance.py | 23 ++++------- src/seismometer/seismogram.py | 14 +++++++ src/seismometer/table/analytics_table.py | 45 +++------------------- 3 files changed, 28 insertions(+), 54 deletions(-) diff --git a/src/seismometer/data/binary_performance.py b/src/seismometer/data/binary_performance.py index 80a6d50..708bfb9 100644 --- a/src/seismometer/data/binary_performance.py +++ b/src/seismometer/data/binary_performance.py @@ -3,7 +3,7 @@ import numpy as np from numpy.typing import ArrayLike -from sklearn.metrics import average_precision_score, roc_auc_score +from sklearn.metrics import auc from . import calculate_bin_stats @@ -69,7 +69,7 @@ def calculate_stats( - 'Prevalence': Prevalence of positive samples. - 'AUROC': Area under the receiver operating characteristic curve. - 'AUPRC': Area under the precision-recall curve. - - Additional metrics (PPV, Flag Rate, Sensitivity, Specificity, Threshold). + - Additional metrics (PPV, Flag Rate, Sensitivity, Specificity, Threshold, etc.). """ # Check if metric is a valid name. try: @@ -85,22 +85,22 @@ def calculate_stats( metrics_to_display = metrics_to_display if metrics_to_display else list(GENERATED_COLUMNS.keys()) _metrics_to_display_lower = [metric_to_display.lower() for metric_to_display in metrics_to_display] + stats = calculate_bin_stats(y_true, y_pred) + # Calculate overall statistics if "positives" in _metrics_to_display_lower: - row_data["Positives"] = sum(y_true) + row_data["Positives"] = stats["TP"].iloc[-1] if "prevalence" in _metrics_to_display_lower: - row_data["Prevalence"] = sum(y_true) / len(y_true) + row_data["Prevalence"] = stats["TP"].iloc[-1] / len(y_true) if "auroc" in _metrics_to_display_lower: - row_data["AUROC"] = roc_auc_score(y_true, y_pred) + row_data["AUROC"] = auc(1 - stats["Specificity"], stats["Sensitivity"]) if "auprc" in _metrics_to_display_lower: - row_data["AUPRC"] = average_precision_score(y_true, y_pred) + row_data["AUPRC"] = auc(stats["Sensitivity"], stats["PPV"]) # Order/round metric values metric_values = sorted([round(num, decimals) for num in metric_values]) metric_values = [0 if val == 0.0 else val for val in metric_values] - stats = calculate_bin_stats(y_true, y_pred) - metric_data = stats[GENERATED_COLUMNS[metric]].to_numpy() thresholds = stats["Threshold"].to_numpy() @@ -123,10 +123,3 @@ def calculate_stats( ) return row_data - - -def is_binary_array(arr): - # Convert the input to a NumPy array if it isn't already - arr = np.asarray(arr) - # Check if all elements are either 0 or 1 - return np.all((arr == 0) | (arr == 1)) diff --git a/src/seismometer/seismogram.py b/src/seismometer/seismogram.py index 621dbd7..98d99c3 100644 --- a/src/seismometer/seismogram.py +++ b/src/seismometer/seismogram.py @@ -2,6 +2,7 @@ import logging from typing import Optional +import numpy as np import pandas as pd from seismometer.configuration import AggregationStrategies, ConfigProvider, MergeStrategies @@ -376,4 +377,17 @@ def create_cohorts(self) -> None: self.cohort_cols.append(disp_attr) logger.debug(f"Created cohorts: {', '.join(self.cohort_cols)}") + def _is_binary_array(self, arr): + # Convert the input to a NumPy array if it isn't already + arr = np.asarray(arr) + # Check if all elements are either 0 or 1 + return np.all((arr == 0) | (arr == 1)) + + def get_binary_targets(self): + return [ + pdh.event_value(target_col) + for target_col in self.target_cols + if self._is_binary_array(self.dataframe[[pdh.event_value(target_col)]]) + ] + # endregion diff --git a/src/seismometer/table/analytics_table.py b/src/seismometer/table/analytics_table.py index 10d34eb..de0abd4 100644 --- a/src/seismometer/table/analytics_table.py +++ b/src/seismometer/table/analytics_table.py @@ -1,6 +1,6 @@ import itertools from enum import Enum -from typing import Any, List, Optional +from typing import List, Optional import pandas as pd @@ -14,7 +14,7 @@ from seismometer.controls.selection import MultiselectDropdownWidget from seismometer.controls.styles import BOX_GRID_LAYOUT, WIDE_LABEL_STYLE from seismometer.data import pandas_helpers as pdh -from seismometer.data.binary_performance import GENERATED_COLUMNS, Metric, calculate_stats, is_binary_array +from seismometer.data.binary_performance import GENERATED_COLUMNS, Metric, calculate_stats from seismometer.data.performance import ( # MetricGenerator, OVERALL_PERFORMANCE, STATNAMES, @@ -109,15 +109,7 @@ def __init__( self.score_columns = score_columns if score_columns else sg.output_list self.target_columns = target_columns if sg.dataframe is not None: - self.target_columns = ( - self.target_columns - if self.target_columns - else [ - pdh.event_value(target_col) - for target_col in sg.target_cols - if is_binary_array(sg.dataframe[[pdh.event_value(target_col)]]) - ] - ) + self.target_columns = self.target_columns if self.target_columns else sg.get_binary_targets() self.statistics_data = statistics_data if self.df is None and self.statistics_data is None: raise ValueError("At least one of 'df' or 'statistics_data' needs to be provided.") @@ -144,20 +136,6 @@ def __init__( self.num_of_rows = len(self.score_columns) * len(self.target_columns) self.per_context = per_context - # If polars package is not installed, overwrite is_na function in great_tables package to treat Agnostic - # as pandas dataframe. - try: - import polars as pl - - # Use 'pl' to avoid the F401 error - _ = pl.DataFrame() - except ImportError: - from great_tables._tbl_data import Agnostic, PdDataFrame, is_na - - @is_na.register(Agnostic) - def _(df: PdDataFrame, x: Any) -> bool: - return pd.isna(x) - def _validate_df_statistics_data(self): if not self._initializing: # Skip validation during initial setup if self.df is None and self.statistics_data is None: @@ -506,15 +484,6 @@ def binary_analytics_table( HTML The HTML table for the fairness evaluation. """ - from seismometer.seismogram import Seismogram - - sg = Seismogram() - target_cols = [ - pdh.event_value(target_col) - for target_col in target_cols - if is_binary_array(sg.dataframe[[pdh.event_value(target_col)]]) - ] - table_config = AnalyticsTableConfig(**COLORING_CONFIG_DEFAULT) performance_metrics = PerformanceMetrics( df=None, @@ -607,13 +576,11 @@ def __init__( sg = Seismogram() self.model_options_widget = model_options_widget self.title = title - self.binary_targets = [ - target_col for target_col in target_cols if is_binary_array(sg.dataframe[[pdh.event_value(target_col)]]) - ] + # Multiple select dropdowns for targets and scores self._target_cols = MultiselectDropdownWidget( - self.binary_targets, - value=self.binary_targets, + sg.get_binary_targets(), + value=sg.get_binary_targets(), title="Targets", ) self._score_cols = MultiselectDropdownWidget(