diff --git a/doc/images/overview.png b/doc/images/overview.png index ea5ba49d08..e367c4b6e4 100644 Binary files a/doc/images/overview.png and b/doc/images/overview.png differ diff --git a/doc/modules/benchmark.rst b/doc/modules/benchmark.rst new file mode 100644 index 0000000000..faf53be790 --- /dev/null +++ b/doc/modules/benchmark.rst @@ -0,0 +1,141 @@ +Benchmark module +================ + +This module contains machinery to compare some sorters against ground truth in many multiple situtation. + + +..notes:: + + In 0.102.0 The previous :py:func:`~spikeinterface.comparison.GroundTruthStudy()` has been replaced by + :py:func:`~spikeinterface.benchmark.SorterStudy()` + + +This module also aims to benchmark sorting components (detection, clustering, motion, template matching) using the +same base class :py:func:`~spikeinterface.benchmark.BenchmarkStudy()` but specialized to a targeted component. + +By design, the main class handle the concept of "levels" : this allows to compare several complexities at the same time. +For instance, compare kilosort4 vs kilsort2.5 (level 0) for different noises amplitudes (level 1) combined with +several motion vectors (leevel 2). + +**Example: compare many sorters : a ground truth study** + +We have a high level class to compare many sorters against ground truth: :py:func:`~spikeinterface.benchmark.SorterStudy()` + + +A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases +like the different parameter sets. + +The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" +on many recordings and then collect and aggregate results in an easy way. + +The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolders: + + * datasets: contains ground truth datasets + * sorters : contains outputs of sorters + * sortings: contains light copy of all sorting + * metrics: contains metrics + * ... + + +.. code-block:: python + + import matplotlib.pyplot as plt + import seaborn as sns + + import spikeinterface.extractors as se + import spikeinterface.widgets as sw + from spikeinterface.benchmark import SorterStudy + + + # generate 2 simulated datasets (could be also mearec files) + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + + datasets = { + "toy0": (rec0, gt_sorting0), + "toy1": (rec1, gt_sorting1), + } + + # define some "cases" here we want to test tridesclous2 on 2 datasets and spykingcircus2 on one dataset + # so it is a two level study (sorter_name, dataset) + # this could be more complicated like (sorter_name, dataset, params) + cases = { + ("tdc2", "toy0"): { + "label": "tridesclous2 on tetrode0", + "dataset": "toy0", + "params": {"sorter_name": "tridesclous2"} + }, + ("tdc2", "toy1"): { + "label": "tridesclous2 on tetrode1", + "dataset": "toy1", + "params": {"sorter_name": "tridesclous2"} + }, + ("sc", "toy0"): { + "label": "spykingcircus2 on tetrode0", + "dataset": "toy0", + "params": { + "sorter_name": "spykingcircus", + "docker_image": True + }, + }, + } + # this initilizes a folder + study = SorterStudy.create(study_folder=study_folder, datasets=datasets, cases=cases, + levels=["sorter_name", "dataset"]) + + + # This internally do run_sorter() for all cases in one function + study.run() + + # Run the benchmark : this internanly do compare_sorter_to_ground_truth() for all cases + study.compute_results() + + # Collect comparisons one by one + for case_key in study.cases: + print('*' * 10) + print(case_key) + # raw counting of tp/fp/... + comp = study.get_result(case_key)["gt_comparison"] + # summary + comp.print_summary() + perf_unit = comp.get_performance(method='by_unit') + perf_avg = comp.get_performance(method='pooled_with_average') + # some plots + m = comp.get_confusion_matrix() + w_comp = sw.plot_agreement_matrix(sorting_comparison=comp) + + # Collect synthetic dataframes and display + # As shown previously, the performance is returned as a pandas dataframe. + # The spikeinterface.comparison.get_performance_by_unit() function, + # gathers all the outputs in the study folder and merges them into a single dataframe. + # Same idea for spikeinterface.comparison.get_count_units() + + # this is a dataframe + perfs = study.get_performance_by_unit() + + # this is a dataframe + unit_counts = study.get_count_units() + + # Study also have several plotting methods for plotting the result + study.plot_agreement_matrix() + study.plot_unit_counts() + study.plot_performances(mode="ordered") + study.plot_performances(mode="snr") + + + + +Benchmark spike collisions +-------------------------- + +SpikeInterface also has a specific toolset to benchmark how well sorters are at recovering spikes in "collision". + +We have three classes to handle collision-specific comparisons, and also to quantify the effects on correlogram +estimation: + + * :py:class:`~spikeinterface.comparison.CollisionGTComparison` + * :py:class:`~spikeinterface.comparison.CorrelogramGTComparison` + +For more details, checkout the following paper: + +`Samuel Garcia, Alessio P. Buccino and Pierre Yger. "How Do Spike Collisions Affect Spike Sorting Performance?" `_ diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index edee7f1fda..a02d76664d 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -5,6 +5,10 @@ Comparison module SpikeInterface has a :py:mod:`~spikeinterface.comparison` module, which contains functions and tools to compare spike trains and templates (useful for tracking units over multiple sessions). +.. note:: + + In version 0.102.0 the benchmark part of comparison has moved in the new :py:mod:`~spikeinterface.benchmark` + In addition, the :py:mod:`~spikeinterface.comparison` module contains advanced benchmarking tools to evaluate the effects of spike collisions on spike sorting results, and to construct hybrid recordings for comparison. @@ -242,135 +246,6 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for cmp_gt_HS.get_redundant_units(redundant_score=0.2) - -**Example: compare many sorters with a Ground Truth Study** - -We also have a high level class to compare many sorters against ground truth: -:py:func:`~spikeinterface.comparison.GroundTruthStudy()` - -A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases -like the different parameter sets. - -The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" -on many recordings and then collect and aggregate results in an easy way. - -The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolders: - - * datasets: contains ground truth datasets - * sorters : contains outputs of sorters - * sortings: contains light copy of all sorting - * metrics: contains metrics - * ... - - -.. code-block:: python - - import matplotlib.pyplot as plt - import seaborn as sns - - import spikeinterface.extractors as se - import spikeinterface.widgets as sw - from spikeinterface.comparison import GroundTruthStudy - - - # generate 2 simulated datasets (could be also mearec files) - rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) - rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) - - datasets = { - "toy0": (rec0, gt_sorting0), - "toy1": (rec1, gt_sorting1), - } - - # define some "cases" here we want to test tridesclous2 on 2 datasets and spykingcircus2 on one dataset - # so it is a two level study (sorter_name, dataset) - # this could be more complicated like (sorter_name, dataset, params) - cases = { - ("tdc2", "toy0"): { - "label": "tridesclous2 on tetrode0", - "dataset": "toy0", - "run_sorter_params": { - "sorter_name": "tridesclous2", - }, - }, - ("tdc2", "toy1"): { - "label": "tridesclous2 on tetrode1", - "dataset": "toy1", - "run_sorter_params": { - "sorter_name": "tridesclous2", - }, - }, - - ("sc", "toy0"): { - "label": "spykingcircus2 on tetrode0", - "dataset": "toy0", - "run_sorter_params": { - "sorter_name": "spykingcircus", - "docker_image": True - }, - }, - } - # this initilizes a folder - study = GroundTruthStudy.create(study_folder=study_folder, datasets=datasets, cases=cases, - levels=["sorter_name", "dataset"]) - - - # all cases in one function - study.run_sorters() - - # Collect comparisons - # - # You can collect in one shot all results and run the - # GroundTruthComparison on it. - # So you can have fine access to all individual results. - # - # Note: use exhaustive_gt=True when you know exactly how many - # units in the ground truth (for synthetic datasets) - - # run all comparisons and loop over the results - study.run_comparisons(exhaustive_gt=True) - for key, comp in study.comparisons.items(): - print('*' * 10) - print(key) - # raw counting of tp/fp/... - print(comp.count_score) - # summary - comp.print_summary() - perf_unit = comp.get_performance(method='by_unit') - perf_avg = comp.get_performance(method='pooled_with_average') - # some plots - m = comp.get_confusion_matrix() - w_comp = sw.plot_agreement_matrix(sorting_comparison=comp) - - # Collect synthetic dataframes and display - # As shown previously, the performance is returned as a pandas dataframe. - # The spikeinterface.comparison.get_performance_by_unit() function, - # gathers all the outputs in the study folder and merges them into a single dataframe. - # Same idea for spikeinterface.comparison.get_count_units() - - # this is a dataframe - perfs = study.get_performance_by_unit() - - # this is a dataframe - unit_counts = study.get_count_units() - - # we can also access run times - run_times = study.get_run_times() - print(run_times) - - # Easy plotting with seaborn - fig1, ax1 = plt.subplots() - sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1) - ax1.set_title('Run times') - - ############################################################################## - - fig2, ax2 = plt.subplots() - sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2) - ax2.set_title('Recall') - ax2.set_ylim(-0.1, 1.1) - - .. _symmetric: 2. Compare the output of two spike sorters (symmetric comparison) @@ -537,35 +412,3 @@ sorting analyzers from day 1 (:code:`analyzer_day1`) to day 5 (:code:`analyzer_d # match all m_tcmp = sc.compare_multiple_templates(waveform_list=analyzer_list, name_list=["D1", "D2", "D3", "D4", "D5"]) - - - -Benchmark spike collisions --------------------------- - -SpikeInterface also has a specific toolset to benchmark how well sorters are at recovering spikes in "collision". - -We have three classes to handle collision-specific comparisons, and also to quantify the effects on correlogram -estimation: - - * :py:class:`~spikeinterface.comparison.CollisionGTComparison` - * :py:class:`~spikeinterface.comparison.CorrelogramGTComparison` - * :py:class:`~spikeinterface.comparison.CollisionGTStudy` - * :py:class:`~spikeinterface.comparison.CorrelogramGTStudy` - -For more details, checkout the following paper: - -`Samuel Garcia, Alessio P. Buccino and Pierre Yger. "How Do Spike Collisions Affect Spike Sorting Performance?" `_ - - -Hybrid recording ----------------- - -To benchmark spike sorting results, we need ground-truth spiking activity. -This can be generated with artificial simulations, e.g., using `MEArec `_, or -alternatively by generating so-called "hybrid" recordings. - -The :py:mod:`~spikeinterface.comparison` module includes functions to generate such "hybrid" recordings: - - * :py:func:`~spikeinterface.comparison.create_hybrid_units_recording`: add new units to an existing recording - * :py:func:`~spikeinterface.comparison.create_hybrid_spikes_recording`: add new spikes to existing units in a recording diff --git a/src/spikeinterface/benchmark/__init__.py b/src/spikeinterface/benchmark/__init__.py new file mode 100644 index 0000000000..3cf0c6a6f6 --- /dev/null +++ b/src/spikeinterface/benchmark/__init__.py @@ -0,0 +1,7 @@ +""" +Module to benchmark: + * sorters + * some sorting components (clustering, motion, template matching) +""" + +from .benchmark_sorter import SorterStudy diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/benchmark/benchmark_base.py similarity index 95% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py rename to src/spikeinterface/benchmark/benchmark_base.py index 4d6dd43bce..b9cbf269c8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -131,7 +131,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): return cls(study_folder) - def create_benchmark(self): + def create_benchmark(self, key): raise NotImplementedError def scan_folder(self): @@ -258,25 +258,9 @@ def get_run_times(self, case_keys=None): return df def plot_run_times(self, case_keys=None): - if case_keys is None: - case_keys = list(self.cases.keys()) - run_times = self.get_run_times(case_keys=case_keys) - - colors = self.get_colors() - import matplotlib.pyplot as plt + from .benchmark_plot_tools import plot_run_times - fig, ax = plt.subplots() - labels = [] - for i, key in enumerate(case_keys): - labels.append(self.cases[key]["label"]) - rt = run_times.at[key, "run_times"] - ax.bar(i, rt, width=0.8, color=colors[key]) - ax.set_xticks(np.arange(len(case_keys))) - ax.set_xticklabels(labels, rotation=45.0) - return fig - - # ax = run_times.plot(kind="bar") - # return ax.figure + return plot_run_times(self, case_keys=case_keys) def compute_results(self, case_keys=None, verbose=False, **result_params): if case_keys is None: @@ -462,10 +446,3 @@ def run(self): def compute_result(self): # run becnhmark result raise NotImplementedError - - -def _simpleaxis(ax): - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.get_xaxis().tick_bottom() - ax.get_yaxis().tick_left() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py similarity index 92% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py rename to src/spikeinterface/benchmark/benchmark_clustering.py index 92fcda35d9..1c731ecb64 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -11,8 +11,7 @@ import numpy as np - -from .benchmark_tools import BenchmarkStudy, Benchmark +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel @@ -161,49 +160,21 @@ def get_count_units(self, case_keys=None, well_detected_score=None, redundant_sc return count_units - def plot_unit_counts(self, case_keys=None, figsize=None, **extra_kwargs): - from spikeinterface.widgets.widget_list import plot_study_unit_counts + # plotting by methods + def plot_unit_counts(self, **kwargs): + from .benchmark_plot_tools import plot_unit_counts - plot_study_unit_counts(self, case_keys, figsize=figsize, **extra_kwargs) + return plot_unit_counts(self, **kwargs) - def plot_agreements(self, case_keys=None, figsize=(15, 15)): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + return plot_agreement_matrix(self, **kwargs) - for count, key in enumerate(case_keys): - ax = axs[0, count] - ax.set_title(self.cases[key]["label"]) - plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + def plot_performances_vs_snr(self, **kwargs): + from .benchmark_plot_tools import plot_performances_vs_snr - return fig - - def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt - - fig, axes = plt.subplots(ncols=1, nrows=3, figsize=figsize) - - for count, k in enumerate(("accuracy", "recall", "precision")): - - ax = axes[count] - for key in case_keys: - label = self.cases[key]["label"] - - analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension("quality_metrics").get_data() - x = metrics["snr"].values - y = self.get_result(key)["gt_comparison"].get_performance()[k].values - ax.scatter(x, y, marker=".", label=label) - ax.set_title(k) - - if count == 2: - ax.legend() - - return fig + return plot_performances_vs_snr(self, **kwargs) def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py similarity index 84% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py rename to src/spikeinterface/benchmark/benchmark_matching.py index ab1523d13a..c53567f460 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -9,11 +9,8 @@ ) import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.sortingcomponents.tools import remove_empty_templates -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.sparsity import compute_sparsity class MatchingBenchmark(Benchmark): @@ -64,41 +61,15 @@ def create_benchmark(self, key): benchmark = MatchingBenchmark(recording, gt_sorting, params) return benchmark - def plot_agreements(self, case_keys=None, figsize=None): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + return plot_agreement_matrix(self, **kwargs) - for count, key in enumerate(case_keys): - ax = axs[0, count] - ax.set_title(self.cases[key]["label"]) - plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + def plot_performances_vs_snr(self, **kwargs): + from .benchmark_plot_tools import plot_performances_vs_snr - def plot_performances_vs_snr(self, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"]): - if case_keys is None: - case_keys = list(self.cases.keys()) - - fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False) - - for count, k in enumerate(metrics): - - ax = axs[count, 0] - for key in case_keys: - label = self.cases[key]["label"] - - analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension("quality_metrics").get_data() - x = metrics["snr"].values - y = self.get_result(key)["gt_comparison"].get_performance()[k].values - ax.scatter(x, y, marker=".", label=label) - ax.set_title(k) - - if count == 2: - ax.legend() - - return fig + return plot_performances_vs_snr(self, **kwargs) def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/benchmark/benchmark_motion_estimation.py similarity index 99% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py rename to src/spikeinterface/benchmark/benchmark_motion_estimation.py index ec7e1e24a8..abb2a51bae 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_estimation.py @@ -8,7 +8,8 @@ import numpy as np from spikeinterface.core import get_noise_levels -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis +from .benchmark_base import Benchmark, BenchmarkStudy +from .benchmark_plot_tools import _simpleaxis from spikeinterface.sortingcomponents.motion import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/benchmark/benchmark_motion_interpolation.py similarity index 98% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py rename to src/spikeinterface/benchmark/benchmark_motion_interpolation.py index 38365adfd1..ab72a1f9bd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_interpolation.py @@ -10,7 +10,7 @@ from spikeinterface.curation import MergeUnitsSorting -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis +from .benchmark_base import Benchmark, BenchmarkStudy class MotionInterpolationBenchmark(Benchmark): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/benchmark/benchmark_peak_detection.py similarity index 98% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py rename to src/spikeinterface/benchmark/benchmark_peak_detection.py index 7d862343d2..77b5e0025c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_detection.py @@ -12,10 +12,9 @@ import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel class PeakDetectionBenchmark(Benchmark): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/benchmark/benchmark_peak_localization.py similarity index 99% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py rename to src/spikeinterface/benchmark/benchmark_peak_localization.py index 05d142113b..399729fa29 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/benchmark_peak_localization.py @@ -6,7 +6,7 @@ compute_grid_convolution, ) import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/benchmark/benchmark_peak_selection.py similarity index 98% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py rename to src/spikeinterface/benchmark/benchmark_peak_selection.py index 008de2d931..41edea156f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_selection.py @@ -6,15 +6,9 @@ from spikeinterface.comparison.comparisontools import make_matching_events from spikeinterface.core import get_noise_levels -import time -import string, random -import pylab as plt -import os import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy -from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from .benchmark_base import Benchmark, BenchmarkStudy class PeakSelectionBenchmark(Benchmark): diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py new file mode 100644 index 0000000000..a6e9b6dacc --- /dev/null +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -0,0 +1,243 @@ +import numpy as np + + +def _simpleaxis(ax): + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.get_xaxis().tick_bottom() + ax.get_yaxis().tick_left() + + +def plot_run_times(study, case_keys=None): + """ + Plot run times for a BenchmarkStudy. + + Parameters + ---------- + study : SorterStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + run_times = study.get_run_times(case_keys=case_keys) + + colors = study.get_colors() + + fig, ax = plt.subplots() + labels = [] + for i, key in enumerate(case_keys): + labels.append(study.cases[key]["label"]) + rt = run_times.at[key, "run_times"] + ax.bar(i, rt, width=0.8, color=colors[key]) + ax.set_xticks(np.arange(len(case_keys))) + ax.set_xticklabels(labels, rotation=45.0) + return fig + + +def plot_unit_counts(study, case_keys=None): + """ + Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" + + Parameters + ---------- + study : SorterStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + from spikeinterface.widgets.utils import get_some_colors + + if case_keys is None: + case_keys = list(study.cases.keys()) + + count_units = study.get_count_units(case_keys=case_keys) + + fig, ax = plt.subplots() + + columns = count_units.columns.tolist() + columns.remove("num_gt") + columns.remove("num_sorter") + + ncol = len(columns) + + colors = get_some_colors(columns, color_engine="auto", map_name="hot") + colors["num_well_detected"] = "green" + + xticklabels = [] + for i, key in enumerate(case_keys): + for c, col in enumerate(columns): + x = i + 1 + c / (ncol + 1) + y = count_units.loc[key, col] + if not "well_detected" in col: + y = -y + + if i == 0: + label = col.replace("num_", "").replace("_", " ").title() + else: + label = None + + ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) + + xticklabels.append(study.cases[key]["label"]) + + ax.set_xticks(np.arange(len(case_keys)) + 1) + ax.set_xticklabels(xticklabels) + ax.legend() + + return fig + + +def plot_performances(study, mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None): + """ + Plot performances over case for a study. + + Parameters + ---------- + study : GroundTruthStudy + A study object. + mode : "ordered" | "snr" | "swarm", default: "ordered" + Which plot mode to use: + + * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy + * "snr": plot performance metrics vs snr + * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) + performance_names : list or tuple, default: ("accuracy", "precision", "recall") + Which performances to plot ("accuracy", "precision", "recall") + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + import pandas as pd + import seaborn as sns + + if case_keys is None: + case_keys = list(study.cases.keys()) + + perfs = study.get_performance_by_unit(case_keys=case_keys) + colors = study.get_colors() + + if mode in ("ordered", "snr"): + num_axes = len(performance_names) + fig, axs = plt.subplots(ncols=num_axes) + else: + fig, ax = plt.subplots() + + if mode == "ordered": + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + for key in case_keys: + label = study.cases[key]["label"] + val = perfs.xs(key).loc[:, performance_name].values + val = np.sort(val)[::-1] + ax.plot(val, label=label, c=colors[key]) + ax.set_title(performance_name) + if count == len(performance_names) - 1: + ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) + + elif mode == "snr": + metric_name = mode + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + + max_metric = 0 + for key in case_keys: + x = study.get_metrics(key).loc[:, metric_name].values + y = perfs.xs(key).loc[:, performance_name].values + label = study.cases[key]["label"] + ax.scatter(x, y, s=10, label=label, color=colors[key]) + max_metric = max(max_metric, np.max(x)) + ax.set_title(performance_name) + ax.set_xlim(0, max_metric * 1.05) + ax.set_ylim(0, 1.05) + if count == 0: + ax.legend(loc="lower right") + + elif mode == "swarm": + levels = perfs.index.names + df = pd.melt( + perfs.reset_index(), + id_vars=levels, + var_name="Metric", + value_name="Score", + value_vars=performance_names, + ) + df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True, ax=ax) + + +def plot_agreement_matrix(study, ordered=True, case_keys=None): + """ + Plot agreement matri ces for cases in a study. + + Parameters + ---------- + study : GroundTruthStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + ordered : bool + Order units with best agreement scores. + This enable to see agreement on a diagonal. + """ + + import matplotlib.pyplot as plt + from spikeinterface.widgets import AgreementMatrixWidget + + if case_keys is None: + case_keys = list(study.cases.keys()) + + num_axes = len(case_keys) + fig, axs = plt.subplots(ncols=num_axes) + + for count, key in enumerate(case_keys): + ax = axs.flatten()[count] + comp = study.get_result(key)["gt_comparison"] + + unit_ticks = len(comp.sorting1.unit_ids) <= 16 + count_text = len(comp.sorting1.unit_ids) <= 16 + + AgreementMatrixWidget( + comp, ordered=ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax + ) + label = study.cases[key]["label"] + ax.set_xlabel(label) + + if count > 0: + ax.set_ylabel(None) + ax.set_yticks([]) + ax.set_xticks([]) + + +def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"]): + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False) + + for count, k in enumerate(metrics): + + ax = axs[count, 0] + for key in case_keys: + label = study.cases[key]["label"] + + analyzer = study.get_sorting_analyzer(key) + metrics = analyzer.get_extension("quality_metrics").get_data() + x = metrics["snr"].values + y = study.get_result(key)["gt_comparison"].get_performance()[k].values + ax.scatter(x, y, marker=".", label=label) + ax.set_title(k) + + ax.set_ylim(0, 1.05) + + if count == 2: + ax.legend() + + return fig diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py new file mode 100644 index 0000000000..f9267c785a --- /dev/null +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -0,0 +1,135 @@ +""" +This replace the previous `GroundTruthStudy` +""" + +import numpy as np +from ..core import NumpySorting +from .benchmark_base import Benchmark, BenchmarkStudy +from ..sorters import run_sorter +from spikeinterface.comparison import compare_sorter_to_ground_truth + + +# TODO later integrate CollisionGTComparison optionally in this class. + + +class SorterBenchmark(Benchmark): + def __init__(self, recording, gt_sorting, params, sorter_folder): + self.recording = recording + self.gt_sorting = gt_sorting + self.params = params + self.sorter_folder = sorter_folder + self.result = {} + + def run(self): + # run one sorter sorter_name is must be in params + raw_sorting = run_sorter(recording=self.recording, folder=self.sorter_folder, **self.params) + sorting = NumpySorting.from_sorting(raw_sorting) + self.result = {"sorting": sorting} + + def compute_result(self): + # run becnhmark result + sorting = self.result["sorting"] + comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) + self.result["gt_comparison"] = comp + + _run_key_saved = [ + ("sorting", "sorting"), + ] + _result_key_saved = [ + ("gt_comparison", "pickle"), + ] + + +class SorterStudy(BenchmarkStudy): + """ + This class is used to tests several sorter in several situtation. + This replace the previous GroundTruthStudy with more flexibility. + """ + + benchmark_class = SorterBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + benchmark = SorterBenchmark(recording, gt_sorting, params, sorter_folder) + return benchmark + + def get_performance_by_unit(self, case_keys=None): + import pandas as pd + + if case_keys is None: + case_keys = self.cases.keys() + + perf_by_unit = [] + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + + perf = comp.get_performance(method="by_unit", output="pandas") + + if isinstance(key, str): + perf[self.levels] = key + elif isinstance(key, tuple): + for col, k in zip(self.levels, key): + perf[col] = k + + perf = perf.reset_index() + perf_by_unit.append(perf) + + perf_by_unit = pd.concat(perf_by_unit) + perf_by_unit = perf_by_unit.set_index(self.levels) + perf_by_unit = perf_by_unit.sort_index() + return perf_by_unit + + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) + else: + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) + + columns = ["num_gt", "num_sorter", "num_well_detected"] + key0 = case_keys[0] + comp = self.get_result(key0)["gt_comparison"] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) + + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units + + # plotting as methods + def plot_unit_counts(self, **kwargs): + from .benchmark_plot_tools import plot_unit_counts + + return plot_unit_counts(self, **kwargs) + + def plot_performances(self, **kwargs): + from .benchmark_plot_tools import plot_performances + + return plot_performances(self, **kwargs) + + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix + + return plot_agreement_matrix(self, **kwargs) diff --git a/src/spikeinterface/benchmark/benchmark_tools.py b/src/spikeinterface/benchmark/benchmark_tools.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py similarity index 100% rename from src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py rename to src/spikeinterface/benchmark/tests/common_benchmark_testing.py diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py similarity index 88% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py rename to src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index bc36fb607c..3f574fd058 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -3,11 +3,13 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset -from spikeinterface.sortingcomponents.benchmark.benchmark_clustering import ClusteringStudy +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset +from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel +from pathlib import Path + @pytest.mark.skip() def test_benchmark_clustering(create_cache_folder): @@ -78,4 +80,5 @@ def test_benchmark_clustering(create_cache_folder): if __name__ == "__main__": - test_benchmark_clustering() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_clustering(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py similarity index 86% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py rename to src/spikeinterface/benchmark/tests/test_benchmark_matching.py index 71a5f282a8..000a00faf5 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py @@ -1,6 +1,7 @@ import pytest import shutil +from pathlib import Path from spikeinterface.core import ( @@ -8,11 +9,11 @@ compute_sparsity, ) -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( +from spikeinterface.benchmark.tests.common_benchmark_testing import ( make_dataset, compute_gt_templates, ) -from spikeinterface.sortingcomponents.benchmark.benchmark_matching import MatchingStudy +from spikeinterface.benchmark.benchmark_matching import MatchingStudy @pytest.mark.skip() @@ -72,4 +73,5 @@ def test_benchmark_matching(create_cache_folder): if __name__ == "__main__": - test_benchmark_matching() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_matching(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/benchmark/tests/test_benchmark_motion_estimation.py similarity index 86% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py rename to src/spikeinterface/benchmark/tests/test_benchmark_motion_estimation.py index 78a9eb7dbc..65cacfc8a0 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_motion_estimation.py @@ -2,12 +2,13 @@ import shutil +from pathlib import Path -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( +from spikeinterface.benchmark.tests.common_benchmark_testing import ( make_drifting_dataset, ) -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import MotionEstimationStudy +from spikeinterface.benchmark.benchmark_motion_estimation import MotionEstimationStudy @pytest.mark.skip() @@ -75,4 +76,5 @@ def test_benchmark_motion_estimaton(create_cache_folder): if __name__ == "__main__": - test_benchmark_motion_estimaton() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_motion_estimaton(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py similarity index 90% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py rename to src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py index 18def37d54..f7afd7a8bc 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py @@ -4,14 +4,14 @@ import numpy as np import shutil +from pathlib import Path - -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( +from spikeinterface.benchmark.tests.common_benchmark_testing import ( make_drifting_dataset, ) -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import ( +from spikeinterface.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy +from spikeinterface.benchmark.benchmark_motion_estimation import ( # get_unit_displacement, get_gt_motion_from_unit_displacement, ) @@ -139,4 +139,5 @@ def test_benchmark_motion_interpolation(create_cache_folder): if __name__ == "__main__": - test_benchmark_motion_interpolation() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_motion_interpolation(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py similarity index 87% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py rename to src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index dffe1529b7..d45ac0b4ce 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -1,10 +1,10 @@ import pytest import shutil +from pathlib import Path - -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset -from spikeinterface.sortingcomponents.benchmark.benchmark_peak_detection import PeakDetectionStudy +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset +from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel @@ -69,5 +69,5 @@ def test_benchmark_peak_detection(create_cache_folder): if __name__ == "__main__": - # test_benchmark_peak_localization() - test_benchmark_peak_detection() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_peak_detection(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py similarity index 79% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py rename to src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py index 23060c4ddb..3b6240cb10 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py @@ -1,12 +1,12 @@ import pytest import shutil +from pathlib import Path +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset - -from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import PeakLocalizationStudy -from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import UnitLocalizationStudy +from spikeinterface.benchmark.benchmark_peak_localization import PeakLocalizationStudy +from spikeinterface.benchmark.benchmark_peak_localization import UnitLocalizationStudy @pytest.mark.skip() @@ -28,7 +28,8 @@ def test_benchmark_peak_localization(create_cache_folder): "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, "params": { "method": method, - "method_kwargs": {"ms_before": 2}, + "ms_before": 2.0, + "method_kwargs": {}, }, } @@ -60,7 +61,7 @@ def test_benchmark_unit_locations(create_cache_folder): cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") - recording, gt_sorting = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset() # create study study_folder = cache_folder / "study_unit_locations" @@ -71,7 +72,7 @@ def test_benchmark_unit_locations(create_cache_folder): "label": f"{method} on toy", "dataset": "toy", "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, - "params": {"method": method, "method_kwargs": {"ms_before": 2}}, + "params": {"method": method, "ms_before": 2.0, "method_kwargs": {}}, } if study_folder.exists(): @@ -99,5 +100,6 @@ def test_benchmark_unit_locations(create_cache_folder): if __name__ == "__main__": - # test_benchmark_peak_localization() - test_benchmark_unit_locations() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + # test_benchmark_peak_localization(cache_folder) + test_benchmark_unit_locations(cache_folder) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_selection.py new file mode 100644 index 0000000000..a6eb090a9d --- /dev/null +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_selection.py @@ -0,0 +1,13 @@ +import pytest + +from pathlib import Path + + +@pytest.mark.skip() +def test_benchmark_peak_selection(create_cache_folder): + cache_folder = create_cache_folder + + +if __name__ == "__main__": + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_peak_selection(cache_folder) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py similarity index 57% rename from src/spikeinterface/comparison/tests/test_groundtruthstudy.py rename to src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index a92d6e9f77..2564d58d52 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -4,12 +4,12 @@ from spikeinterface import generate_ground_truth_recording from spikeinterface.preprocessing import bandpass_filter -from spikeinterface.comparison import GroundTruthStudy +from spikeinterface.benchmark import SorterStudy @pytest.fixture(scope="module") def setup_module(tmp_path_factory): - study_folder = tmp_path_factory.mktemp("study_folder") + study_folder = tmp_path_factory.mktemp("sorter_study_folder") if study_folder.is_dir(): shutil.rmtree(study_folder) create_a_study(study_folder) @@ -36,63 +36,53 @@ def create_a_study(study_folder): ("tdc2", "no-preprocess", "tetrode"): { "label": "tridesclous2 without preprocessing and standard params", "dataset": "toy_tetrode", - "run_sorter_params": { + "params": { "sorter_name": "tridesclous2", }, - "comparison_params": {}, }, # ("tdc2", "with-preprocess", "probe32"): { "label": "tridesclous2 with preprocessing standar params", "dataset": "toy_probe32_preprocess", - "run_sorter_params": { + "params": { "sorter_name": "tridesclous2", }, - "comparison_params": {}, }, - # we comment this at the moement because SC2 is quite slow for testing - # ("sc2", "no-preprocess", "tetrode"): { - # "label": "spykingcircus2 without preprocessing standar params", - # "dataset": "toy_tetrode", - # "run_sorter_params": { - # "sorter_name": "spykingcircus2", - # }, - # "comparison_params": { - # }, - # }, } - study = GroundTruthStudy.create( + study = SorterStudy.create( study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"] ) # print(study) -def test_GroundTruthStudy(setup_module): +def test_SorterStudy(setup_module): + # job_kwargs = dict(n_jobs=2, chunk_duration="1s") + study_folder = setup_module - study = GroundTruthStudy(study_folder) + study = SorterStudy(study_folder) print(study) - study.run_sorters(verbose=True) - - print(study.sortings) - - print(study.comparisons) - study.run_comparisons() - print(study.comparisons) + # # this run the sorters + # study.run() - study.create_sorting_analyzer_gt(n_jobs=-1) - - study.compute_metrics() + # # this run comparisons + # study.compute_results() + print(study) - for key in study.cases: - metrics = study.get_metrics(key) - print(metrics) + # this is from the base class + rt = study.get_run_times() + # rt = study.plot_run_times() + # import matplotlib.pyplot as plt + # plt.show() - study.get_performance_by_unit() - study.get_count_units() + perf_by_unit = study.get_performance_by_unit() + # print(perf_by_unit) + count_units = study.get_count_units() + # print(count_units) if __name__ == "__main__": - setup_module() - test_GroundTruthStudy() + study_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy" + create_a_study(study_folder) + test_SorterStudy(study_folder) diff --git a/src/spikeinterface/comparison/__init__.py b/src/spikeinterface/comparison/__init__.py index 648ef4ed70..f4ada19f73 100644 --- a/src/spikeinterface/comparison/__init__.py +++ b/src/spikeinterface/comparison/__init__.py @@ -30,8 +30,8 @@ ) from .groundtruthstudy import GroundTruthStudy -from .collision import CollisionGTComparison, CollisionGTStudy -from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy +from .collision import CollisionGTComparison +from .correlogram import CorrelogramGTComparison from .hybrid import ( HybridSpikesRecording, diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 574bd16093..12bfab84ed 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -172,71 +172,75 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good return similarities, recall_scores, pair_names -class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - _kwargs = dict() - _kwargs.update(kwargs) - _kwargs["exhaustive_gt"] = exhaustive_gt - _kwargs["collision_lag"] = collision_lag - _kwargs["nbins"] = nbins - GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) - self.exhaustive_gt = exhaustive_gt - self.collision_lag = collision_lag - - def get_lags(self, key): - comp = self.comparisons[key] - fs = comp.sorting1.get_sampling_frequency() - lags = comp.bins / fs * 1000.0 - return lags - - def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): - import sklearn - - if case_keys is None: - case_keys = self.cases.keys() - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for key in case_keys: - templates = self.get_templates(key) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - comp = self.comparisons[key] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity, good_only=good_only, min_accuracy=min_accuracy - ) - self.all_similarities[key] = similarities - self.all_recall_scores[key] = recall_scores - - def get_mean_over_similarity_range(self, similarity_range, key): - idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) - all_similarities = self.all_similarities[key][idx] - all_recall_scores = self.all_recall_scores[key][idx] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = np.nanmean(all_recall_scores, axis=0) - - return mean_recall_scores - - def get_lag_profile_over_similarity_bins(self, similarity_bins, key): - all_similarities = self.all_similarities[key] - all_recall_scores = self.all_recall_scores[key] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) - result[(cmin, cmax)] = mean_recall_scores - - return result +# This is removed at the moment. +# We need to move this maybe one day in benchmark. +# please do not delete this + +# class CollisionGTStudy(GroundTruthStudy): +# def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): +# _kwargs = dict() +# _kwargs.update(kwargs) +# _kwargs["exhaustive_gt"] = exhaustive_gt +# _kwargs["collision_lag"] = collision_lag +# _kwargs["nbins"] = nbins +# GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) +# self.exhaustive_gt = exhaustive_gt +# self.collision_lag = collision_lag + +# def get_lags(self, key): +# comp = self.comparisons[key] +# fs = comp.sorting1.get_sampling_frequency() +# lags = comp.bins / fs * 1000.0 +# return lags + +# def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): +# import sklearn + +# if case_keys is None: +# case_keys = self.cases.keys() + +# self.all_similarities = {} +# self.all_recall_scores = {} +# self.good_only = good_only + +# for key in case_keys: +# templates = self.get_templates(key) +# flat_templates = templates.reshape(templates.shape[0], -1) +# similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) +# comp = self.comparisons[key] +# similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( +# similarity, good_only=good_only, min_accuracy=min_accuracy +# ) +# self.all_similarities[key] = similarities +# self.all_recall_scores[key] = recall_scores + +# def get_mean_over_similarity_range(self, similarity_range, key): +# idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) +# all_similarities = self.all_similarities[key][idx] +# all_recall_scores = self.all_recall_scores[key][idx] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_recall_scores = all_recall_scores[order, :] + +# mean_recall_scores = np.nanmean(all_recall_scores, axis=0) + +# return mean_recall_scores + +# def get_lag_profile_over_similarity_bins(self, similarity_bins, key): +# all_similarities = self.all_similarities[key] +# all_recall_scores = self.all_recall_scores[key] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_recall_scores = all_recall_scores[order, :] + +# result = {} + +# for i in range(similarity_bins.size - 1): +# cmin, cmax = similarity_bins[i], similarity_bins[i + 1] +# amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) +# mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) +# result[(cmin, cmax)] = mean_recall_scores + +# return result diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 0cafef2c12..717d11a3fa 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -128,57 +128,60 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): return similarities, errors -class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons( - self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs - ): - _kwargs = dict() - _kwargs.update(kwargs) - _kwargs["exhaustive_gt"] = exhaustive_gt - _kwargs["window_ms"] = window_ms - _kwargs["bin_ms"] = bin_ms - _kwargs["well_detected_score"] = well_detected_score - GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) - self.exhaustive_gt = exhaustive_gt - - @property - def time_bins(self): - for key, value in self.comparisons.items(): - return value.time_bins - - def precompute_scores_by_similarities(self, case_keys=None, good_only=True): - import sklearn.metrics - - if case_keys is None: - case_keys = self.cases.keys() - - self.all_similarities = {} - self.all_errors = {} - - for key in case_keys: - templates = self.get_templates(key) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - comp = self.comparisons[key] - similarities, errors = comp.compute_correlogram_by_similarity(similarity) - - self.all_similarities[key] = similarities - self.all_errors[key] = errors - - def get_error_profile_over_similarity_bins(self, similarity_bins, key): - all_similarities = self.all_similarities[key] - all_errors = self.all_errors[key] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_errors = np.nanmean(all_errors[amin:amax], axis=0) - result[(cmin, cmax)] = mean_errors - - return result +# This is removed at the moment. +# We need to move this maybe one day in benchmark + +# class CorrelogramGTStudy(GroundTruthStudy): +# def run_comparisons( +# self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs +# ): +# _kwargs = dict() +# _kwargs.update(kwargs) +# _kwargs["exhaustive_gt"] = exhaustive_gt +# _kwargs["window_ms"] = window_ms +# _kwargs["bin_ms"] = bin_ms +# _kwargs["well_detected_score"] = well_detected_score +# GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) +# self.exhaustive_gt = exhaustive_gt + +# @property +# def time_bins(self): +# for key, value in self.comparisons.items(): +# return value.time_bins + +# def precompute_scores_by_similarities(self, case_keys=None, good_only=True): +# import sklearn.metrics + +# if case_keys is None: +# case_keys = self.cases.keys() + +# self.all_similarities = {} +# self.all_errors = {} + +# for key in case_keys: +# templates = self.get_templates(key) +# flat_templates = templates.reshape(templates.shape[0], -1) +# similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) +# comp = self.comparisons[key] +# similarities, errors = comp.compute_correlogram_by_similarity(similarity) + +# self.all_similarities[key] = similarities +# self.all_errors[key] = errors + +# def get_error_profile_over_similarity_bins(self, similarity_bins, key): +# all_similarities = self.all_similarities[key] +# all_errors = self.all_errors[key] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_errors = all_errors[order, :] + +# result = {} + +# for i in range(similarity_bins.size - 1): +# cmin, cmax = similarity_bins[i], similarity_bins[i + 1] +# amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) +# mean_errors = np.nanmean(all_errors[amin:amax], axis=0) +# result[(cmin, cmax)] = mean_errors + +# return result diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 8929d6983c..df9e1420cb 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -1,441 +1,21 @@ -from __future__ import annotations - -from pathlib import Path -import shutil -import os -import json -import pickle - -import numpy as np - -from spikeinterface.core import load_extractor, create_sorting_analyzer, load_sorting_analyzer -from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder - -from spikeinterface.qualitymetrics import compute_quality_metrics - -from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison - - -# TODO later : save comparison in folders when comparison object will be able to serialize - - -# This is to separate names when the key are tuples when saving folders -# _key_separator = "_##_" -_key_separator = "_-°°-_" +_txt_error_message = """ +GroundTruthStudy has been replaced by SorterStudy with similar API but not back compatible folder loading. +You can do: +from spikeinterface.benchmark import SorterStudy +study = SorterStudy.create(study_folder, datasets=..., cases=..., levels=...) +study.run() # this run sorters +study.compute_results() # this run the comparisons +# and then some ploting +study.plot_agreements() +study.plot_performances_vs_snr() +... +""" class GroundTruthStudy: - """ - This class is an helper function to run any comparison on several "cases" for many ground-truth dataset. - - "cases" refer to: - * several sorters for comparisons - * same sorter with differents parameters - * any combination of these (and more) - - For increased flexibility, cases keys can be a tuple so that we can vary complexity along several - "levels" or "axis" (paremeters or sorters). - In this case, the result dataframes will have `MultiIndex` to handle the different levels. - - A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see - :py:func:`~spikeinterface.core.generate.generate_ground_truth_recording()`). - - This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. - Note that the underlying folder structure is not backward compatible! - - Parameters - ---------- - study_folder : str | Path - Path to folder containing `GroundTruthStudy` - """ - def __init__(self, study_folder): - self.folder = Path(study_folder) - - self.datasets = {} - self.cases = {} - self.sortings = {} - self.comparisons = {} - self.colors = None - - self.scan_folder() + raise RuntimeError(_txt_error_message) @classmethod def create(cls, study_folder, datasets={}, cases={}, levels=None): - # check that cases keys are homogeneous - key0 = list(cases.keys())[0] - if isinstance(key0, str): - assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" - if levels is None: - levels = "level0" - else: - assert isinstance(levels, str) - elif isinstance(key0, tuple): - assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" - num_levels = len(key0) - assert all( - len(key) == num_levels for key in cases.keys() - ), "Keys for cases are not homogeneous, tuple negth differ" - if levels is None: - levels = [f"level{i}" for i in range(num_levels)] - else: - levels = list(levels) - assert len(levels) == num_levels - else: - raise ValueError("Keys for cases must str or tuple") - - study_folder = Path(study_folder) - study_folder.mkdir(exist_ok=False, parents=True) - - (study_folder / "datasets").mkdir() - (study_folder / "datasets" / "recordings").mkdir() - (study_folder / "datasets" / "gt_sortings").mkdir() - (study_folder / "sorters").mkdir() - (study_folder / "sortings").mkdir() - (study_folder / "sortings" / "run_logs").mkdir() - (study_folder / "metrics").mkdir() - (study_folder / "comparisons").mkdir() - - for key, (rec, gt_sorting) in datasets.items(): - assert "/" not in key, "'/' cannot be in the key name!" - assert "\\" not in key, "'\\' cannot be in the key name!" - - # recordings are pickled - rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") - - # sortings are pickled + saved as NumpyFolderSorting - gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") - gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") - - info = {} - info["levels"] = levels - (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") - - # cases is dumped to a pickle file, json is not possible because of the tuple key - (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) - - return cls(study_folder) - - def scan_folder(self): - if not (self.folder / "datasets").exists(): - raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") - - with open(self.folder / "info.json", "r") as f: - self.info = json.load(f) - - self.levels = self.info["levels"] - - for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): - key = rec_file.stem - rec = load_extractor(rec_file) - gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) - self.datasets[key] = (rec, gt_sorting) - - with open(self.folder / "cases.pickle", "rb") as f: - self.cases = pickle.load(f) - - self.sortings = {k: None for k in self.cases} - self.comparisons = {k: None for k in self.cases} - for key in self.cases: - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - if sorting_folder.exists(): - self.sortings[key] = load_extractor(sorting_folder) - - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - if comparison_file.exists(): - with open(comparison_file, mode="rb") as f: - try: - self.comparisons[key] = pickle.load(f) - except Exception: - pass - - def __repr__(self): - t = f"{self.__class__.__name__} {self.folder.stem} \n" - t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" - t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" - num_computed = sum([1 for sorting in self.sortings.values() if sorting is not None]) - t += f" computed: {num_computed}\n" - - return t - - def key_to_str(self, key): - if isinstance(key, str): - return key - elif isinstance(key, tuple): - return _key_separator.join(key) - else: - raise ValueError("Keys for cases must str or tuple") - - def remove_sorting(self, key): - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - comparison_file = self.folder / "comparisons" / self.key_to_str(key) - self.sortings[key] = None - self.comparisons[key] = None - if sorting_folder.exists(): - shutil.rmtree(sorting_folder) - for f in (log_file, comparison_file): - if f.exists(): - f.unlink() - - def set_colors(self, colors=None, map_name="tab20"): - from spikeinterface.widgets import get_some_colors - - if colors is None: - case_keys = list(self.cases.keys()) - self.colors = get_some_colors( - case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0 - ) - else: - self.colors = colors - - def get_colors(self): - if self.colors is None: - self.set_colors() - return self.colors - - def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): - if case_keys is None: - case_keys = self.cases.keys() - - job_list = [] - for key in case_keys: - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - sorting_exists = sorting_folder.exists() - - sorter_folder = self.folder / "sorters" / self.key_to_str(key) - sorter_folder_exists = sorter_folder.exists() - - if keep: - if sorting_exists: - continue - if sorter_folder_exists: - # the sorter folder exists but havent been copied to sortings folder - sorting = read_sorter_folder(sorter_folder, raise_error=False) - if sorting is not None: - # save and skip - self.copy_sortings(case_keys=[key]) - continue - - self.remove_sorting(key) - - if sorter_folder_exists: - shutil.rmtree(sorter_folder) - - params = self.cases[key]["run_sorter_params"].copy() - # this ensure that sorter_name is given - recording, _ = self.datasets[self.cases[key]["dataset"]] - sorter_name = params.pop("sorter_name") - job = dict( - sorter_name=sorter_name, - recording=recording, - output_folder=sorter_folder, - ) - job.update(params) - # the verbose is overwritten and global to all run_sorters - job["verbose"] = verbose - job["with_output"] = False - job_list.append(job) - - run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) - - # TODO later create a list in laucher for engine blocking and non-blocking - if engine not in ("slurm",): - self.copy_sortings(case_keys) - - def copy_sortings(self, case_keys=None, force=True): - if case_keys is None: - case_keys = self.cases.keys() - - for key in case_keys: - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - sorter_folder = self.folder / "sorters" / self.key_to_str(key) - log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - - if (sorter_folder / "spikeinterface_log.json").exists(): - sorting = read_sorter_folder( - sorter_folder, raise_error=False, register_recording=False, sorting_info=False - ) - else: - sorting = None - - if sorting is not None: - if sorting_folder.exists(): - if force: - self.remove_sorting(key) - else: - continue - - sorting = sorting.save(format="numpy_folder", folder=sorting_folder) - self.sortings[key] = sorting - - # copy logs - shutil.copyfile(sorter_folder / "spikeinterface_log.json", log_file) - - def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): - if case_keys is None: - case_keys = self.cases.keys() - - for key in case_keys: - dataset_key = self.cases[key]["dataset"] - _, gt_sorting = self.datasets[dataset_key] - sorting = self.sortings[key] - if sorting is None: - self.comparisons[key] = None - continue - comp = comparison_class(gt_sorting, sorting, **kwargs) - self.comparisons[key] = comp - - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - with open(comparison_file, mode="wb") as f: - pickle.dump(comp, f) - - def get_run_times(self, case_keys=None): - import pandas as pd - - if case_keys is None: - case_keys = self.cases.keys() - - log_folder = self.folder / "sortings" / "run_logs" - - run_times = {} - for key in case_keys: - log_file = log_folder / f"{self.key_to_str(key)}.json" - with open(log_file, mode="r") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - run_times[key] = run_time - - return pd.Series(run_times, name="run_time") - - def create_sorting_analyzer_gt(self, case_keys=None, random_params={}, waveforms_params={}, **job_kwargs): - if case_keys is None: - case_keys = self.cases.keys() - - base_folder = self.folder / "sorting_analyzer" - base_folder.mkdir(exist_ok=True) - - dataset_keys = [self.cases[key]["dataset"] for key in case_keys] - dataset_keys = set(dataset_keys) - for dataset_key in dataset_keys: - # the waveforms depend on the dataset key - folder = base_folder / self.key_to_str(dataset_key) - recording, gt_sorting = self.datasets[dataset_key] - sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) - sorting_analyzer.compute("random_spikes", **random_params) - sorting_analyzer.compute("templates", **job_kwargs) - sorting_analyzer.compute("noise_levels") - - def get_sorting_analyzer(self, case_key=None, dataset_key=None): - if case_key is not None: - dataset_key = self.cases[case_key]["dataset"] - - folder = self.folder / "sorting_analyzer" / self.key_to_str(dataset_key) - sorting_analyzer = load_sorting_analyzer(folder) - return sorting_analyzer - - # def get_templates(self, key, mode="average"): - # analyzer = self.get_sorting_analyzer(case_key=key) - # templates = sorting_analyzer.get_all_templates(mode=mode) - # return templates - - def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): - if case_keys is None: - case_keys = self.cases.keys() - - done = [] - for key in case_keys: - dataset_key = self.cases[key]["dataset"] - if dataset_key in done: - # some case can share the same waveform extractor - continue - done.append(dataset_key) - filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" - if filename.exists(): - if force: - os.remove(filename) - else: - continue - analyzer = self.get_sorting_analyzer(key) - metrics = compute_quality_metrics(analyzer, metric_names=metric_names) - metrics.to_csv(filename, sep="\t", index=True) - - def get_metrics(self, key): - import pandas as pd - - dataset_key = self.cases[key]["dataset"] - - filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" - if not filename.exists(): - return - metrics = pd.read_csv(filename, sep="\t", index_col=0) - dataset_key = self.cases[key]["dataset"] - recording, gt_sorting = self.datasets[dataset_key] - metrics.index = gt_sorting.unit_ids - return metrics - - def get_units_snr(self, key): - return self.get_metrics(key)["snr"] - - def get_performance_by_unit(self, case_keys=None): - import pandas as pd - - if case_keys is None: - case_keys = self.cases.keys() - - perf_by_unit = [] - for key in case_keys: - comp = self.comparisons.get(key, None) - assert comp is not None, "You need to do study.run_comparisons() first" - - perf = comp.get_performance(method="by_unit", output="pandas") - - if isinstance(key, str): - perf[self.levels] = key - elif isinstance(key, tuple): - for col, k in zip(self.levels, key): - perf[col] = k - - perf = perf.reset_index() - perf_by_unit.append(perf) - - perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(self.levels) - perf_by_unit = perf_by_unit.sort_index() - return perf_by_unit - - def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): - import pandas as pd - - if case_keys is None: - case_keys = list(self.cases.keys()) - - if isinstance(case_keys[0], str): - index = pd.Index(case_keys, name=self.levels) - else: - index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) - - columns = ["num_gt", "num_sorter", "num_well_detected"] - comp = self.comparisons[case_keys[0]] - if comp.exhaustive_gt: - columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) - count_units = pd.DataFrame(index=index, columns=columns, dtype=int) - - for key in case_keys: - comp = self.comparisons.get(key, None) - assert comp is not None, "You need to do study.run_comparisons() first" - - gt_sorting = comp.sorting1 - sorting = comp.sorting2 - - count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) - - if comp.exhaustive_gt: - count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) - count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) - count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) - count_units.loc[key, "num_bad"] = comp.count_bad_units() - - return count_units + raise RuntimeError(_txt_error_message) diff --git a/src/spikeinterface/full.py b/src/spikeinterface/full.py index 0cd0fb0fb5..b9410bc021 100644 --- a/src/spikeinterface/full.py +++ b/src/spikeinterface/full.py @@ -25,3 +25,4 @@ from .widgets import * from .exporters import * from .generation import * +from .benchmark import * diff --git a/src/spikeinterface/sortingcomponents/benchmark/__init__.py b/src/spikeinterface/sortingcomponents/benchmark/__init__.py deleted file mode 100644 index ad6d444bdb..0000000000 --- a/src/spikeinterface/sortingcomponents/benchmark/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Module to benchmark some sorting components: - * clustering - * motion - * template matching -""" diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py deleted file mode 100644 index a9e404292d..0000000000 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - - -@pytest.mark.skip() -def test_benchmark_peak_selection(create_cache_folder): - cache_folder = create_cache_folder - pass - - -if __name__ == "__main__": - test_benchmark_peak_selection() diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 85043d0d12..5e160a6a5a 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -1,127 +1,67 @@ +""" +This module will be deprecated and will be removed in 0.102.0 + +All ploting for the previous GTStudy is now centralized in spikeinterface.benchmark.benchmark_plot_tools +Please not that GTStudy is replaced by SorterStudy wich is based more generic BenchmarkStudy. +""" + from __future__ import annotations -import numpy as np +from .base import BaseWidget -from .base import BaseWidget, to_attr +import warnings class StudyRunTimesWidget(BaseWidget): """ - Plot sorter run times for a GroundTruthStudy - + Plot sorter run times for a SorterStudy. Parameters ---------- - study : GroundTruthStudy + study : SorterStudy A study object. case_keys : list or None A selection of cases to plot, if None, then all. """ - def __init__( - self, - study, - case_keys=None, - backend=None, - **backend_kwargs, - ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, run_times=study.get_run_times(case_keys), case_keys=case_keys, colors=study.get_colors() + def __init__(self, study, case_keys=None, backend=None, **backend_kwargs): + warnings.warn( + "plot_study_run_times is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." ) - + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) + from spikeinterface.benchmark.benchmark_plot_tools import plot_run_times - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + plot_run_times(data_plot["study"], case_keys=data_plot["case_keys"]) - for i, key in enumerate(dp.case_keys): - label = dp.study.cases[key]["label"] - rt = dp.run_times.loc[key] - self.ax.bar(i, rt, width=0.8, label=label, facecolor=dp.colors[key]) - self.ax.set_ylabel("run time (s)") - self.ax.legend() - -# TODO : plot optionally average on some levels using group by class StudyUnitCountsWidget(BaseWidget): """ Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" - Parameters ---------- - study : GroundTruthStudy + study : SorterStudy A study object. case_keys : list or None A selection of cases to plot, if None, then all. """ - def __init__( - self, - study, - case_keys=None, - backend=None, - **backend_kwargs, - ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, - count_units=study.get_count_units(case_keys=case_keys), - case_keys=case_keys, + def __init__(self, study, case_keys=None, backend=None, **backend_kwargs): + warnings.warn( + "plot_study_unit_counts is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." ) - + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors + from spikeinterface.benchmark.benchmark_plot_tools import plot_unit_counts - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - columns = dp.count_units.columns.tolist() - columns.remove("num_gt") - columns.remove("num_sorter") - - ncol = len(columns) - - colors = get_some_colors(columns, color_engine="auto", map_name="hot") - colors["num_well_detected"] = "green" - - xticklabels = [] - for i, key in enumerate(dp.case_keys): - for c, col in enumerate(columns): - x = i + 1 + c / (ncol + 1) - y = dp.count_units.loc[key, col] - if not "well_detected" in col: - y = -y - - if i == 0: - label = col.replace("num_", "").replace("_", " ").title() - else: - label = None - - self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) - - xticklabels.append(dp.study.cases[key]["label"]) - - self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) - self.ax.set_xticklabels(xticklabels) - self.ax.legend() + plot_unit_counts(data_plot["study"], case_keys=data_plot["case_keys"]) class StudyPerformances(BaseWidget): @@ -154,78 +94,26 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - + warnings.warn( + "plot_study_performances is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." + ) plot_data = dict( study=study, - perfs=study.get_performance_by_unit(case_keys=case_keys), mode=mode, performance_names=performance_names, case_keys=case_keys, ) - - self.colors = study.get_colors() - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors - - import pandas as pd - import seaborn as sns - - dp = to_attr(data_plot) - perfs = dp.perfs - study = dp.study - - if dp.mode in ("ordered", "snr"): - backend_kwargs["num_axes"] = len(dp.performance_names) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - if dp.mode == "ordered": - for count, performance_name in enumerate(dp.performance_names): - ax = self.axes.flatten()[count] - for key in dp.case_keys: - label = study.cases[key]["label"] - val = perfs.xs(key).loc[:, performance_name].values - val = np.sort(val)[::-1] - ax.plot(val, label=label, c=self.colors[key]) - ax.set_title(performance_name) - if count == len(dp.performance_names) - 1: - ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) - - elif dp.mode == "snr": - metric_name = dp.mode - for count, performance_name in enumerate(dp.performance_names): - ax = self.axes.flatten()[count] - - max_metric = 0 - for key in dp.case_keys: - x = study.get_metrics(key).loc[:, metric_name].values - y = perfs.xs(key).loc[:, performance_name].values - label = study.cases[key]["label"] - ax.scatter(x, y, s=10, label=label, color=self.colors[key]) - max_metric = max(max_metric, np.max(x)) - ax.set_title(performance_name) - ax.set_xlim(0, max_metric * 1.05) - ax.set_ylim(0, 1.05) - if count == 0: - ax.legend(loc="lower right") - - elif dp.mode == "swarm": - levels = perfs.index.names - df = pd.melt( - perfs.reset_index(), - id_vars=levels, - var_name="Metric", - value_name="Score", - value_vars=dp.performance_names, - ) - df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) - sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) + from spikeinterface.benchmark.benchmark_plot_tools import plot_performances + + plot_performances( + data_plot["study"], + mode=data_plot["mode"], + performance_names=data_plot["performance_names"], + case_keys=data_plot["case_keys"], + ) class StudyAgreementMatrix(BaseWidget): @@ -251,9 +139,9 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - + warnings.warn( + "plot_study_agreement_matrix is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." + ) plot_data = dict( study=study, case_keys=case_keys, @@ -263,36 +151,9 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .comparison import AgreementMatrixWidget - - dp = to_attr(data_plot) - study = dp.study - - backend_kwargs["num_axes"] = len(dp.case_keys) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - for count, key in enumerate(dp.case_keys): - ax = self.axes.flatten()[count] - comp = study.comparisons[key] - unit_ticks = len(comp.sorting1.unit_ids) <= 16 - count_text = len(comp.sorting1.unit_ids) <= 16 - - AgreementMatrixWidget( - comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax - ) - label = study.cases[key]["label"] - ax.set_xlabel(label) - - if count > 0: - ax.set_ylabel(None) - ax.set_yticks([]) - ax.set_xticks([]) + from spikeinterface.benchmark.benchmark_plot_tools import plot_agreement_matrix - # ax0 = self.axes.flatten()[0] - # for ax in self.axes.flatten()[1:]: - # ax.sharey(ax0) + plot_agreement_matrix(data_plot["study"], ordered=data_plot["ordered"], case_keys=data_plot["case_keys"]) class StudySummary(BaseWidget): @@ -320,25 +181,26 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - plot_data = dict( - study=study, - case_keys=case_keys, + warnings.warn( + "plot_study_summary is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." ) - + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - study = data_plot["study"] case_keys = data_plot["case_keys"] - StudyPerformances(study=study, case_keys=case_keys, mode="ordered", backend="matplotlib", **backend_kwargs) - StudyPerformances(study=study, case_keys=case_keys, mode="snr", backend="matplotlib", **backend_kwargs) - StudyAgreementMatrix(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) - StudyRunTimesWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) - StudyUnitCountsWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + from spikeinterface.benchmark.benchmark_plot_tools import ( + plot_agreement_matrix, + plot_performances, + plot_unit_counts, + plot_run_times, + ) + + plot_performances(study=study, case_keys=case_keys, mode="ordered") + plot_performances(study=study, case_keys=case_keys, mode="snr") + plot_agreement_matrix(study=study, case_keys=case_keys) + plot_run_times(study=study, case_keys=case_keys) + plot_unit_counts(study=study, case_keys=case_keys)