Skip to content

Commit

Permalink
Move value decay plot to curve plotting, adapt decimal places in tabl…
Browse files Browse the repository at this point in the history
…e and look for x-axis.
  • Loading branch information
kosmitive committed May 1, 2024
1 parent 33e1e97 commit 752d9ac
Show file tree
Hide file tree
Showing 7 changed files with 383 additions and 412 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

Code for the submission to TMLR 2024. The original paper can be found [here](https://arxiv.org/abs/2211.06800).

| :warning: WARNING |
|:------------------------------------------------------------|
| It is it not recommended to use this library in production. |

# Getting started

## Installation
Expand Down
42 changes: 26 additions & 16 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,6 @@ experiments:
point_removal:
sampler: default
curves:
top_fraction:
fn: top_fraction
alpha_range:
from: 0.01
to: 0.5
step: 0.01
plots:
- rank_stability

accuracy_logistic_regression:
fn: metric
Expand Down Expand Up @@ -120,6 +112,19 @@ experiments:
plots:
- accuracy

top_fraction:
fn: top_fraction
alpha_range:
from: 0.01
to: 0.5
step: 0.01
plots:
- rank_stability

value_decay:
fn: value_decay
plots:
- value_decay

metrics:
weighted_relative_accuracy_difference_random:
Expand Down Expand Up @@ -183,14 +188,6 @@ experiments:

plots:

density:

rank_stability:
type: line
mean_agg: intersect
x_label: "%"
y_label: "a"

accuracy:
type: line
mean_agg: mean
Expand All @@ -208,6 +205,7 @@ plots:

table:
type: table
format: "%.3f"

box_wrad:
type: boxplot
Expand All @@ -221,6 +219,18 @@ plots:
type: boxplot
x_label: "AUC"

rank_stability:
type: line
mean_agg: intersect
x_label: "%"
y_label: "a"

value_decay:
type: line
mean_agg: mean
std_agg: bootstrap
x_label: "n"
y_label: "%"

samplers:
default:
Expand Down
653 changes: 324 additions & 329 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "re_classwise_shapley"
version = "0.1.0"
version = "1.0.0"
description = "Reproduction of the paper 'CS-Shapley: Class-wise Shapley Values for Data Valuation in Classification'"
authors = ["Markus Semmler"]
license = "LGPL-3.0"
Expand All @@ -23,8 +23,8 @@ click = "^8.1.3"
mlflow = "^2.9.2"
boto3 = "^1.28.36"
plotly = "^5.16.1"
dataframe_image = "*"
python-dotenv = "*"
dataframe_image = "^0.2.3"
python-dotenv = "^1.0.1"
pyDVL = {version="0.9.1", extras=["memcached"]}
python-memcached = "1.62"

Expand Down
10 changes: 4 additions & 6 deletions scripts/render_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
plot_metric_table,
plot_threshold_characteristics,
plot_time,
plot_value_decay,
)
from re_classwise_shapley.utils import (
flatten_dict,
Expand Down Expand Up @@ -113,10 +112,6 @@ def _render_plots(experiment_name: str, model_name: str):
repetitions,
method_names,
)
logger.info(f"Plotting value decay for all methods.")
with plot_value_decay(valuation_results, method_names) as fig:
log_figure(fig, output_folder, f"decay.{plot_format}", "values")

for method_name in method_names:
logger.info(f"Plot histogram for values of method `{method_name}`.")
with plot_histogram(valuation_results, [method_name]) as fig:
Expand Down Expand Up @@ -243,7 +238,9 @@ def _render_plots(experiment_name: str, model_name: str):
)

logger.info(f"Plotting table for metric '{metric_name}'.")
with plot_metric_table(metric_table) as fig:
with plot_metric_table(
metric_table, format_x=plot_settings.get("format", None)
) as fig:
log_figure(
fig,
output_folder,
Expand All @@ -252,6 +249,7 @@ def _render_plots(experiment_name: str, model_name: str):
)
case "boxplot":
x_label = plot_settings.get("x_label", None)
x_format = plot_settings.get("format", None)
logger.info(f"Plotting boxplot for metric '{metric_name}'.")
with plot_metric_boxplot(
selected_loaded_metrics, x_label=x_label
Expand Down
17 changes: 17 additions & 0 deletions src/re_classwise_shapley/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ def curve_metric(
return curve


def curve_value_decay(values: ValuationResult, fraction: float = 0.05) -> pd.Series:
"""
Computes the value decay curve for a given valuation result. The value decay curve
shows the average value of the valuation result for each prefix of the ranking.
:param values: Valuation result to compute the value decay curve for.
:param fraction: Fraction of the ranking to consider.
:return: A pd.Series object containing the value decay curve.
"""
method_values = np.flip(np.sort(values.values))
reduced_length = int(len(method_values) * fraction)
method_values = method_values[:reduced_length] / np.max(method_values)
return pd.Series(
method_values, index=np.arange(1, len(method_values) + 1), name="value_decay"
)


def _curve_precision_recall_ranking(
target_list: NDArray[np.int_], ranked_list: NDArray[np.int_]
) -> pd.Series:
Expand Down Expand Up @@ -280,4 +296,5 @@ def evaluate_at_point(
"metric": curve_metric,
"precision_recall": curve_roc,
"top_fraction": curve_top_fraction,
"value_decay": curve_value_decay,
}
63 changes: 5 additions & 58 deletions src/re_classwise_shapley/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.patches import Patch
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import FormatStrFormatter, FuncFormatter

from re_classwise_shapley.log import setup_logger
from re_classwise_shapley.types import OneOrMany, ensure_list
Expand Down Expand Up @@ -254,7 +254,8 @@ def plot_grid_over_datasets(
shadow=False,
**legend_kwargs,
)
fig.subplots_adjust(bottom=0.1)

fig.subplots_adjust(bottom=0.1)
yield fig
plt.close(fig)

Expand Down Expand Up @@ -326,62 +327,6 @@ def plot_histogram_func(
yield fig


@contextmanager
def plot_value_decay(
data: pd.DataFrame,
method_names: OneOrMany[str],
patch_size: Tuple[float, float] = (3, 2.5),
n_cols: int = 5,
fraction: float = 0.05,
) -> plt.Figure:
def plot_value_decay_func(
data: pd.DataFrame, ax: plt.Axes, method_names: List[str], **kwargs
):
data.loc[:, "method_name"] = data["method_name"].apply(lambda m: LABELS[m])
for method_name in method_names:
method_dataset_valuation_results = data.loc[
data["method_name"] == LABELS[method_name]
]
method_values = np.stack(
method_dataset_valuation_results["valuation"].apply(
lambda v: np.flip(np.sort(v.values))
)
)
reduced_length = int(method_values.shape[1] * fraction)
method_values = method_values[:, :reduced_length] / np.max(
method_values, axis=1, keepdims=True
)
color_name = COLOR_ENCODING[LABELS[method_name]]
mean_color, shade_color = COLORS[color_name]
method_values = pd.DataFrame(
method_values.T, index=np.arange(method_values.shape[1])
)
shaded_interval_line_plot(
method_values,
mean_agg="mean",
std_agg="bootstrap",
abscissa=method_values.index,
mean_color=mean_color,
shade_color=shade_color,
label=method_name,
ax=ax,
)

with plot_grid_over_datasets(
data,
plot_value_decay_func,
patch_size=patch_size,
n_cols=n_cols,
legend=True,
method_names=ensure_list(method_names),
xlabel="n",
ylabel="Value",
format_x_ticks="%.3f",
grid=True,
) as fig:
yield fig


@contextmanager
def plot_time(
data: pd.DataFrame,
Expand Down Expand Up @@ -489,13 +434,15 @@ def plot_curves_func(data: pd.DataFrame, ax: plt.Axes, **kwargs):
@contextmanager
def plot_metric_table(
data: pd.DataFrame,
format_x: str = "%.3f",
) -> plt.Figure:
"""
Takes a pd.DataFrame and plots it as a table.
"""
data.columns = [LABELS[c] for c in data.columns]
fig, ax = plt.subplots()
sns.heatmap(data, annot=True, cmap=plt.cm.get_cmap("viridis"), ax=ax)
ax.xaxis.set_major_formatter(FormatStrFormatter(format_x))
plt.ylabel("")
plt.xlabel("")
plt.tight_layout()
Expand Down

0 comments on commit 752d9ac

Please sign in to comment.