diff --git a/doc/images/overview.png b/doc/images/overview.png index ea5ba49d08..e367c4b6e4 100644 Binary files a/doc/images/overview.png and b/doc/images/overview.png differ diff --git a/doc/modules/benchmark.rst b/doc/modules/benchmark.rst new file mode 100644 index 0000000000..71b84adde9 --- /dev/null +++ b/doc/modules/benchmark.rst @@ -0,0 +1,141 @@ +Benchmark module +================ + +This module contains machinery to compare some sorters against ground truth in many multiple situtation. + + +..notes:: + + In 0.102.0 The previous :py:func:`~spikeinterface.comparison.GroundTruthStudy()` has been replaced by + :py:func:`~spikeinterface.benchmark.SorterStudy()` + + +This module also aims to benchmark sorting components (detection, clustering, motion, template matching) using the +same base class :py:func:`~spikeinterface.benchmark.BenchmarkStudy()` but specialized to a targeted component. + +By design, the main class handle the concept of "levels" : this allows to compare several complexities at the same time. +For instance, compare kilosort4 vs kilsort2.5 (level 0) for different noises amplitudes (level 1) combined with +several motion vectors (leevel 2). + +**Example: compare many sorters : a ground truth study** + +We have a high level class to compare many sorters against ground truth: :py:func:`~spikeinterface.benchmark.SorterStudy()` + + +A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases +like the different parameter sets. + +The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" +on many recordings and then collect and aggregate results in an easy way. + +The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolders: + + * datasets: contains ground truth datasets + * sorters : contains outputs of sorters + * sortings: contains light copy of all sorting + * metrics: contains metrics + * ... + + +.. code-block:: python + + import matplotlib.pyplot as plt + import seaborn as sns + + import spikeinterface.extractors as se + import spikeinterface.widgets as sw + from spikeinterface.benchmark import SorterStudy + + + # generate 2 simulated datasets (could be also mearec files) + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + + datasets = { + "toy0": (rec0, gt_sorting0), + "toy1": (rec1, gt_sorting1), + } + + # define some "cases" here we want to test tridesclous2 on 2 datasets and spykingcircus2 on one dataset + # so it is a two level study (sorter_name, dataset) + # this could be more complicated like (sorter_name, dataset, params) + cases = { + ("tdc2", "toy0"): { + "label": "tridesclous2 on tetrode0", + "dataset": "toy0", + "params": {"sorter_name": "tridesclous2"} + }, + ("tdc2", "toy1"): { + "label": "tridesclous2 on tetrode1", + "dataset": "toy1", + "params": {"sorter_name": "tridesclous2"} + }, + ("sc", "toy0"): { + "label": "spykingcircus2 on tetrode0", + "dataset": "toy0", + "params": { + "sorter_name": "spykingcircus", + "docker_image": True + }, + }, + } + # this initilizes a folder + study = SorterStudy.create(study_folder=study_folder, datasets=datasets, cases=cases, + levels=["sorter_name", "dataset"]) + + + # This internally do run_sorter() for all cases in one function + study.run() + + # Run the benchmark : this internanly do compare_sorter_to_ground_truth() for all cases + study.compute_results() + + # Collect comparisons one by one + for case_key in study.cases: + print('*' * 10) + print(case_key) + # raw counting of tp/fp/... + comp = study.get_result(case_key)["gt_comparison"] + # summary + comp.print_summary() + perf_unit = comp.get_performance(method='by_unit') + perf_avg = comp.get_performance(method='pooled_with_average') + # some plots + m = comp.get_confusion_matrix() + w_comp = sw.plot_agreement_matrix(sorting_comparison=comp) + + # Collect synthetic dataframes and display + # As shown previously, the performance is returned as a pandas dataframe. + # The spikeinterface.comparison.get_performance_by_unit() function, + # gathers all the outputs in the study folder and merges them into a single dataframe. + # Same idea for spikeinterface.comparison.get_count_units() + + # this is a dataframe + perfs = study.get_performance_by_unit() + + # this is a dataframe + unit_counts = study.get_count_units() + + # Study also have several plotting methods for plotting the result + study.plot_agreement_matrix() + study.plot_unit_counts() + study.plot_performances(mode="ordered") + study.plot_performances(mode="snr") + + + + +Benchmark spike collisions +-------------------------- + +SpikeInterface also has a specific toolset to benchmark how well sorters are at recovering spikes in "collision". + +We have three classes to handle collision-specific comparisons, and also to quantify the effects on correlogram +estimation: + + * :py:class:`~spikeinterface.comparison.CollisionGTComparison` + * :py:class:`~spikeinterface.comparison.CorrelogramGTComparison` + +For more details, checkout the following paper: + +`Samuel Garcia, Alessio P. Buccino and Pierre Yger. "How Do Spike Collisions Affect Spike Sorting Performance?" `_ diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index edee7f1fda..66451c0e87 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -5,6 +5,10 @@ Comparison module SpikeInterface has a :py:mod:`~spikeinterface.comparison` module, which contains functions and tools to compare spike trains and templates (useful for tracking units over multiple sessions). +.. note:: + + In version 0.102.0 the benchmark part of comparison has moved in the new :py:mod:`~spikeinterface.benchmark` + In addition, the :py:mod:`~spikeinterface.comparison` module contains advanced benchmarking tools to evaluate the effects of spike collisions on spike sorting results, and to construct hybrid recordings for comparison. @@ -242,135 +246,6 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for cmp_gt_HS.get_redundant_units(redundant_score=0.2) - -**Example: compare many sorters with a Ground Truth Study** - -We also have a high level class to compare many sorters against ground truth: -:py:func:`~spikeinterface.comparison.GroundTruthStudy()` - -A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases -like the different parameter sets. - -The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" -on many recordings and then collect and aggregate results in an easy way. - -The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolders: - - * datasets: contains ground truth datasets - * sorters : contains outputs of sorters - * sortings: contains light copy of all sorting - * metrics: contains metrics - * ... - - -.. code-block:: python - - import matplotlib.pyplot as plt - import seaborn as sns - - import spikeinterface.extractors as se - import spikeinterface.widgets as sw - from spikeinterface.comparison import GroundTruthStudy - - - # generate 2 simulated datasets (could be also mearec files) - rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) - rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) - - datasets = { - "toy0": (rec0, gt_sorting0), - "toy1": (rec1, gt_sorting1), - } - - # define some "cases" here we want to test tridesclous2 on 2 datasets and spykingcircus2 on one dataset - # so it is a two level study (sorter_name, dataset) - # this could be more complicated like (sorter_name, dataset, params) - cases = { - ("tdc2", "toy0"): { - "label": "tridesclous2 on tetrode0", - "dataset": "toy0", - "run_sorter_params": { - "sorter_name": "tridesclous2", - }, - }, - ("tdc2", "toy1"): { - "label": "tridesclous2 on tetrode1", - "dataset": "toy1", - "run_sorter_params": { - "sorter_name": "tridesclous2", - }, - }, - - ("sc", "toy0"): { - "label": "spykingcircus2 on tetrode0", - "dataset": "toy0", - "run_sorter_params": { - "sorter_name": "spykingcircus", - "docker_image": True - }, - }, - } - # this initilizes a folder - study = GroundTruthStudy.create(study_folder=study_folder, datasets=datasets, cases=cases, - levels=["sorter_name", "dataset"]) - - - # all cases in one function - study.run_sorters() - - # Collect comparisons - # - # You can collect in one shot all results and run the - # GroundTruthComparison on it. - # So you can have fine access to all individual results. - # - # Note: use exhaustive_gt=True when you know exactly how many - # units in the ground truth (for synthetic datasets) - - # run all comparisons and loop over the results - study.run_comparisons(exhaustive_gt=True) - for key, comp in study.comparisons.items(): - print('*' * 10) - print(key) - # raw counting of tp/fp/... - print(comp.count_score) - # summary - comp.print_summary() - perf_unit = comp.get_performance(method='by_unit') - perf_avg = comp.get_performance(method='pooled_with_average') - # some plots - m = comp.get_confusion_matrix() - w_comp = sw.plot_agreement_matrix(sorting_comparison=comp) - - # Collect synthetic dataframes and display - # As shown previously, the performance is returned as a pandas dataframe. - # The spikeinterface.comparison.get_performance_by_unit() function, - # gathers all the outputs in the study folder and merges them into a single dataframe. - # Same idea for spikeinterface.comparison.get_count_units() - - # this is a dataframe - perfs = study.get_performance_by_unit() - - # this is a dataframe - unit_counts = study.get_count_units() - - # we can also access run times - run_times = study.get_run_times() - print(run_times) - - # Easy plotting with seaborn - fig1, ax1 = plt.subplots() - sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1) - ax1.set_title('Run times') - - ############################################################################## - - fig2, ax2 = plt.subplots() - sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2) - ax2.set_title('Recall') - ax2.set_ylim(-0.1, 1.1) - - .. _symmetric: 2. Compare the output of two spike sorters (symmetric comparison) @@ -540,32 +415,4 @@ sorting analyzers from day 1 (:code:`analyzer_day1`) to day 5 (:code:`analyzer_d -Benchmark spike collisions --------------------------- - -SpikeInterface also has a specific toolset to benchmark how well sorters are at recovering spikes in "collision". - -We have three classes to handle collision-specific comparisons, and also to quantify the effects on correlogram -estimation: - - * :py:class:`~spikeinterface.comparison.CollisionGTComparison` - * :py:class:`~spikeinterface.comparison.CorrelogramGTComparison` - * :py:class:`~spikeinterface.comparison.CollisionGTStudy` - * :py:class:`~spikeinterface.comparison.CorrelogramGTStudy` - -For more details, checkout the following paper: - -`Samuel Garcia, Alessio P. Buccino and Pierre Yger. "How Do Spike Collisions Affect Spike Sorting Performance?" `_ - - -Hybrid recording ----------------- - -To benchmark spike sorting results, we need ground-truth spiking activity. -This can be generated with artificial simulations, e.g., using `MEArec `_, or -alternatively by generating so-called "hybrid" recordings. - -The :py:mod:`~spikeinterface.comparison` module includes functions to generate such "hybrid" recordings: - * :py:func:`~spikeinterface.comparison.create_hybrid_units_recording`: add new units to an existing recording - * :py:func:`~spikeinterface.comparison.create_hybrid_spikes_recording`: add new spikes to existing units in a recording diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index 5f3e584b20..1f2dde3b32 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -15,7 +15,7 @@ # ) - +# TODO later integrate CollisionGTComparison optionally in this class. class SorterBenchmark(Benchmark): diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 574bd16093..cff87e7a57 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -171,72 +171,75 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good return similarities, recall_scores, pair_names - -class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - _kwargs = dict() - _kwargs.update(kwargs) - _kwargs["exhaustive_gt"] = exhaustive_gt - _kwargs["collision_lag"] = collision_lag - _kwargs["nbins"] = nbins - GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) - self.exhaustive_gt = exhaustive_gt - self.collision_lag = collision_lag - - def get_lags(self, key): - comp = self.comparisons[key] - fs = comp.sorting1.get_sampling_frequency() - lags = comp.bins / fs * 1000.0 - return lags - - def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): - import sklearn - - if case_keys is None: - case_keys = self.cases.keys() - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for key in case_keys: - templates = self.get_templates(key) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - comp = self.comparisons[key] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity, good_only=good_only, min_accuracy=min_accuracy - ) - self.all_similarities[key] = similarities - self.all_recall_scores[key] = recall_scores - - def get_mean_over_similarity_range(self, similarity_range, key): - idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) - all_similarities = self.all_similarities[key][idx] - all_recall_scores = self.all_recall_scores[key][idx] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = np.nanmean(all_recall_scores, axis=0) - - return mean_recall_scores - - def get_lag_profile_over_similarity_bins(self, similarity_bins, key): - all_similarities = self.all_similarities[key] - all_recall_scores = self.all_recall_scores[key] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) - result[(cmin, cmax)] = mean_recall_scores - - return result +# This is removed at the moment. +# We need to move this maybe one day in benchmark. +# please do not delete this + +# class CollisionGTStudy(GroundTruthStudy): +# def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): +# _kwargs = dict() +# _kwargs.update(kwargs) +# _kwargs["exhaustive_gt"] = exhaustive_gt +# _kwargs["collision_lag"] = collision_lag +# _kwargs["nbins"] = nbins +# GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) +# self.exhaustive_gt = exhaustive_gt +# self.collision_lag = collision_lag + +# def get_lags(self, key): +# comp = self.comparisons[key] +# fs = comp.sorting1.get_sampling_frequency() +# lags = comp.bins / fs * 1000.0 +# return lags + +# def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): +# import sklearn + +# if case_keys is None: +# case_keys = self.cases.keys() + +# self.all_similarities = {} +# self.all_recall_scores = {} +# self.good_only = good_only + +# for key in case_keys: +# templates = self.get_templates(key) +# flat_templates = templates.reshape(templates.shape[0], -1) +# similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) +# comp = self.comparisons[key] +# similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( +# similarity, good_only=good_only, min_accuracy=min_accuracy +# ) +# self.all_similarities[key] = similarities +# self.all_recall_scores[key] = recall_scores + +# def get_mean_over_similarity_range(self, similarity_range, key): +# idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) +# all_similarities = self.all_similarities[key][idx] +# all_recall_scores = self.all_recall_scores[key][idx] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_recall_scores = all_recall_scores[order, :] + +# mean_recall_scores = np.nanmean(all_recall_scores, axis=0) + +# return mean_recall_scores + +# def get_lag_profile_over_similarity_bins(self, similarity_bins, key): +# all_similarities = self.all_similarities[key] +# all_recall_scores = self.all_recall_scores[key] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_recall_scores = all_recall_scores[order, :] + +# result = {} + +# for i in range(similarity_bins.size - 1): +# cmin, cmax = similarity_bins[i], similarity_bins[i + 1] +# amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) +# mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) +# result[(cmin, cmax)] = mean_recall_scores + +# return result diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 0cafef2c12..717d11a3fa 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -128,57 +128,60 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): return similarities, errors -class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons( - self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs - ): - _kwargs = dict() - _kwargs.update(kwargs) - _kwargs["exhaustive_gt"] = exhaustive_gt - _kwargs["window_ms"] = window_ms - _kwargs["bin_ms"] = bin_ms - _kwargs["well_detected_score"] = well_detected_score - GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) - self.exhaustive_gt = exhaustive_gt - - @property - def time_bins(self): - for key, value in self.comparisons.items(): - return value.time_bins - - def precompute_scores_by_similarities(self, case_keys=None, good_only=True): - import sklearn.metrics - - if case_keys is None: - case_keys = self.cases.keys() - - self.all_similarities = {} - self.all_errors = {} - - for key in case_keys: - templates = self.get_templates(key) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - comp = self.comparisons[key] - similarities, errors = comp.compute_correlogram_by_similarity(similarity) - - self.all_similarities[key] = similarities - self.all_errors[key] = errors - - def get_error_profile_over_similarity_bins(self, similarity_bins, key): - all_similarities = self.all_similarities[key] - all_errors = self.all_errors[key] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_errors = np.nanmean(all_errors[amin:amax], axis=0) - result[(cmin, cmax)] = mean_errors - - return result +# This is removed at the moment. +# We need to move this maybe one day in benchmark + +# class CorrelogramGTStudy(GroundTruthStudy): +# def run_comparisons( +# self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs +# ): +# _kwargs = dict() +# _kwargs.update(kwargs) +# _kwargs["exhaustive_gt"] = exhaustive_gt +# _kwargs["window_ms"] = window_ms +# _kwargs["bin_ms"] = bin_ms +# _kwargs["well_detected_score"] = well_detected_score +# GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) +# self.exhaustive_gt = exhaustive_gt + +# @property +# def time_bins(self): +# for key, value in self.comparisons.items(): +# return value.time_bins + +# def precompute_scores_by_similarities(self, case_keys=None, good_only=True): +# import sklearn.metrics + +# if case_keys is None: +# case_keys = self.cases.keys() + +# self.all_similarities = {} +# self.all_errors = {} + +# for key in case_keys: +# templates = self.get_templates(key) +# flat_templates = templates.reshape(templates.shape[0], -1) +# similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) +# comp = self.comparisons[key] +# similarities, errors = comp.compute_correlogram_by_similarity(similarity) + +# self.all_similarities[key] = similarities +# self.all_errors[key] = errors + +# def get_error_profile_over_similarity_bins(self, similarity_bins, key): +# all_similarities = self.all_similarities[key] +# all_errors = self.all_errors[key] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_errors = all_errors[order, :] + +# result = {} + +# for i in range(similarity_bins.size - 1): +# cmin, cmax = similarity_bins[i], similarity_bins[i + 1] +# amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) +# mean_errors = np.nanmean(all_errors[amin:amax], axis=0) +# result[(cmin, cmax)] = mean_errors + +# return result