diff --git a/src/seismometer/table/analytics_table.py b/src/seismometer/table/analytics_table.py index 0ed8ed8..10d34eb 100644 --- a/src/seismometer/table/analytics_table.py +++ b/src/seismometer/table/analytics_table.py @@ -12,7 +12,7 @@ from seismometer.controls.explore import ExplorationWidget from seismometer.controls.selection import MultiselectDropdownWidget -from seismometer.controls.styles import BOX_GRID_LAYOUT, WIDE_LABEL_STYLE # , html_title +from seismometer.controls.styles import BOX_GRID_LAYOUT, WIDE_LABEL_STYLE from seismometer.data import pandas_helpers as pdh from seismometer.data.binary_performance import GENERATED_COLUMNS, Metric, calculate_stats, is_binary_array from seismometer.data.performance import ( # MetricGenerator, @@ -59,6 +59,7 @@ def __init__( top_level: str = "Score", table_config: AnalyticsTableConfig = AnalyticsTableConfig(), statistics_data: Optional[pd.DataFrame] = None, + per_context: bool = False, ): """ Initializes the PerformanceMetrics object with the necessary data and parameters. @@ -141,6 +142,7 @@ def __init__( self._initializing = False self.rows_group_length = len(self.target_columns) if self.top_level == "Score" else len(self.score_columns) self.num_of_rows = len(self.score_columns) * len(self.target_columns) + self.per_context = per_context # If polars package is not installed, overwrite is_na function in great_tables package to treat Agnostic # as pandas dataframe. @@ -307,42 +309,9 @@ def generate_color_bar(self, gt, columns): data_bar_stroke_width=self.data_bar_stroke_width, ), ) - return gt - - def add_coloring_parity(self, gt, columns=None, even_color="white", odd_color="#F2F2F2"): - """ - Adds alternating row colors to the specified columns in the table. - - Parameters - ---------- - gt : GT - The table object to which the alternating row colors will be added. - columns : Optional[List[str]], optional - The list of columns to which the alternating row colors will be applied, by default None. - If None, all columns are considered. - even_color : str, optional - The color for even rows, by default "#F2F2F2" (light gray). - odd_color : str, optional - The color for odd rows, by default "white". - - Returns - ------- - gt : GT - The table object with alternating row colors. - """ - gt = gt.tab_style( - style=style.fill(color=even_color), - locations=loc.body( - columns=columns, - rows=[row for row in range(self.num_of_rows) if (row % self.rows_group_length) % 2 == 0], - ), - ).tab_style( - style=style.fill(color=odd_color), - locations=loc.body( - columns=columns, - rows=[row for row in range(self.num_of_rows) if (row % self.rows_group_length) % 2 == 1], - ), - ) + # If col and col_bar are not grouped under a metric value, group them together. + if data_col == f"{col}_bar": + gt = gt.tab_spanner(label=col, columns=[col, f"{col}_bar"]) return gt def group_columns_by_metric_value(self, gt, columns, value): @@ -413,7 +382,7 @@ def analytics_table(self): gt = self.generate_color_bar(gt, columns=data.columns) - # Group columns of the form ***_value together + # Group columns of the form value_*** together grouped_columns = [] for value in self.metric_values: columns = [column for column in data.columns if column.startswith(f"{value}_")] @@ -425,12 +394,9 @@ def analytics_table(self): if col not in grouped_columns and f"{col}_bar" in data.columns: gt = self.add_borders(gt, col, f"{col}_bar") - # Light gray/white alternating pattern needs to be corrected - gt = self.add_coloring_parity(gt) + gt = gt.opt_horizontal_padding(scale=3).tab_options(row_group_font_weight="bold") - gt = gt.cols_align(align="left").opt_horizontal_padding(scale=3).opt_stylize() - - return gt + return HTML(gt.as_raw_html()) def _prepare_data(self, data): """ @@ -472,10 +438,21 @@ def _generate_table_data(self): for first, second in product: current_row = {self.top_level: first, self._get_second_level[self.top_level]: second} (score, target) = (first, second) if self.top_level == "Score" else (second, first) + if self.per_context: + sg = Seismogram() + data = pdh.event_score( + sg.dataframe, + sg.entity_keys, + score=score, + ref_event=sg.predict_time, + aggregation_method=sg.event_aggregation_method(target), + ) + else: + data = self.df current_row.update( calculate_stats( - self.df[target], - self.df[score], + data[target], + data[score], self.metric, self.metric_values, self.metrics_to_display, @@ -522,21 +499,7 @@ def binary_analytics_table( Parameters ---------- - metric_generator : The BinaryClassifierMetricGenerator that determines rho. - metric_list : list[str] - List of metrics to evaluate. - cohort_dict : dict[str, tuple[Any]] - Collection of cohort groups to loop over. - fairness_ratio : float - Ratio of acceptable difference between cohorts, 20% is 0.2, 200% is 2.0. - target : str - The target descriptor for the binary classifier. - score : str - The score descriptor for the binary classifier. - threshold : float - The threshold for the binary classifier. - per_context : bool, optional - Whether to group scores by context, by default False. + Returns ------- @@ -551,21 +514,10 @@ def binary_analytics_table( for target_col in target_cols if is_binary_array(sg.dataframe[[pdh.event_value(target_col)]]) ] - # data = ( - # pdh.event_score( - # sg.dataframe, - # sg.entity_keys, - # score=score, - # ref_event=sg.predict_time, - # aggregation_method=sg.event_aggregation_method(target), - # ) - # if per_context - # else sg.dataframe - # ) - data = None + table_config = AnalyticsTableConfig(**COLORING_CONFIG_DEFAULT) performance_metrics = PerformanceMetrics( - df=data, + df=None, score_columns=score_cols, target_columns=target_cols, metric=metric, @@ -574,16 +526,19 @@ def binary_analytics_table( title=title if title else "Model Performance Statistics", top_level=group_by, table_config=table_config, + per_context=per_context, ) return performance_metrics.analytics_table() class ExploreAnalyticsTable(ExplorationWidget): - def __init__(self, title: Optional[str] = None): + def __init__(self, title: Optional[str] = None, *, per_context: bool = False): from seismometer.seismogram import Seismogram sg = Seismogram() self.metric_generator = BinaryClassifierMetricGenerator() + self.title = title + self.per_context = per_context super().__init__( title="Model Performance Comparison", @@ -609,7 +564,7 @@ def generate_plot_args(self) -> tuple[tuple, dict]: list(self.option_widget.metrics_to_display), # Updated to use metrics_to_display self.option_widget.group_by, # Updated to use group_by ) - kwargs = {} + kwargs = {"title": self.title, "per_context": self.per_context} return args, kwargs @@ -700,7 +655,6 @@ def __init__( self._group_by.observe(self._on_value_changed, names="value") v_children = [ - # html_title("Analytics Table Options"), self._target_cols, self._score_cols, self._metrics_to_display, diff --git a/tests/table/test_analytics_table.py b/tests/table/test_analytics_table.py index 0179fb1..67482e6 100644 --- a/tests/table/test_analytics_table.py +++ b/tests/table/test_analytics_table.py @@ -302,24 +302,6 @@ def test_generate_color_bar(self, fake_seismo): gt = pm.generate_color_bar(gt, data.columns) assert gt is not None - def test_add_coloring_parity(self, fake_seismo): - df = pd.DataFrame( - { - "score1": [0.1, 0.4, 0.35, 0.8], - "score2": [0.2, 0.5, 0.3, 0.7], - "target1": [0, 1, 0, 1], - "target2": [1, 0, 1, 0], - } - ) - scores = ["score1", "score2"] - targets = ["target1", "target2"] - - pm = PerformanceMetrics(df=df, score_columns=scores, target_columns=targets, metric="sensitivity") - data = pm._generate_table_data() - gt = pm.generate_initial_table(data) - gt = pm.add_coloring_parity(gt) - assert gt is not None - def test_group_columns_by_metric_value(self, fake_seismo): df = pd.DataFrame( {