Skip to content

Commit

Permalink
add per_context to analytics table + remove zebra striping
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoodEtedadi committed Dec 5, 2024
1 parent 9822843 commit 196cb87
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 94 deletions.
106 changes: 30 additions & 76 deletions src/seismometer/table/analytics_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from seismometer.controls.explore import ExplorationWidget
from seismometer.controls.selection import MultiselectDropdownWidget
from seismometer.controls.styles import BOX_GRID_LAYOUT, WIDE_LABEL_STYLE # , html_title
from seismometer.controls.styles import BOX_GRID_LAYOUT, WIDE_LABEL_STYLE
from seismometer.data import pandas_helpers as pdh
from seismometer.data.binary_performance import GENERATED_COLUMNS, Metric, calculate_stats, is_binary_array
from seismometer.data.performance import ( # MetricGenerator,
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(
top_level: str = "Score",
table_config: AnalyticsTableConfig = AnalyticsTableConfig(),
statistics_data: Optional[pd.DataFrame] = None,
per_context: bool = False,
):
"""
Initializes the PerformanceMetrics object with the necessary data and parameters.
Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(
self._initializing = False
self.rows_group_length = len(self.target_columns) if self.top_level == "Score" else len(self.score_columns)
self.num_of_rows = len(self.score_columns) * len(self.target_columns)
self.per_context = per_context

# If polars package is not installed, overwrite is_na function in great_tables package to treat Agnostic
# as pandas dataframe.
Expand Down Expand Up @@ -307,42 +309,9 @@ def generate_color_bar(self, gt, columns):
data_bar_stroke_width=self.data_bar_stroke_width,
),
)
return gt

def add_coloring_parity(self, gt, columns=None, even_color="white", odd_color="#F2F2F2"):
"""
Adds alternating row colors to the specified columns in the table.
Parameters
----------
gt : GT
The table object to which the alternating row colors will be added.
columns : Optional[List[str]], optional
The list of columns to which the alternating row colors will be applied, by default None.
If None, all columns are considered.
even_color : str, optional
The color for even rows, by default "#F2F2F2" (light gray).
odd_color : str, optional
The color for odd rows, by default "white".
Returns
-------
gt : GT
The table object with alternating row colors.
"""
gt = gt.tab_style(
style=style.fill(color=even_color),
locations=loc.body(
columns=columns,
rows=[row for row in range(self.num_of_rows) if (row % self.rows_group_length) % 2 == 0],
),
).tab_style(
style=style.fill(color=odd_color),
locations=loc.body(
columns=columns,
rows=[row for row in range(self.num_of_rows) if (row % self.rows_group_length) % 2 == 1],
),
)
# If col and col_bar are not grouped under a metric value, group them together.
if data_col == f"{col}_bar":
gt = gt.tab_spanner(label=col, columns=[col, f"{col}_bar"])
return gt

def group_columns_by_metric_value(self, gt, columns, value):
Expand Down Expand Up @@ -413,7 +382,7 @@ def analytics_table(self):

gt = self.generate_color_bar(gt, columns=data.columns)

# Group columns of the form ***_value together
# Group columns of the form value_*** together
grouped_columns = []
for value in self.metric_values:
columns = [column for column in data.columns if column.startswith(f"{value}_")]
Expand All @@ -425,12 +394,9 @@ def analytics_table(self):
if col not in grouped_columns and f"{col}_bar" in data.columns:
gt = self.add_borders(gt, col, f"{col}_bar")

# Light gray/white alternating pattern needs to be corrected
gt = self.add_coloring_parity(gt)
gt = gt.opt_horizontal_padding(scale=3).tab_options(row_group_font_weight="bold")

gt = gt.cols_align(align="left").opt_horizontal_padding(scale=3).opt_stylize()

return gt
return HTML(gt.as_raw_html())

def _prepare_data(self, data):
"""
Expand Down Expand Up @@ -472,10 +438,21 @@ def _generate_table_data(self):
for first, second in product:
current_row = {self.top_level: first, self._get_second_level[self.top_level]: second}
(score, target) = (first, second) if self.top_level == "Score" else (second, first)
if self.per_context:
sg = Seismogram()
data = pdh.event_score(
sg.dataframe,
sg.entity_keys,
score=score,
ref_event=sg.predict_time,
aggregation_method=sg.event_aggregation_method(target),
)
else:
data = self.df
current_row.update(
calculate_stats(
self.df[target],
self.df[score],
data[target],
data[score],
self.metric,
self.metric_values,
self.metrics_to_display,
Expand Down Expand Up @@ -522,21 +499,7 @@ def binary_analytics_table(
Parameters
----------
metric_generator : The BinaryClassifierMetricGenerator that determines rho.
metric_list : list[str]
List of metrics to evaluate.
cohort_dict : dict[str, tuple[Any]]
Collection of cohort groups to loop over.
fairness_ratio : float
Ratio of acceptable difference between cohorts, 20% is 0.2, 200% is 2.0.
target : str
The target descriptor for the binary classifier.
score : str
The score descriptor for the binary classifier.
threshold : float
The threshold for the binary classifier.
per_context : bool, optional
Whether to group scores by context, by default False.
Returns
-------
Expand All @@ -551,21 +514,10 @@ def binary_analytics_table(
for target_col in target_cols
if is_binary_array(sg.dataframe[[pdh.event_value(target_col)]])
]
# data = (
# pdh.event_score(
# sg.dataframe,
# sg.entity_keys,
# score=score,
# ref_event=sg.predict_time,
# aggregation_method=sg.event_aggregation_method(target),
# )
# if per_context
# else sg.dataframe
# )
data = None

table_config = AnalyticsTableConfig(**COLORING_CONFIG_DEFAULT)
performance_metrics = PerformanceMetrics(
df=data,
df=None,
score_columns=score_cols,
target_columns=target_cols,
metric=metric,
Expand All @@ -574,16 +526,19 @@ def binary_analytics_table(
title=title if title else "Model Performance Statistics",
top_level=group_by,
table_config=table_config,
per_context=per_context,
)
return performance_metrics.analytics_table()


class ExploreAnalyticsTable(ExplorationWidget):
def __init__(self, title: Optional[str] = None):
def __init__(self, title: Optional[str] = None, *, per_context: bool = False):
from seismometer.seismogram import Seismogram

sg = Seismogram()
self.metric_generator = BinaryClassifierMetricGenerator()
self.title = title
self.per_context = per_context

super().__init__(
title="Model Performance Comparison",
Expand All @@ -609,7 +564,7 @@ def generate_plot_args(self) -> tuple[tuple, dict]:
list(self.option_widget.metrics_to_display), # Updated to use metrics_to_display
self.option_widget.group_by, # Updated to use group_by
)
kwargs = {}
kwargs = {"title": self.title, "per_context": self.per_context}
return args, kwargs


Expand Down Expand Up @@ -700,7 +655,6 @@ def __init__(
self._group_by.observe(self._on_value_changed, names="value")

v_children = [
# html_title("Analytics Table Options"),
self._target_cols,
self._score_cols,
self._metrics_to_display,
Expand Down
18 changes: 0 additions & 18 deletions tests/table/test_analytics_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,24 +302,6 @@ def test_generate_color_bar(self, fake_seismo):
gt = pm.generate_color_bar(gt, data.columns)
assert gt is not None

def test_add_coloring_parity(self, fake_seismo):
df = pd.DataFrame(
{
"score1": [0.1, 0.4, 0.35, 0.8],
"score2": [0.2, 0.5, 0.3, 0.7],
"target1": [0, 1, 0, 1],
"target2": [1, 0, 1, 0],
}
)
scores = ["score1", "score2"]
targets = ["target1", "target2"]

pm = PerformanceMetrics(df=df, score_columns=scores, target_columns=targets, metric="sensitivity")
data = pm._generate_table_data()
gt = pm.generate_initial_table(data)
gt = pm.add_coloring_parity(gt)
assert gt is not None

def test_group_columns_by_metric_value(self, fake_seismo):
df = pd.DataFrame(
{
Expand Down

0 comments on commit 196cb87

Please sign in to comment.