diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index 2dfe2b3448..7d5f17e948 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -258,25 +258,8 @@ def get_run_times(self, case_keys=None): return df def plot_run_times(self, case_keys=None): - if case_keys is None: - case_keys = list(self.cases.keys()) - run_times = self.get_run_times(case_keys=case_keys) - - colors = self.get_colors() - import matplotlib.pyplot as plt - - fig, ax = plt.subplots() - labels = [] - for i, key in enumerate(case_keys): - labels.append(self.cases[key]["label"]) - rt = run_times.at[key, "run_times"] - ax.bar(i, rt, width=0.8, color=colors[key]) - ax.set_xticks(np.arange(len(case_keys))) - ax.set_xticklabels(labels, rotation=45.0) - return fig - - # ax = run_times.plot(kind="bar") - # return ax.figure + from .benchmark_plot_tools import plot_run_times + return plot_run_times(self, case_keys=case_keys) def compute_results(self, case_keys=None, verbose=False, **result_params): if case_keys is None: diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index ee9d2947d6..ae9009521f 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -1,4 +1,4 @@ - +import numpy as np @@ -7,3 +7,213 @@ def _simpleaxis(ax): ax.spines["right"].set_visible(False) ax.get_xaxis().tick_bottom() ax.get_yaxis().tick_left() + + +def plot_run_times(study, case_keys=None): + """ + Plot run times for a BenchmarkStudy. + + Parameters + ---------- + study : SorterStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + run_times = study.get_run_times(case_keys=case_keys) + + colors = study.get_colors() + + + fig, ax = plt.subplots() + labels = [] + for i, key in enumerate(case_keys): + labels.append(study.cases[key]["label"]) + rt = run_times.at[key, "run_times"] + ax.bar(i, rt, width=0.8, color=colors[key]) + ax.set_xticks(np.arange(len(case_keys))) + ax.set_xticklabels(labels, rotation=45.0) + return fig + + +def plot_unit_counts(study, case_keys=None): + """ + Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" + + Parameters + ---------- + study : SorterStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + from spikeinterface.widgets.utils import get_some_colors + + if case_keys is None: + case_keys = list(study.cases.keys()) + + + count_units = study.get_count_units(case_keys=case_keys) + + fig, ax = plt.subplots() + + columns = count_units.columns.tolist() + columns.remove("num_gt") + columns.remove("num_sorter") + + ncol = len(columns) + + colors = get_some_colors(columns, color_engine="auto", map_name="hot") + colors["num_well_detected"] = "green" + + xticklabels = [] + for i, key in enumerate(case_keys): + for c, col in enumerate(columns): + x = i + 1 + c / (ncol + 1) + y = count_units.loc[key, col] + if not "well_detected" in col: + y = -y + + if i == 0: + label = col.replace("num_", "").replace("_", " ").title() + else: + label = None + + ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) + + xticklabels.append(study.cases[key]["label"]) + + ax.set_xticks(np.arange(len(case_keys)) + 1) + ax.set_xticklabels(xticklabels) + ax.legend() + + return fig + +def plot_performances(study, mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None): + """ + Plot performances over case for a study. + + Parameters + ---------- + study : GroundTruthStudy + A study object. + mode : "ordered" | "snr" | "swarm", default: "ordered" + Which plot mode to use: + + * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy + * "snr": plot performance metrics vs snr + * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) + performance_names : list or tuple, default: ("accuracy", "precision", "recall") + Which performances to plot ("accuracy", "precision", "recall") + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + import pandas as pd + import seaborn as sns + + if case_keys is None: + case_keys = list(study.cases.keys()) + + perfs=study.get_performance_by_unit(case_keys=case_keys) + colors = study.get_colors() + + + if mode in ("ordered", "snr"): + num_axes = len(performance_names) + fig, axs = plt.subplots(ncols=num_axes) + else: + fig, ax = plt.subplots() + + if mode == "ordered": + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + for key in case_keys: + label = study.cases[key]["label"] + val = perfs.xs(key).loc[:, performance_name].values + val = np.sort(val)[::-1] + ax.plot(val, label=label, c=colors[key]) + ax.set_title(performance_name) + if count == len(performance_names) - 1: + ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) + + elif mode == "snr": + metric_name = mode + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + + max_metric = 0 + for key in case_keys: + x = study.get_metrics(key).loc[:, metric_name].values + y = perfs.xs(key).loc[:, performance_name].values + label = study.cases[key]["label"] + ax.scatter(x, y, s=10, label=label, color=colors[key]) + max_metric = max(max_metric, np.max(x)) + ax.set_title(performance_name) + ax.set_xlim(0, max_metric * 1.05) + ax.set_ylim(0, 1.05) + if count == 0: + ax.legend(loc="lower right") + + elif mode == "swarm": + levels = perfs.index.names + df = pd.melt( + perfs.reset_index(), + id_vars=levels, + var_name="Metric", + value_name="Score", + value_vars=performance_names, + ) + df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True, ax=ax) + + +def plot_agreement_matrix(study, ordered=True, case_keys=None): + """ + Plot agreement matri ces for cases in a study. + + Parameters + ---------- + study : GroundTruthStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + ordered : bool + Order units with best agreement scores. + This enable to see agreement on a diagonal. + """ + + import matplotlib.pyplot as plt + from spikeinterface.widgets import AgreementMatrixWidget + + if case_keys is None: + case_keys = list(study.cases.keys()) + + + num_axes = len(case_keys) + fig, axs = plt.subplots(ncols=num_axes) + + for count, key in enumerate(case_keys): + ax = axs.flatten()[count] + comp = study.get_result(key)["gt_comparison"] + + unit_ticks = len(comp.sorting1.unit_ids) <= 16 + count_text = len(comp.sorting1.unit_ids) <= 16 + + AgreementMatrixWidget( + comp, ordered=ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax + ) + label = study.cases[key]["label"] + ax.set_xlabel(label) + + if count > 0: + ax.set_ylabel(None) + ax.set_yticks([]) + ax.set_xticks([]) + diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index 50308f8df7..03ac86d715 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -63,10 +63,10 @@ def test_SorterStudy(setup_module): print(study) # # this run the sorters - study.run() + # study.run() # # this run comparisons - study.compute_results() + # study.compute_results() print(study) # this is from the base class diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 85043d0d12..f32a15e429 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -1,127 +1,60 @@ -from __future__ import annotations +""" +This module will be deprecated and will be removed in 0.102.0 + +All ploting for the previous GTStudy is now centralized in spikeinterface.benchmark.benchmark_plot_tools +Please not that GTStudy is replaced by SorterStudy wich is based more generic BenchmarkStudy. +""" -import numpy as np +from __future__ import annotations -from .base import BaseWidget, to_attr +from .base import BaseWidget +import warnings class StudyRunTimesWidget(BaseWidget): """ - Plot sorter run times for a GroundTruthStudy - + Plot sorter run times for a SorterStudy. Parameters ---------- - study : GroundTruthStudy + study : SorterStudy A study object. case_keys : list or None A selection of cases to plot, if None, then all. """ - def __init__( - self, - study, - case_keys=None, - backend=None, - **backend_kwargs, - ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, run_times=study.get_run_times(case_keys), case_keys=case_keys, colors=study.get_colors() - ) - + def __init__(self, study, case_keys=None, backend=None, **backend_kwargs): + warnings.warn("plot_study_run_times is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.") + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) + from spikeinterface.benchmark.benchmark_plot_tools import plot_run_times + plot_run_times(data_plot["study"], case_keys=data_plot["case_keys"]) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - for i, key in enumerate(dp.case_keys): - label = dp.study.cases[key]["label"] - rt = dp.run_times.loc[key] - self.ax.bar(i, rt, width=0.8, label=label, facecolor=dp.colors[key]) - self.ax.set_ylabel("run time (s)") - self.ax.legend() - - -# TODO : plot optionally average on some levels using group by class StudyUnitCountsWidget(BaseWidget): """ Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" - Parameters ---------- - study : GroundTruthStudy + study : SorterStudy A study object. case_keys : list or None A selection of cases to plot, if None, then all. """ - def __init__( - self, - study, - case_keys=None, - backend=None, - **backend_kwargs, - ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, - count_units=study.get_count_units(case_keys=case_keys), - case_keys=case_keys, - ) - + def __init__(self, study, case_keys=None, backend=None, **backend_kwargs): + warnings.warn("plot_study_unit_counts is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.") + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors - - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - columns = dp.count_units.columns.tolist() - columns.remove("num_gt") - columns.remove("num_sorter") - - ncol = len(columns) - - colors = get_some_colors(columns, color_engine="auto", map_name="hot") - colors["num_well_detected"] = "green" - - xticklabels = [] - for i, key in enumerate(dp.case_keys): - for c, col in enumerate(columns): - x = i + 1 + c / (ncol + 1) - y = dp.count_units.loc[key, col] - if not "well_detected" in col: - y = -y - - if i == 0: - label = col.replace("num_", "").replace("_", " ").title() - else: - label = None - - self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) - - xticklabels.append(dp.study.cases[key]["label"]) - - self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) - self.ax.set_xticklabels(xticklabels) - self.ax.legend() + from spikeinterface.benchmark.benchmark_plot_tools import plot_unit_counts + plot_unit_counts(data_plot["study"], case_keys=data_plot["case_keys"]) class StudyPerformances(BaseWidget): @@ -154,79 +87,23 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - + warnings.warn("plot_study_performances is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.") plot_data = dict( study=study, - perfs=study.get_performance_by_unit(case_keys=case_keys), mode=mode, performance_names=performance_names, case_keys=case_keys, ) - - self.colors = study.get_colors() - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors - - import pandas as pd - import seaborn as sns - - dp = to_attr(data_plot) - perfs = dp.perfs - study = dp.study - - if dp.mode in ("ordered", "snr"): - backend_kwargs["num_axes"] = len(dp.performance_names) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - if dp.mode == "ordered": - for count, performance_name in enumerate(dp.performance_names): - ax = self.axes.flatten()[count] - for key in dp.case_keys: - label = study.cases[key]["label"] - val = perfs.xs(key).loc[:, performance_name].values - val = np.sort(val)[::-1] - ax.plot(val, label=label, c=self.colors[key]) - ax.set_title(performance_name) - if count == len(dp.performance_names) - 1: - ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) - - elif dp.mode == "snr": - metric_name = dp.mode - for count, performance_name in enumerate(dp.performance_names): - ax = self.axes.flatten()[count] - - max_metric = 0 - for key in dp.case_keys: - x = study.get_metrics(key).loc[:, metric_name].values - y = perfs.xs(key).loc[:, performance_name].values - label = study.cases[key]["label"] - ax.scatter(x, y, s=10, label=label, color=self.colors[key]) - max_metric = max(max_metric, np.max(x)) - ax.set_title(performance_name) - ax.set_xlim(0, max_metric * 1.05) - ax.set_ylim(0, 1.05) - if count == 0: - ax.legend(loc="lower right") - - elif dp.mode == "swarm": - levels = perfs.index.names - df = pd.melt( - perfs.reset_index(), - id_vars=levels, - var_name="Metric", - value_name="Score", - value_vars=dp.performance_names, - ) - df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) - sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) - + from spikeinterface.benchmark.benchmark_plot_tools import plot_performances + plot_performances( + data_plot["study"], + mode=data_plot["mode"], + performance_names=data_plot["performance_names"], + case_keys=data_plot["case_keys"] + ) class StudyAgreementMatrix(BaseWidget): """ @@ -251,9 +128,7 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - + warnings.warn("plot_study_agreement_matrix is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.") plot_data = dict( study=study, case_keys=case_keys, @@ -263,36 +138,12 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .comparison import AgreementMatrixWidget - - dp = to_attr(data_plot) - study = dp.study - - backend_kwargs["num_axes"] = len(dp.case_keys) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - for count, key in enumerate(dp.case_keys): - ax = self.axes.flatten()[count] - comp = study.comparisons[key] - unit_ticks = len(comp.sorting1.unit_ids) <= 16 - count_text = len(comp.sorting1.unit_ids) <= 16 - - AgreementMatrixWidget( - comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax - ) - label = study.cases[key]["label"] - ax.set_xlabel(label) - - if count > 0: - ax.set_ylabel(None) - ax.set_yticks([]) - ax.set_xticks([]) - - # ax0 = self.axes.flatten()[0] - # for ax in self.axes.flatten()[1:]: - # ax.sharey(ax0) + from spikeinterface.benchmark.benchmark_plot_tools import plot_agreement_matrix + plot_agreement_matrix( + data_plot["study"], + ordered=data_plot["ordered"], + case_keys=data_plot["case_keys"] + ) class StudySummary(BaseWidget): @@ -320,25 +171,19 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, - case_keys=case_keys, - ) - + + warnings.warn("plot_study_summary is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.") + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - study = data_plot["study"] case_keys = data_plot["case_keys"] - StudyPerformances(study=study, case_keys=case_keys, mode="ordered", backend="matplotlib", **backend_kwargs) - StudyPerformances(study=study, case_keys=case_keys, mode="snr", backend="matplotlib", **backend_kwargs) - StudyAgreementMatrix(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) - StudyRunTimesWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) - StudyUnitCountsWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + from spikeinterface.benchmark.benchmark_plot_tools import plot_agreement_matrix, plot_performances, plot_unit_counts, plot_run_times + + plot_performances(study=study, case_keys=case_keys, mode="ordered") + plot_performances(study=study, case_keys=case_keys, mode="snr") + plot_agreement_matrix(study=study, case_keys=case_keys) + plot_run_times(study=study, case_keys=case_keys) + plot_unit_counts(study=study, case_keys=case_keys)