diff --git a/doc/development/development.rst b/doc/development/development.rst index 246a2bcb9a..a91818a271 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -192,6 +192,7 @@ Miscelleaneous Stylistic Conventions #. Avoid using abbreviations in variable names (e.g. use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. #. Use index as singular and indices for plural following the NumPy convention. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids) #. We use file_path and folder_path (instead of file_name and folder_name) for clarity. +#. For the titles of documentation pages, only capitalize the first letter of the first word and classes or software packages. For example, "How to use a SortingAnalyzer in SpikeInterface". #. For creating headers to divide sections of code we use the following convention (see issue `#3019 `_): diff --git a/doc/how_to/combine_recordings.rst b/doc/how_to/combine_recordings.rst index db37e28382..4a088f01b1 100644 --- a/doc/how_to/combine_recordings.rst +++ b/doc/how_to/combine_recordings.rst @@ -1,4 +1,4 @@ -Combine Recordings in SpikeInterface +Combine recordings in SpikeInterface ==================================== In this tutorial we will walk through combining multiple recording objects. Sometimes this occurs due to hardware diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index 1f24fb66d3..eab1e0a300 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -1,4 +1,4 @@ -Export MATLAB Data to Binary & Load in SpikeInterface +Export MATLAB data to binary & load in SpikeInterface ======================================================== In this tutorial, we will walk through the process of exporting data from MATLAB in a binary format and subsequently loading it using SpikeInterface in Python. diff --git a/doc/how_to/load_your_data_into_sorting.rst b/doc/how_to/load_your_data_into_sorting.rst index 4e434ecb7a..e250cfa6e9 100644 --- a/doc/how_to/load_your_data_into_sorting.rst +++ b/doc/how_to/load_your_data_into_sorting.rst @@ -1,5 +1,5 @@ -Load Your Own Data into a Sorting -================================= +Load your own data into a Sorting object +======================================== Why make a :code:`Sorting`? diff --git a/doc/how_to/process_by_channel_group.rst b/doc/how_to/process_by_channel_group.rst index bac0de4d0c..08a87ab738 100644 --- a/doc/how_to/process_by_channel_group.rst +++ b/doc/how_to/process_by_channel_group.rst @@ -1,4 +1,4 @@ -Process a Recording by Channel Group +Process a recording by channel group ==================================== In this tutorial, we will walk through how to preprocess and sort a recording diff --git a/doc/how_to/viewers.rst b/doc/how_to/viewers.rst index c7574961bd..7bb41cadb6 100644 --- a/doc/how_to/viewers.rst +++ b/doc/how_to/viewers.rst @@ -1,4 +1,4 @@ -Visualize Data +Visualize data ============== There are several ways to plot signals (raw, preprocessed) and spikes. diff --git a/doc/images/overview.png b/doc/images/overview.png index ea5ba49d08..e367c4b6e4 100644 Binary files a/doc/images/overview.png and b/doc/images/overview.png differ diff --git a/doc/modules/benchmark.rst b/doc/modules/benchmark.rst new file mode 100644 index 0000000000..faf53be790 --- /dev/null +++ b/doc/modules/benchmark.rst @@ -0,0 +1,141 @@ +Benchmark module +================ + +This module contains machinery to compare some sorters against ground truth in many multiple situtation. + + +..notes:: + + In 0.102.0 The previous :py:func:`~spikeinterface.comparison.GroundTruthStudy()` has been replaced by + :py:func:`~spikeinterface.benchmark.SorterStudy()` + + +This module also aims to benchmark sorting components (detection, clustering, motion, template matching) using the +same base class :py:func:`~spikeinterface.benchmark.BenchmarkStudy()` but specialized to a targeted component. + +By design, the main class handle the concept of "levels" : this allows to compare several complexities at the same time. +For instance, compare kilosort4 vs kilsort2.5 (level 0) for different noises amplitudes (level 1) combined with +several motion vectors (leevel 2). + +**Example: compare many sorters : a ground truth study** + +We have a high level class to compare many sorters against ground truth: :py:func:`~spikeinterface.benchmark.SorterStudy()` + + +A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases +like the different parameter sets. + +The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" +on many recordings and then collect and aggregate results in an easy way. + +The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolders: + + * datasets: contains ground truth datasets + * sorters : contains outputs of sorters + * sortings: contains light copy of all sorting + * metrics: contains metrics + * ... + + +.. code-block:: python + + import matplotlib.pyplot as plt + import seaborn as sns + + import spikeinterface.extractors as se + import spikeinterface.widgets as sw + from spikeinterface.benchmark import SorterStudy + + + # generate 2 simulated datasets (could be also mearec files) + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + + datasets = { + "toy0": (rec0, gt_sorting0), + "toy1": (rec1, gt_sorting1), + } + + # define some "cases" here we want to test tridesclous2 on 2 datasets and spykingcircus2 on one dataset + # so it is a two level study (sorter_name, dataset) + # this could be more complicated like (sorter_name, dataset, params) + cases = { + ("tdc2", "toy0"): { + "label": "tridesclous2 on tetrode0", + "dataset": "toy0", + "params": {"sorter_name": "tridesclous2"} + }, + ("tdc2", "toy1"): { + "label": "tridesclous2 on tetrode1", + "dataset": "toy1", + "params": {"sorter_name": "tridesclous2"} + }, + ("sc", "toy0"): { + "label": "spykingcircus2 on tetrode0", + "dataset": "toy0", + "params": { + "sorter_name": "spykingcircus", + "docker_image": True + }, + }, + } + # this initilizes a folder + study = SorterStudy.create(study_folder=study_folder, datasets=datasets, cases=cases, + levels=["sorter_name", "dataset"]) + + + # This internally do run_sorter() for all cases in one function + study.run() + + # Run the benchmark : this internanly do compare_sorter_to_ground_truth() for all cases + study.compute_results() + + # Collect comparisons one by one + for case_key in study.cases: + print('*' * 10) + print(case_key) + # raw counting of tp/fp/... + comp = study.get_result(case_key)["gt_comparison"] + # summary + comp.print_summary() + perf_unit = comp.get_performance(method='by_unit') + perf_avg = comp.get_performance(method='pooled_with_average') + # some plots + m = comp.get_confusion_matrix() + w_comp = sw.plot_agreement_matrix(sorting_comparison=comp) + + # Collect synthetic dataframes and display + # As shown previously, the performance is returned as a pandas dataframe. + # The spikeinterface.comparison.get_performance_by_unit() function, + # gathers all the outputs in the study folder and merges them into a single dataframe. + # Same idea for spikeinterface.comparison.get_count_units() + + # this is a dataframe + perfs = study.get_performance_by_unit() + + # this is a dataframe + unit_counts = study.get_count_units() + + # Study also have several plotting methods for plotting the result + study.plot_agreement_matrix() + study.plot_unit_counts() + study.plot_performances(mode="ordered") + study.plot_performances(mode="snr") + + + + +Benchmark spike collisions +-------------------------- + +SpikeInterface also has a specific toolset to benchmark how well sorters are at recovering spikes in "collision". + +We have three classes to handle collision-specific comparisons, and also to quantify the effects on correlogram +estimation: + + * :py:class:`~spikeinterface.comparison.CollisionGTComparison` + * :py:class:`~spikeinterface.comparison.CorrelogramGTComparison` + +For more details, checkout the following paper: + +`Samuel Garcia, Alessio P. Buccino and Pierre Yger. "How Do Spike Collisions Affect Spike Sorting Performance?" `_ diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index edee7f1fda..a02d76664d 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -5,6 +5,10 @@ Comparison module SpikeInterface has a :py:mod:`~spikeinterface.comparison` module, which contains functions and tools to compare spike trains and templates (useful for tracking units over multiple sessions). +.. note:: + + In version 0.102.0 the benchmark part of comparison has moved in the new :py:mod:`~spikeinterface.benchmark` + In addition, the :py:mod:`~spikeinterface.comparison` module contains advanced benchmarking tools to evaluate the effects of spike collisions on spike sorting results, and to construct hybrid recordings for comparison. @@ -242,135 +246,6 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for cmp_gt_HS.get_redundant_units(redundant_score=0.2) - -**Example: compare many sorters with a Ground Truth Study** - -We also have a high level class to compare many sorters against ground truth: -:py:func:`~spikeinterface.comparison.GroundTruthStudy()` - -A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases -like the different parameter sets. - -The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" -on many recordings and then collect and aggregate results in an easy way. - -The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolders: - - * datasets: contains ground truth datasets - * sorters : contains outputs of sorters - * sortings: contains light copy of all sorting - * metrics: contains metrics - * ... - - -.. code-block:: python - - import matplotlib.pyplot as plt - import seaborn as sns - - import spikeinterface.extractors as se - import spikeinterface.widgets as sw - from spikeinterface.comparison import GroundTruthStudy - - - # generate 2 simulated datasets (could be also mearec files) - rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) - rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) - - datasets = { - "toy0": (rec0, gt_sorting0), - "toy1": (rec1, gt_sorting1), - } - - # define some "cases" here we want to test tridesclous2 on 2 datasets and spykingcircus2 on one dataset - # so it is a two level study (sorter_name, dataset) - # this could be more complicated like (sorter_name, dataset, params) - cases = { - ("tdc2", "toy0"): { - "label": "tridesclous2 on tetrode0", - "dataset": "toy0", - "run_sorter_params": { - "sorter_name": "tridesclous2", - }, - }, - ("tdc2", "toy1"): { - "label": "tridesclous2 on tetrode1", - "dataset": "toy1", - "run_sorter_params": { - "sorter_name": "tridesclous2", - }, - }, - - ("sc", "toy0"): { - "label": "spykingcircus2 on tetrode0", - "dataset": "toy0", - "run_sorter_params": { - "sorter_name": "spykingcircus", - "docker_image": True - }, - }, - } - # this initilizes a folder - study = GroundTruthStudy.create(study_folder=study_folder, datasets=datasets, cases=cases, - levels=["sorter_name", "dataset"]) - - - # all cases in one function - study.run_sorters() - - # Collect comparisons - # - # You can collect in one shot all results and run the - # GroundTruthComparison on it. - # So you can have fine access to all individual results. - # - # Note: use exhaustive_gt=True when you know exactly how many - # units in the ground truth (for synthetic datasets) - - # run all comparisons and loop over the results - study.run_comparisons(exhaustive_gt=True) - for key, comp in study.comparisons.items(): - print('*' * 10) - print(key) - # raw counting of tp/fp/... - print(comp.count_score) - # summary - comp.print_summary() - perf_unit = comp.get_performance(method='by_unit') - perf_avg = comp.get_performance(method='pooled_with_average') - # some plots - m = comp.get_confusion_matrix() - w_comp = sw.plot_agreement_matrix(sorting_comparison=comp) - - # Collect synthetic dataframes and display - # As shown previously, the performance is returned as a pandas dataframe. - # The spikeinterface.comparison.get_performance_by_unit() function, - # gathers all the outputs in the study folder and merges them into a single dataframe. - # Same idea for spikeinterface.comparison.get_count_units() - - # this is a dataframe - perfs = study.get_performance_by_unit() - - # this is a dataframe - unit_counts = study.get_count_units() - - # we can also access run times - run_times = study.get_run_times() - print(run_times) - - # Easy plotting with seaborn - fig1, ax1 = plt.subplots() - sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1) - ax1.set_title('Run times') - - ############################################################################## - - fig2, ax2 = plt.subplots() - sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2) - ax2.set_title('Recall') - ax2.set_ylim(-0.1, 1.1) - - .. _symmetric: 2. Compare the output of two spike sorters (symmetric comparison) @@ -537,35 +412,3 @@ sorting analyzers from day 1 (:code:`analyzer_day1`) to day 5 (:code:`analyzer_d # match all m_tcmp = sc.compare_multiple_templates(waveform_list=analyzer_list, name_list=["D1", "D2", "D3", "D4", "D5"]) - - - -Benchmark spike collisions --------------------------- - -SpikeInterface also has a specific toolset to benchmark how well sorters are at recovering spikes in "collision". - -We have three classes to handle collision-specific comparisons, and also to quantify the effects on correlogram -estimation: - - * :py:class:`~spikeinterface.comparison.CollisionGTComparison` - * :py:class:`~spikeinterface.comparison.CorrelogramGTComparison` - * :py:class:`~spikeinterface.comparison.CollisionGTStudy` - * :py:class:`~spikeinterface.comparison.CorrelogramGTStudy` - -For more details, checkout the following paper: - -`Samuel Garcia, Alessio P. Buccino and Pierre Yger. "How Do Spike Collisions Affect Spike Sorting Performance?" `_ - - -Hybrid recording ----------------- - -To benchmark spike sorting results, we need ground-truth spiking activity. -This can be generated with artificial simulations, e.g., using `MEArec `_, or -alternatively by generating so-called "hybrid" recordings. - -The :py:mod:`~spikeinterface.comparison` module includes functions to generate such "hybrid" recordings: - - * :py:func:`~spikeinterface.comparison.create_hybrid_units_recording`: add new units to an existing recording - * :py:func:`~spikeinterface.comparison.create_hybrid_spikes_recording`: add new spikes to existing units in a recording diff --git a/doc/releases/0.101.2.rst b/doc/releases/0.101.2.rst new file mode 100644 index 0000000000..e54546ddfb --- /dev/null +++ b/doc/releases/0.101.2.rst @@ -0,0 +1,66 @@ +.. _release0.101.2: + +SpikeInterface 0.101.2 release notes +------------------------------------ + +4th October 2024 + +Minor release with bug fixes + +core: + +* Fix `random_spikes_selection()` (#3456) +* Expose `backend_options` at the analyzer level to set `storage_options` and `saving_options` (#3446) +* Avoid warnings in `SortingAnalyzer` (#3455) +* Fix `reset_global_job_kwargs` (#3452) +* Allow to save recordingless analyzer as (#3443) +* Fix compute analyzer pipeline with tmp recording (#3433) +* Fix bug in saving zarr recordings (#3432) +* Set `run_info` to `None` for `load_waveforms` (#3430) +* Fix integer overflow in parallel computing (#3426) +* Refactor `pandas` save load and `convert_dtypes` (#3412) +* Add spike-train based lazy `SortingGenerator` (#2227) + + +extractors: + +* Improve IBL recording extractors by PID (#3449) + +sorters: + +* Get default encoding for `Popen` (#3439) + +postprocessing: + +* Add `max_threads_per_process` and `mp_context` to pca by channel computation and PCA metrics (#3434) + +widgets: + +* Fix metrics widgets for convert_dtypes (#3417) +* Fix plot motion for multi-segment (#3414) + +motion correction: + +* Auto-cast recording to float prior to interpolation (#3415) + +documentation: + +* Add docstring for `generate_unit_locations` (#3418) +* Add `get_channel_locations` to the base recording API (#3403) + +continuous integration: + +* Enable testing arm64 Mac architecture in the CI (#3422) +* Add kachery_zone secret (#3416) + +testing: + +* Relax causal filter tests (#3445) + +Contributors: + +* @alejoe91 +* @h-mayorquin +* @jiumao2 +* @samuelgarcia +* @zm711 diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index c8038387f9..2851f8ab4a 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.101.2.rst releases/0.101.1.rst releases/0.101.0.rst releases/0.100.8.rst @@ -44,6 +45,11 @@ Release notes releases/0.9.1.rst +Version 0.101.2 +=============== + +* Minor release with bug fixes + Version 0.101.1 =============== diff --git a/pyproject.toml b/pyproject.toml index c1c02db8db..a43ab63c8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.101.1" +version = "0.102.0" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -124,16 +124,16 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_extractors = [ # Functions to download data in neo test suite "pooch>=1.8.2", "datalad>=1.0.2", - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_preprocessing = [ @@ -173,8 +173,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -197,8 +197,8 @@ docs = [ "datalad>=1.0.2", # for release we need pypi, so this needs to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 97fb95b623..306c12d516 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -30,5 +30,5 @@ # This flag must be set to False for release # This avoids using versioning that contains ".dev0" (and this is a better choice) # This is mainly useful when using run_sorter in a container and spikeinterface install -# DEV_MODE = True -DEV_MODE = False +DEV_MODE = True +# DEV_MODE = False diff --git a/src/spikeinterface/benchmark/__init__.py b/src/spikeinterface/benchmark/__init__.py new file mode 100644 index 0000000000..3cf0c6a6f6 --- /dev/null +++ b/src/spikeinterface/benchmark/__init__.py @@ -0,0 +1,7 @@ +""" +Module to benchmark: + * sorters + * some sorting components (clustering, motion, template matching) +""" + +from .benchmark_sorter import SorterStudy diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/benchmark/benchmark_base.py similarity index 95% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py rename to src/spikeinterface/benchmark/benchmark_base.py index 4d6dd43bce..b9cbf269c8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -131,7 +131,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): return cls(study_folder) - def create_benchmark(self): + def create_benchmark(self, key): raise NotImplementedError def scan_folder(self): @@ -258,25 +258,9 @@ def get_run_times(self, case_keys=None): return df def plot_run_times(self, case_keys=None): - if case_keys is None: - case_keys = list(self.cases.keys()) - run_times = self.get_run_times(case_keys=case_keys) - - colors = self.get_colors() - import matplotlib.pyplot as plt + from .benchmark_plot_tools import plot_run_times - fig, ax = plt.subplots() - labels = [] - for i, key in enumerate(case_keys): - labels.append(self.cases[key]["label"]) - rt = run_times.at[key, "run_times"] - ax.bar(i, rt, width=0.8, color=colors[key]) - ax.set_xticks(np.arange(len(case_keys))) - ax.set_xticklabels(labels, rotation=45.0) - return fig - - # ax = run_times.plot(kind="bar") - # return ax.figure + return plot_run_times(self, case_keys=case_keys) def compute_results(self, case_keys=None, verbose=False, **result_params): if case_keys is None: @@ -462,10 +446,3 @@ def run(self): def compute_result(self): # run becnhmark result raise NotImplementedError - - -def _simpleaxis(ax): - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.get_xaxis().tick_bottom() - ax.get_yaxis().tick_left() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py similarity index 92% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py rename to src/spikeinterface/benchmark/benchmark_clustering.py index 92fcda35d9..1c731ecb64 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -11,8 +11,7 @@ import numpy as np - -from .benchmark_tools import BenchmarkStudy, Benchmark +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel @@ -161,49 +160,21 @@ def get_count_units(self, case_keys=None, well_detected_score=None, redundant_sc return count_units - def plot_unit_counts(self, case_keys=None, figsize=None, **extra_kwargs): - from spikeinterface.widgets.widget_list import plot_study_unit_counts + # plotting by methods + def plot_unit_counts(self, **kwargs): + from .benchmark_plot_tools import plot_unit_counts - plot_study_unit_counts(self, case_keys, figsize=figsize, **extra_kwargs) + return plot_unit_counts(self, **kwargs) - def plot_agreements(self, case_keys=None, figsize=(15, 15)): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + return plot_agreement_matrix(self, **kwargs) - for count, key in enumerate(case_keys): - ax = axs[0, count] - ax.set_title(self.cases[key]["label"]) - plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + def plot_performances_vs_snr(self, **kwargs): + from .benchmark_plot_tools import plot_performances_vs_snr - return fig - - def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt - - fig, axes = plt.subplots(ncols=1, nrows=3, figsize=figsize) - - for count, k in enumerate(("accuracy", "recall", "precision")): - - ax = axes[count] - for key in case_keys: - label = self.cases[key]["label"] - - analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension("quality_metrics").get_data() - x = metrics["snr"].values - y = self.get_result(key)["gt_comparison"].get_performance()[k].values - ax.scatter(x, y, marker=".", label=label) - ax.set_title(k) - - if count == 2: - ax.legend() - - return fig + return plot_performances_vs_snr(self, **kwargs) def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py similarity index 84% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py rename to src/spikeinterface/benchmark/benchmark_matching.py index ab1523d13a..c53567f460 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -9,11 +9,8 @@ ) import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.sortingcomponents.tools import remove_empty_templates -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.sparsity import compute_sparsity class MatchingBenchmark(Benchmark): @@ -64,41 +61,15 @@ def create_benchmark(self, key): benchmark = MatchingBenchmark(recording, gt_sorting, params) return benchmark - def plot_agreements(self, case_keys=None, figsize=None): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + return plot_agreement_matrix(self, **kwargs) - for count, key in enumerate(case_keys): - ax = axs[0, count] - ax.set_title(self.cases[key]["label"]) - plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + def plot_performances_vs_snr(self, **kwargs): + from .benchmark_plot_tools import plot_performances_vs_snr - def plot_performances_vs_snr(self, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"]): - if case_keys is None: - case_keys = list(self.cases.keys()) - - fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False) - - for count, k in enumerate(metrics): - - ax = axs[count, 0] - for key in case_keys: - label = self.cases[key]["label"] - - analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension("quality_metrics").get_data() - x = metrics["snr"].values - y = self.get_result(key)["gt_comparison"].get_performance()[k].values - ax.scatter(x, y, marker=".", label=label) - ax.set_title(k) - - if count == 2: - ax.legend() - - return fig + return plot_performances_vs_snr(self, **kwargs) def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/benchmark/benchmark_motion_estimation.py similarity index 99% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py rename to src/spikeinterface/benchmark/benchmark_motion_estimation.py index ec7e1e24a8..abb2a51bae 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_estimation.py @@ -8,7 +8,8 @@ import numpy as np from spikeinterface.core import get_noise_levels -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis +from .benchmark_base import Benchmark, BenchmarkStudy +from .benchmark_plot_tools import _simpleaxis from spikeinterface.sortingcomponents.motion import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/benchmark/benchmark_motion_interpolation.py similarity index 98% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py rename to src/spikeinterface/benchmark/benchmark_motion_interpolation.py index 38365adfd1..ab72a1f9bd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_interpolation.py @@ -10,7 +10,7 @@ from spikeinterface.curation import MergeUnitsSorting -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis +from .benchmark_base import Benchmark, BenchmarkStudy class MotionInterpolationBenchmark(Benchmark): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/benchmark/benchmark_peak_detection.py similarity index 98% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py rename to src/spikeinterface/benchmark/benchmark_peak_detection.py index 7d862343d2..77b5e0025c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_detection.py @@ -12,10 +12,9 @@ import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel class PeakDetectionBenchmark(Benchmark): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/benchmark/benchmark_peak_localization.py similarity index 99% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py rename to src/spikeinterface/benchmark/benchmark_peak_localization.py index 05d142113b..399729fa29 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/benchmark_peak_localization.py @@ -6,7 +6,7 @@ compute_grid_convolution, ) import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/benchmark/benchmark_peak_selection.py similarity index 98% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py rename to src/spikeinterface/benchmark/benchmark_peak_selection.py index 008de2d931..41edea156f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_selection.py @@ -6,15 +6,9 @@ from spikeinterface.comparison.comparisontools import make_matching_events from spikeinterface.core import get_noise_levels -import time -import string, random -import pylab as plt -import os import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy -from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from .benchmark_base import Benchmark, BenchmarkStudy class PeakSelectionBenchmark(Benchmark): diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py new file mode 100644 index 0000000000..a6e9b6dacc --- /dev/null +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -0,0 +1,243 @@ +import numpy as np + + +def _simpleaxis(ax): + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.get_xaxis().tick_bottom() + ax.get_yaxis().tick_left() + + +def plot_run_times(study, case_keys=None): + """ + Plot run times for a BenchmarkStudy. + + Parameters + ---------- + study : SorterStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + run_times = study.get_run_times(case_keys=case_keys) + + colors = study.get_colors() + + fig, ax = plt.subplots() + labels = [] + for i, key in enumerate(case_keys): + labels.append(study.cases[key]["label"]) + rt = run_times.at[key, "run_times"] + ax.bar(i, rt, width=0.8, color=colors[key]) + ax.set_xticks(np.arange(len(case_keys))) + ax.set_xticklabels(labels, rotation=45.0) + return fig + + +def plot_unit_counts(study, case_keys=None): + """ + Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" + + Parameters + ---------- + study : SorterStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + from spikeinterface.widgets.utils import get_some_colors + + if case_keys is None: + case_keys = list(study.cases.keys()) + + count_units = study.get_count_units(case_keys=case_keys) + + fig, ax = plt.subplots() + + columns = count_units.columns.tolist() + columns.remove("num_gt") + columns.remove("num_sorter") + + ncol = len(columns) + + colors = get_some_colors(columns, color_engine="auto", map_name="hot") + colors["num_well_detected"] = "green" + + xticklabels = [] + for i, key in enumerate(case_keys): + for c, col in enumerate(columns): + x = i + 1 + c / (ncol + 1) + y = count_units.loc[key, col] + if not "well_detected" in col: + y = -y + + if i == 0: + label = col.replace("num_", "").replace("_", " ").title() + else: + label = None + + ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) + + xticklabels.append(study.cases[key]["label"]) + + ax.set_xticks(np.arange(len(case_keys)) + 1) + ax.set_xticklabels(xticklabels) + ax.legend() + + return fig + + +def plot_performances(study, mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None): + """ + Plot performances over case for a study. + + Parameters + ---------- + study : GroundTruthStudy + A study object. + mode : "ordered" | "snr" | "swarm", default: "ordered" + Which plot mode to use: + + * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy + * "snr": plot performance metrics vs snr + * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) + performance_names : list or tuple, default: ("accuracy", "precision", "recall") + Which performances to plot ("accuracy", "precision", "recall") + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + import pandas as pd + import seaborn as sns + + if case_keys is None: + case_keys = list(study.cases.keys()) + + perfs = study.get_performance_by_unit(case_keys=case_keys) + colors = study.get_colors() + + if mode in ("ordered", "snr"): + num_axes = len(performance_names) + fig, axs = plt.subplots(ncols=num_axes) + else: + fig, ax = plt.subplots() + + if mode == "ordered": + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + for key in case_keys: + label = study.cases[key]["label"] + val = perfs.xs(key).loc[:, performance_name].values + val = np.sort(val)[::-1] + ax.plot(val, label=label, c=colors[key]) + ax.set_title(performance_name) + if count == len(performance_names) - 1: + ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) + + elif mode == "snr": + metric_name = mode + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + + max_metric = 0 + for key in case_keys: + x = study.get_metrics(key).loc[:, metric_name].values + y = perfs.xs(key).loc[:, performance_name].values + label = study.cases[key]["label"] + ax.scatter(x, y, s=10, label=label, color=colors[key]) + max_metric = max(max_metric, np.max(x)) + ax.set_title(performance_name) + ax.set_xlim(0, max_metric * 1.05) + ax.set_ylim(0, 1.05) + if count == 0: + ax.legend(loc="lower right") + + elif mode == "swarm": + levels = perfs.index.names + df = pd.melt( + perfs.reset_index(), + id_vars=levels, + var_name="Metric", + value_name="Score", + value_vars=performance_names, + ) + df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True, ax=ax) + + +def plot_agreement_matrix(study, ordered=True, case_keys=None): + """ + Plot agreement matri ces for cases in a study. + + Parameters + ---------- + study : GroundTruthStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + ordered : bool + Order units with best agreement scores. + This enable to see agreement on a diagonal. + """ + + import matplotlib.pyplot as plt + from spikeinterface.widgets import AgreementMatrixWidget + + if case_keys is None: + case_keys = list(study.cases.keys()) + + num_axes = len(case_keys) + fig, axs = plt.subplots(ncols=num_axes) + + for count, key in enumerate(case_keys): + ax = axs.flatten()[count] + comp = study.get_result(key)["gt_comparison"] + + unit_ticks = len(comp.sorting1.unit_ids) <= 16 + count_text = len(comp.sorting1.unit_ids) <= 16 + + AgreementMatrixWidget( + comp, ordered=ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax + ) + label = study.cases[key]["label"] + ax.set_xlabel(label) + + if count > 0: + ax.set_ylabel(None) + ax.set_yticks([]) + ax.set_xticks([]) + + +def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"]): + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False) + + for count, k in enumerate(metrics): + + ax = axs[count, 0] + for key in case_keys: + label = study.cases[key]["label"] + + analyzer = study.get_sorting_analyzer(key) + metrics = analyzer.get_extension("quality_metrics").get_data() + x = metrics["snr"].values + y = study.get_result(key)["gt_comparison"].get_performance()[k].values + ax.scatter(x, y, marker=".", label=label) + ax.set_title(k) + + ax.set_ylim(0, 1.05) + + if count == 2: + ax.legend() + + return fig diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py new file mode 100644 index 0000000000..f9267c785a --- /dev/null +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -0,0 +1,135 @@ +""" +This replace the previous `GroundTruthStudy` +""" + +import numpy as np +from ..core import NumpySorting +from .benchmark_base import Benchmark, BenchmarkStudy +from ..sorters import run_sorter +from spikeinterface.comparison import compare_sorter_to_ground_truth + + +# TODO later integrate CollisionGTComparison optionally in this class. + + +class SorterBenchmark(Benchmark): + def __init__(self, recording, gt_sorting, params, sorter_folder): + self.recording = recording + self.gt_sorting = gt_sorting + self.params = params + self.sorter_folder = sorter_folder + self.result = {} + + def run(self): + # run one sorter sorter_name is must be in params + raw_sorting = run_sorter(recording=self.recording, folder=self.sorter_folder, **self.params) + sorting = NumpySorting.from_sorting(raw_sorting) + self.result = {"sorting": sorting} + + def compute_result(self): + # run becnhmark result + sorting = self.result["sorting"] + comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) + self.result["gt_comparison"] = comp + + _run_key_saved = [ + ("sorting", "sorting"), + ] + _result_key_saved = [ + ("gt_comparison", "pickle"), + ] + + +class SorterStudy(BenchmarkStudy): + """ + This class is used to tests several sorter in several situtation. + This replace the previous GroundTruthStudy with more flexibility. + """ + + benchmark_class = SorterBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + benchmark = SorterBenchmark(recording, gt_sorting, params, sorter_folder) + return benchmark + + def get_performance_by_unit(self, case_keys=None): + import pandas as pd + + if case_keys is None: + case_keys = self.cases.keys() + + perf_by_unit = [] + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + + perf = comp.get_performance(method="by_unit", output="pandas") + + if isinstance(key, str): + perf[self.levels] = key + elif isinstance(key, tuple): + for col, k in zip(self.levels, key): + perf[col] = k + + perf = perf.reset_index() + perf_by_unit.append(perf) + + perf_by_unit = pd.concat(perf_by_unit) + perf_by_unit = perf_by_unit.set_index(self.levels) + perf_by_unit = perf_by_unit.sort_index() + return perf_by_unit + + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) + else: + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) + + columns = ["num_gt", "num_sorter", "num_well_detected"] + key0 = case_keys[0] + comp = self.get_result(key0)["gt_comparison"] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) + + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units + + # plotting as methods + def plot_unit_counts(self, **kwargs): + from .benchmark_plot_tools import plot_unit_counts + + return plot_unit_counts(self, **kwargs) + + def plot_performances(self, **kwargs): + from .benchmark_plot_tools import plot_performances + + return plot_performances(self, **kwargs) + + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix + + return plot_agreement_matrix(self, **kwargs) diff --git a/src/spikeinterface/benchmark/benchmark_tools.py b/src/spikeinterface/benchmark/benchmark_tools.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py similarity index 100% rename from src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py rename to src/spikeinterface/benchmark/tests/common_benchmark_testing.py diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py similarity index 88% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py rename to src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index bc36fb607c..3f574fd058 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -3,11 +3,13 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset -from spikeinterface.sortingcomponents.benchmark.benchmark_clustering import ClusteringStudy +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset +from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel +from pathlib import Path + @pytest.mark.skip() def test_benchmark_clustering(create_cache_folder): @@ -78,4 +80,5 @@ def test_benchmark_clustering(create_cache_folder): if __name__ == "__main__": - test_benchmark_clustering() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_clustering(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py similarity index 86% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py rename to src/spikeinterface/benchmark/tests/test_benchmark_matching.py index 71a5f282a8..000a00faf5 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py @@ -1,6 +1,7 @@ import pytest import shutil +from pathlib import Path from spikeinterface.core import ( @@ -8,11 +9,11 @@ compute_sparsity, ) -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( +from spikeinterface.benchmark.tests.common_benchmark_testing import ( make_dataset, compute_gt_templates, ) -from spikeinterface.sortingcomponents.benchmark.benchmark_matching import MatchingStudy +from spikeinterface.benchmark.benchmark_matching import MatchingStudy @pytest.mark.skip() @@ -72,4 +73,5 @@ def test_benchmark_matching(create_cache_folder): if __name__ == "__main__": - test_benchmark_matching() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_matching(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/benchmark/tests/test_benchmark_motion_estimation.py similarity index 86% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py rename to src/spikeinterface/benchmark/tests/test_benchmark_motion_estimation.py index 78a9eb7dbc..65cacfc8a0 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_motion_estimation.py @@ -2,12 +2,13 @@ import shutil +from pathlib import Path -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( +from spikeinterface.benchmark.tests.common_benchmark_testing import ( make_drifting_dataset, ) -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import MotionEstimationStudy +from spikeinterface.benchmark.benchmark_motion_estimation import MotionEstimationStudy @pytest.mark.skip() @@ -75,4 +76,5 @@ def test_benchmark_motion_estimaton(create_cache_folder): if __name__ == "__main__": - test_benchmark_motion_estimaton() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_motion_estimaton(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py similarity index 90% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py rename to src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py index 18def37d54..f7afd7a8bc 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py @@ -4,14 +4,14 @@ import numpy as np import shutil +from pathlib import Path - -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( +from spikeinterface.benchmark.tests.common_benchmark_testing import ( make_drifting_dataset, ) -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import ( +from spikeinterface.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy +from spikeinterface.benchmark.benchmark_motion_estimation import ( # get_unit_displacement, get_gt_motion_from_unit_displacement, ) @@ -139,4 +139,5 @@ def test_benchmark_motion_interpolation(create_cache_folder): if __name__ == "__main__": - test_benchmark_motion_interpolation() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_motion_interpolation(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py similarity index 87% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py rename to src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index dffe1529b7..d45ac0b4ce 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -1,10 +1,10 @@ import pytest import shutil +from pathlib import Path - -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset -from spikeinterface.sortingcomponents.benchmark.benchmark_peak_detection import PeakDetectionStudy +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset +from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel @@ -69,5 +69,5 @@ def test_benchmark_peak_detection(create_cache_folder): if __name__ == "__main__": - # test_benchmark_peak_localization() - test_benchmark_peak_detection() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_peak_detection(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py similarity index 79% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py rename to src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py index 23060c4ddb..3b6240cb10 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py @@ -1,12 +1,12 @@ import pytest import shutil +from pathlib import Path +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset - -from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import PeakLocalizationStudy -from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import UnitLocalizationStudy +from spikeinterface.benchmark.benchmark_peak_localization import PeakLocalizationStudy +from spikeinterface.benchmark.benchmark_peak_localization import UnitLocalizationStudy @pytest.mark.skip() @@ -28,7 +28,8 @@ def test_benchmark_peak_localization(create_cache_folder): "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, "params": { "method": method, - "method_kwargs": {"ms_before": 2}, + "ms_before": 2.0, + "method_kwargs": {}, }, } @@ -60,7 +61,7 @@ def test_benchmark_unit_locations(create_cache_folder): cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") - recording, gt_sorting = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset() # create study study_folder = cache_folder / "study_unit_locations" @@ -71,7 +72,7 @@ def test_benchmark_unit_locations(create_cache_folder): "label": f"{method} on toy", "dataset": "toy", "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, - "params": {"method": method, "method_kwargs": {"ms_before": 2}}, + "params": {"method": method, "ms_before": 2.0, "method_kwargs": {}}, } if study_folder.exists(): @@ -99,5 +100,6 @@ def test_benchmark_unit_locations(create_cache_folder): if __name__ == "__main__": - # test_benchmark_peak_localization() - test_benchmark_unit_locations() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + # test_benchmark_peak_localization(cache_folder) + test_benchmark_unit_locations(cache_folder) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_selection.py new file mode 100644 index 0000000000..a6eb090a9d --- /dev/null +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_selection.py @@ -0,0 +1,13 @@ +import pytest + +from pathlib import Path + + +@pytest.mark.skip() +def test_benchmark_peak_selection(create_cache_folder): + cache_folder = create_cache_folder + + +if __name__ == "__main__": + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_peak_selection(cache_folder) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py similarity index 57% rename from src/spikeinterface/comparison/tests/test_groundtruthstudy.py rename to src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index a92d6e9f77..2564d58d52 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -4,12 +4,12 @@ from spikeinterface import generate_ground_truth_recording from spikeinterface.preprocessing import bandpass_filter -from spikeinterface.comparison import GroundTruthStudy +from spikeinterface.benchmark import SorterStudy @pytest.fixture(scope="module") def setup_module(tmp_path_factory): - study_folder = tmp_path_factory.mktemp("study_folder") + study_folder = tmp_path_factory.mktemp("sorter_study_folder") if study_folder.is_dir(): shutil.rmtree(study_folder) create_a_study(study_folder) @@ -36,63 +36,53 @@ def create_a_study(study_folder): ("tdc2", "no-preprocess", "tetrode"): { "label": "tridesclous2 without preprocessing and standard params", "dataset": "toy_tetrode", - "run_sorter_params": { + "params": { "sorter_name": "tridesclous2", }, - "comparison_params": {}, }, # ("tdc2", "with-preprocess", "probe32"): { "label": "tridesclous2 with preprocessing standar params", "dataset": "toy_probe32_preprocess", - "run_sorter_params": { + "params": { "sorter_name": "tridesclous2", }, - "comparison_params": {}, }, - # we comment this at the moement because SC2 is quite slow for testing - # ("sc2", "no-preprocess", "tetrode"): { - # "label": "spykingcircus2 without preprocessing standar params", - # "dataset": "toy_tetrode", - # "run_sorter_params": { - # "sorter_name": "spykingcircus2", - # }, - # "comparison_params": { - # }, - # }, } - study = GroundTruthStudy.create( + study = SorterStudy.create( study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"] ) # print(study) -def test_GroundTruthStudy(setup_module): +def test_SorterStudy(setup_module): + # job_kwargs = dict(n_jobs=2, chunk_duration="1s") + study_folder = setup_module - study = GroundTruthStudy(study_folder) + study = SorterStudy(study_folder) print(study) - study.run_sorters(verbose=True) - - print(study.sortings) - - print(study.comparisons) - study.run_comparisons() - print(study.comparisons) + # # this run the sorters + # study.run() - study.create_sorting_analyzer_gt(n_jobs=-1) - - study.compute_metrics() + # # this run comparisons + # study.compute_results() + print(study) - for key in study.cases: - metrics = study.get_metrics(key) - print(metrics) + # this is from the base class + rt = study.get_run_times() + # rt = study.plot_run_times() + # import matplotlib.pyplot as plt + # plt.show() - study.get_performance_by_unit() - study.get_count_units() + perf_by_unit = study.get_performance_by_unit() + # print(perf_by_unit) + count_units = study.get_count_units() + # print(count_units) if __name__ == "__main__": - setup_module() - test_GroundTruthStudy() + study_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy" + create_a_study(study_folder) + test_SorterStudy(study_folder) diff --git a/src/spikeinterface/comparison/__init__.py b/src/spikeinterface/comparison/__init__.py index 648ef4ed70..f4ada19f73 100644 --- a/src/spikeinterface/comparison/__init__.py +++ b/src/spikeinterface/comparison/__init__.py @@ -30,8 +30,8 @@ ) from .groundtruthstudy import GroundTruthStudy -from .collision import CollisionGTComparison, CollisionGTStudy -from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy +from .collision import CollisionGTComparison +from .correlogram import CorrelogramGTComparison from .hybrid import ( HybridSpikesRecording, diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 574bd16093..12bfab84ed 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -172,71 +172,75 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good return similarities, recall_scores, pair_names -class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - _kwargs = dict() - _kwargs.update(kwargs) - _kwargs["exhaustive_gt"] = exhaustive_gt - _kwargs["collision_lag"] = collision_lag - _kwargs["nbins"] = nbins - GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) - self.exhaustive_gt = exhaustive_gt - self.collision_lag = collision_lag - - def get_lags(self, key): - comp = self.comparisons[key] - fs = comp.sorting1.get_sampling_frequency() - lags = comp.bins / fs * 1000.0 - return lags - - def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): - import sklearn - - if case_keys is None: - case_keys = self.cases.keys() - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for key in case_keys: - templates = self.get_templates(key) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - comp = self.comparisons[key] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity, good_only=good_only, min_accuracy=min_accuracy - ) - self.all_similarities[key] = similarities - self.all_recall_scores[key] = recall_scores - - def get_mean_over_similarity_range(self, similarity_range, key): - idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) - all_similarities = self.all_similarities[key][idx] - all_recall_scores = self.all_recall_scores[key][idx] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = np.nanmean(all_recall_scores, axis=0) - - return mean_recall_scores - - def get_lag_profile_over_similarity_bins(self, similarity_bins, key): - all_similarities = self.all_similarities[key] - all_recall_scores = self.all_recall_scores[key] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) - result[(cmin, cmax)] = mean_recall_scores - - return result +# This is removed at the moment. +# We need to move this maybe one day in benchmark. +# please do not delete this + +# class CollisionGTStudy(GroundTruthStudy): +# def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): +# _kwargs = dict() +# _kwargs.update(kwargs) +# _kwargs["exhaustive_gt"] = exhaustive_gt +# _kwargs["collision_lag"] = collision_lag +# _kwargs["nbins"] = nbins +# GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) +# self.exhaustive_gt = exhaustive_gt +# self.collision_lag = collision_lag + +# def get_lags(self, key): +# comp = self.comparisons[key] +# fs = comp.sorting1.get_sampling_frequency() +# lags = comp.bins / fs * 1000.0 +# return lags + +# def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): +# import sklearn + +# if case_keys is None: +# case_keys = self.cases.keys() + +# self.all_similarities = {} +# self.all_recall_scores = {} +# self.good_only = good_only + +# for key in case_keys: +# templates = self.get_templates(key) +# flat_templates = templates.reshape(templates.shape[0], -1) +# similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) +# comp = self.comparisons[key] +# similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( +# similarity, good_only=good_only, min_accuracy=min_accuracy +# ) +# self.all_similarities[key] = similarities +# self.all_recall_scores[key] = recall_scores + +# def get_mean_over_similarity_range(self, similarity_range, key): +# idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) +# all_similarities = self.all_similarities[key][idx] +# all_recall_scores = self.all_recall_scores[key][idx] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_recall_scores = all_recall_scores[order, :] + +# mean_recall_scores = np.nanmean(all_recall_scores, axis=0) + +# return mean_recall_scores + +# def get_lag_profile_over_similarity_bins(self, similarity_bins, key): +# all_similarities = self.all_similarities[key] +# all_recall_scores = self.all_recall_scores[key] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_recall_scores = all_recall_scores[order, :] + +# result = {} + +# for i in range(similarity_bins.size - 1): +# cmin, cmax = similarity_bins[i], similarity_bins[i + 1] +# amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) +# mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) +# result[(cmin, cmax)] = mean_recall_scores + +# return result diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 0cafef2c12..717d11a3fa 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -128,57 +128,60 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): return similarities, errors -class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons( - self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs - ): - _kwargs = dict() - _kwargs.update(kwargs) - _kwargs["exhaustive_gt"] = exhaustive_gt - _kwargs["window_ms"] = window_ms - _kwargs["bin_ms"] = bin_ms - _kwargs["well_detected_score"] = well_detected_score - GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) - self.exhaustive_gt = exhaustive_gt - - @property - def time_bins(self): - for key, value in self.comparisons.items(): - return value.time_bins - - def precompute_scores_by_similarities(self, case_keys=None, good_only=True): - import sklearn.metrics - - if case_keys is None: - case_keys = self.cases.keys() - - self.all_similarities = {} - self.all_errors = {} - - for key in case_keys: - templates = self.get_templates(key) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - comp = self.comparisons[key] - similarities, errors = comp.compute_correlogram_by_similarity(similarity) - - self.all_similarities[key] = similarities - self.all_errors[key] = errors - - def get_error_profile_over_similarity_bins(self, similarity_bins, key): - all_similarities = self.all_similarities[key] - all_errors = self.all_errors[key] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_errors = np.nanmean(all_errors[amin:amax], axis=0) - result[(cmin, cmax)] = mean_errors - - return result +# This is removed at the moment. +# We need to move this maybe one day in benchmark + +# class CorrelogramGTStudy(GroundTruthStudy): +# def run_comparisons( +# self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs +# ): +# _kwargs = dict() +# _kwargs.update(kwargs) +# _kwargs["exhaustive_gt"] = exhaustive_gt +# _kwargs["window_ms"] = window_ms +# _kwargs["bin_ms"] = bin_ms +# _kwargs["well_detected_score"] = well_detected_score +# GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) +# self.exhaustive_gt = exhaustive_gt + +# @property +# def time_bins(self): +# for key, value in self.comparisons.items(): +# return value.time_bins + +# def precompute_scores_by_similarities(self, case_keys=None, good_only=True): +# import sklearn.metrics + +# if case_keys is None: +# case_keys = self.cases.keys() + +# self.all_similarities = {} +# self.all_errors = {} + +# for key in case_keys: +# templates = self.get_templates(key) +# flat_templates = templates.reshape(templates.shape[0], -1) +# similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) +# comp = self.comparisons[key] +# similarities, errors = comp.compute_correlogram_by_similarity(similarity) + +# self.all_similarities[key] = similarities +# self.all_errors[key] = errors + +# def get_error_profile_over_similarity_bins(self, similarity_bins, key): +# all_similarities = self.all_similarities[key] +# all_errors = self.all_errors[key] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_errors = all_errors[order, :] + +# result = {} + +# for i in range(similarity_bins.size - 1): +# cmin, cmax = similarity_bins[i], similarity_bins[i + 1] +# amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) +# mean_errors = np.nanmean(all_errors[amin:amax], axis=0) +# result[(cmin, cmax)] = mean_errors + +# return result diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 8929d6983c..df9e1420cb 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -1,441 +1,21 @@ -from __future__ import annotations - -from pathlib import Path -import shutil -import os -import json -import pickle - -import numpy as np - -from spikeinterface.core import load_extractor, create_sorting_analyzer, load_sorting_analyzer -from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder - -from spikeinterface.qualitymetrics import compute_quality_metrics - -from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison - - -# TODO later : save comparison in folders when comparison object will be able to serialize - - -# This is to separate names when the key are tuples when saving folders -# _key_separator = "_##_" -_key_separator = "_-°°-_" +_txt_error_message = """ +GroundTruthStudy has been replaced by SorterStudy with similar API but not back compatible folder loading. +You can do: +from spikeinterface.benchmark import SorterStudy +study = SorterStudy.create(study_folder, datasets=..., cases=..., levels=...) +study.run() # this run sorters +study.compute_results() # this run the comparisons +# and then some ploting +study.plot_agreements() +study.plot_performances_vs_snr() +... +""" class GroundTruthStudy: - """ - This class is an helper function to run any comparison on several "cases" for many ground-truth dataset. - - "cases" refer to: - * several sorters for comparisons - * same sorter with differents parameters - * any combination of these (and more) - - For increased flexibility, cases keys can be a tuple so that we can vary complexity along several - "levels" or "axis" (paremeters or sorters). - In this case, the result dataframes will have `MultiIndex` to handle the different levels. - - A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see - :py:func:`~spikeinterface.core.generate.generate_ground_truth_recording()`). - - This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. - Note that the underlying folder structure is not backward compatible! - - Parameters - ---------- - study_folder : str | Path - Path to folder containing `GroundTruthStudy` - """ - def __init__(self, study_folder): - self.folder = Path(study_folder) - - self.datasets = {} - self.cases = {} - self.sortings = {} - self.comparisons = {} - self.colors = None - - self.scan_folder() + raise RuntimeError(_txt_error_message) @classmethod def create(cls, study_folder, datasets={}, cases={}, levels=None): - # check that cases keys are homogeneous - key0 = list(cases.keys())[0] - if isinstance(key0, str): - assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" - if levels is None: - levels = "level0" - else: - assert isinstance(levels, str) - elif isinstance(key0, tuple): - assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" - num_levels = len(key0) - assert all( - len(key) == num_levels for key in cases.keys() - ), "Keys for cases are not homogeneous, tuple negth differ" - if levels is None: - levels = [f"level{i}" for i in range(num_levels)] - else: - levels = list(levels) - assert len(levels) == num_levels - else: - raise ValueError("Keys for cases must str or tuple") - - study_folder = Path(study_folder) - study_folder.mkdir(exist_ok=False, parents=True) - - (study_folder / "datasets").mkdir() - (study_folder / "datasets" / "recordings").mkdir() - (study_folder / "datasets" / "gt_sortings").mkdir() - (study_folder / "sorters").mkdir() - (study_folder / "sortings").mkdir() - (study_folder / "sortings" / "run_logs").mkdir() - (study_folder / "metrics").mkdir() - (study_folder / "comparisons").mkdir() - - for key, (rec, gt_sorting) in datasets.items(): - assert "/" not in key, "'/' cannot be in the key name!" - assert "\\" not in key, "'\\' cannot be in the key name!" - - # recordings are pickled - rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") - - # sortings are pickled + saved as NumpyFolderSorting - gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") - gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") - - info = {} - info["levels"] = levels - (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") - - # cases is dumped to a pickle file, json is not possible because of the tuple key - (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) - - return cls(study_folder) - - def scan_folder(self): - if not (self.folder / "datasets").exists(): - raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") - - with open(self.folder / "info.json", "r") as f: - self.info = json.load(f) - - self.levels = self.info["levels"] - - for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): - key = rec_file.stem - rec = load_extractor(rec_file) - gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) - self.datasets[key] = (rec, gt_sorting) - - with open(self.folder / "cases.pickle", "rb") as f: - self.cases = pickle.load(f) - - self.sortings = {k: None for k in self.cases} - self.comparisons = {k: None for k in self.cases} - for key in self.cases: - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - if sorting_folder.exists(): - self.sortings[key] = load_extractor(sorting_folder) - - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - if comparison_file.exists(): - with open(comparison_file, mode="rb") as f: - try: - self.comparisons[key] = pickle.load(f) - except Exception: - pass - - def __repr__(self): - t = f"{self.__class__.__name__} {self.folder.stem} \n" - t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" - t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" - num_computed = sum([1 for sorting in self.sortings.values() if sorting is not None]) - t += f" computed: {num_computed}\n" - - return t - - def key_to_str(self, key): - if isinstance(key, str): - return key - elif isinstance(key, tuple): - return _key_separator.join(key) - else: - raise ValueError("Keys for cases must str or tuple") - - def remove_sorting(self, key): - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - comparison_file = self.folder / "comparisons" / self.key_to_str(key) - self.sortings[key] = None - self.comparisons[key] = None - if sorting_folder.exists(): - shutil.rmtree(sorting_folder) - for f in (log_file, comparison_file): - if f.exists(): - f.unlink() - - def set_colors(self, colors=None, map_name="tab20"): - from spikeinterface.widgets import get_some_colors - - if colors is None: - case_keys = list(self.cases.keys()) - self.colors = get_some_colors( - case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0 - ) - else: - self.colors = colors - - def get_colors(self): - if self.colors is None: - self.set_colors() - return self.colors - - def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): - if case_keys is None: - case_keys = self.cases.keys() - - job_list = [] - for key in case_keys: - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - sorting_exists = sorting_folder.exists() - - sorter_folder = self.folder / "sorters" / self.key_to_str(key) - sorter_folder_exists = sorter_folder.exists() - - if keep: - if sorting_exists: - continue - if sorter_folder_exists: - # the sorter folder exists but havent been copied to sortings folder - sorting = read_sorter_folder(sorter_folder, raise_error=False) - if sorting is not None: - # save and skip - self.copy_sortings(case_keys=[key]) - continue - - self.remove_sorting(key) - - if sorter_folder_exists: - shutil.rmtree(sorter_folder) - - params = self.cases[key]["run_sorter_params"].copy() - # this ensure that sorter_name is given - recording, _ = self.datasets[self.cases[key]["dataset"]] - sorter_name = params.pop("sorter_name") - job = dict( - sorter_name=sorter_name, - recording=recording, - output_folder=sorter_folder, - ) - job.update(params) - # the verbose is overwritten and global to all run_sorters - job["verbose"] = verbose - job["with_output"] = False - job_list.append(job) - - run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) - - # TODO later create a list in laucher for engine blocking and non-blocking - if engine not in ("slurm",): - self.copy_sortings(case_keys) - - def copy_sortings(self, case_keys=None, force=True): - if case_keys is None: - case_keys = self.cases.keys() - - for key in case_keys: - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - sorter_folder = self.folder / "sorters" / self.key_to_str(key) - log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - - if (sorter_folder / "spikeinterface_log.json").exists(): - sorting = read_sorter_folder( - sorter_folder, raise_error=False, register_recording=False, sorting_info=False - ) - else: - sorting = None - - if sorting is not None: - if sorting_folder.exists(): - if force: - self.remove_sorting(key) - else: - continue - - sorting = sorting.save(format="numpy_folder", folder=sorting_folder) - self.sortings[key] = sorting - - # copy logs - shutil.copyfile(sorter_folder / "spikeinterface_log.json", log_file) - - def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): - if case_keys is None: - case_keys = self.cases.keys() - - for key in case_keys: - dataset_key = self.cases[key]["dataset"] - _, gt_sorting = self.datasets[dataset_key] - sorting = self.sortings[key] - if sorting is None: - self.comparisons[key] = None - continue - comp = comparison_class(gt_sorting, sorting, **kwargs) - self.comparisons[key] = comp - - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - with open(comparison_file, mode="wb") as f: - pickle.dump(comp, f) - - def get_run_times(self, case_keys=None): - import pandas as pd - - if case_keys is None: - case_keys = self.cases.keys() - - log_folder = self.folder / "sortings" / "run_logs" - - run_times = {} - for key in case_keys: - log_file = log_folder / f"{self.key_to_str(key)}.json" - with open(log_file, mode="r") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - run_times[key] = run_time - - return pd.Series(run_times, name="run_time") - - def create_sorting_analyzer_gt(self, case_keys=None, random_params={}, waveforms_params={}, **job_kwargs): - if case_keys is None: - case_keys = self.cases.keys() - - base_folder = self.folder / "sorting_analyzer" - base_folder.mkdir(exist_ok=True) - - dataset_keys = [self.cases[key]["dataset"] for key in case_keys] - dataset_keys = set(dataset_keys) - for dataset_key in dataset_keys: - # the waveforms depend on the dataset key - folder = base_folder / self.key_to_str(dataset_key) - recording, gt_sorting = self.datasets[dataset_key] - sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) - sorting_analyzer.compute("random_spikes", **random_params) - sorting_analyzer.compute("templates", **job_kwargs) - sorting_analyzer.compute("noise_levels") - - def get_sorting_analyzer(self, case_key=None, dataset_key=None): - if case_key is not None: - dataset_key = self.cases[case_key]["dataset"] - - folder = self.folder / "sorting_analyzer" / self.key_to_str(dataset_key) - sorting_analyzer = load_sorting_analyzer(folder) - return sorting_analyzer - - # def get_templates(self, key, mode="average"): - # analyzer = self.get_sorting_analyzer(case_key=key) - # templates = sorting_analyzer.get_all_templates(mode=mode) - # return templates - - def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): - if case_keys is None: - case_keys = self.cases.keys() - - done = [] - for key in case_keys: - dataset_key = self.cases[key]["dataset"] - if dataset_key in done: - # some case can share the same waveform extractor - continue - done.append(dataset_key) - filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" - if filename.exists(): - if force: - os.remove(filename) - else: - continue - analyzer = self.get_sorting_analyzer(key) - metrics = compute_quality_metrics(analyzer, metric_names=metric_names) - metrics.to_csv(filename, sep="\t", index=True) - - def get_metrics(self, key): - import pandas as pd - - dataset_key = self.cases[key]["dataset"] - - filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" - if not filename.exists(): - return - metrics = pd.read_csv(filename, sep="\t", index_col=0) - dataset_key = self.cases[key]["dataset"] - recording, gt_sorting = self.datasets[dataset_key] - metrics.index = gt_sorting.unit_ids - return metrics - - def get_units_snr(self, key): - return self.get_metrics(key)["snr"] - - def get_performance_by_unit(self, case_keys=None): - import pandas as pd - - if case_keys is None: - case_keys = self.cases.keys() - - perf_by_unit = [] - for key in case_keys: - comp = self.comparisons.get(key, None) - assert comp is not None, "You need to do study.run_comparisons() first" - - perf = comp.get_performance(method="by_unit", output="pandas") - - if isinstance(key, str): - perf[self.levels] = key - elif isinstance(key, tuple): - for col, k in zip(self.levels, key): - perf[col] = k - - perf = perf.reset_index() - perf_by_unit.append(perf) - - perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(self.levels) - perf_by_unit = perf_by_unit.sort_index() - return perf_by_unit - - def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): - import pandas as pd - - if case_keys is None: - case_keys = list(self.cases.keys()) - - if isinstance(case_keys[0], str): - index = pd.Index(case_keys, name=self.levels) - else: - index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) - - columns = ["num_gt", "num_sorter", "num_well_detected"] - comp = self.comparisons[case_keys[0]] - if comp.exhaustive_gt: - columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) - count_units = pd.DataFrame(index=index, columns=columns, dtype=int) - - for key in case_keys: - comp = self.comparisons.get(key, None) - assert comp is not None, "You need to do study.run_comparisons() first" - - gt_sorting = comp.sorting1 - sorting = comp.sorting2 - - count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) - - if comp.exhaustive_gt: - count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) - count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) - count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) - count_units.loc[key, "num_bad"] = comp.count_bad_units() - - return count_units + raise RuntimeError(_txt_error_message) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 6d2d1cbb55..0316b3bab1 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -742,10 +742,10 @@ def synthesize_poisson_spike_vector( # Calculate the number of frames in the refractory period refractory_period_seconds = refractory_period_ms / 1000.0 - refactory_period_frames = int(refractory_period_seconds * sampling_frequency) + refractory_period_frames = int(refractory_period_seconds * sampling_frequency) - is_refactory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates) - if is_refactory_period_too_long: + is_refractory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates) + if is_refractory_period_too_long: raise ValueError( f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}" ) @@ -764,9 +764,9 @@ def synthesize_poisson_spike_vector( binomial_p_modified = modified_firing_rate / sampling_frequency binomial_p_modified = np.minimum(binomial_p_modified, 1.0) - # Generate inter spike frames, add the refactory samples and accumulate for sorted spike frames + # Generate inter spike frames, add the refractory samples and accumulate for sorted spike frames inter_spike_frames = rng.geometric(p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max)) - inter_spike_frames[:, 1:] += refactory_period_frames + inter_spike_frames[:, 1:] += refractory_period_frames spike_frames = np.cumsum(inter_spike_frames, axis=1, out=inter_spike_frames) spike_frames = spike_frames.ravel() @@ -1054,6 +1054,176 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol return spike_train +from spikeinterface.core.basesorting import BaseSortingSegment, BaseSorting + + +class SortingGenerator(BaseSorting): + def __init__( + self, + num_units: int = 20, + sampling_frequency: float = 30_000.0, # in Hz + durations: List[float] = [10.325, 3.5], #  in s for 2 segments + firing_rates: float | np.ndarray = 3.0, + refractory_period_ms: float | np.ndarray = 4.0, # in ms + seed: int = 0, + ): + """ + A class for lazily generate synthetic sorting objects with Poisson spike trains. + + We have two ways of representing spike trains in SpikeInterface: + + - Spike vector (sample_index, unit_index) + - Dictionary of unit_id to spike times + + This class simulates a sorting object that uses a representation based on unit IDs to lists of spike times, + rather than pre-computed spike vectors. It is intended for testing performance differences and functionalities + in data handling and analysis frameworks. For the normal use case of sorting objects with spike_vectors use the + `generate_sorting` function. + + Parameters + ---------- + num_units : int, optional + The number of distinct units (neurons) to simulate. Default is 20. + sampling_frequency : float, optional + The sampling frequency of the spike data in Hz. Default is 30_000.0. + durations : list of float, optional + A list containing the duration in seconds for each segment of the sorting data. Default is [10.325, 3.5], + corresponding to 2 segments. + firing_rates : float or np.ndarray, optional + The firing rate(s) in Hz, which can be specified as a single value applicable to all units or as an array + with individual firing rates for each unit. Default is 3.0. + refractory_period_ms : float or np.ndarray, optional + The refractory period in milliseconds. Can be specified either as a single value for all units or as an + array with different values for each unit. Default is 4.0. + seed : int, default: 0 + The seed for the random number generator to ensure reproducibility. + + Raises + ------ + ValueError + If the refractory period is too long for the given firing rates, which could result in unrealistic + physiological conditions. + + Notes + ----- + This generator simulates the spike trains using a Poisson process. It takes into account the refractory periods + by adjusting the firing rates accordingly. See the notes on `synthesize_poisson_spike_vector` for more details. + + """ + + unit_ids = np.arange(num_units) + super().__init__(sampling_frequency, unit_ids) + + self.num_units = num_units + self.num_segments = len(durations) + self.firing_rates = firing_rates + self.durations = durations + self.refractory_period_seconds = refractory_period_ms / 1000.0 + + is_refractory_period_too_long = np.any(self.refractory_period_seconds >= 1.0 / firing_rates) + if is_refractory_period_too_long: + raise ValueError( + f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}" + ) + + seed = _ensure_seed(seed) + self.seed = seed + + for segment_index in range(self.num_segments): + segment_seed = self.seed + segment_index + segment = SortingGeneratorSegment( + num_units=num_units, + sampling_frequency=sampling_frequency, + duration=durations[segment_index], + firing_rates=firing_rates, + refractory_period_seconds=self.refractory_period_seconds, + seed=segment_seed, + t_start=None, + ) + self.add_sorting_segment(segment) + + self._kwargs = { + "num_units": num_units, + "sampling_frequency": sampling_frequency, + "durations": durations, + "firing_rates": firing_rates, + "refractory_period_ms": refractory_period_ms, + "seed": seed, + } + + +class SortingGeneratorSegment(BaseSortingSegment): + def __init__( + self, + num_units: int, + sampling_frequency: float, + duration: float, + firing_rates: float | np.ndarray, + refractory_period_seconds: float | np.ndarray, + seed: int, + t_start: Optional[float] = None, + ): + self.num_units = num_units + self.duration = duration + self.sampling_frequency = sampling_frequency + self.refractory_period_seconds = refractory_period_seconds + + if np.isscalar(firing_rates): + firing_rates = np.full(num_units, firing_rates, dtype="float64") + + self.firing_rates = firing_rates + + if np.isscalar(self.refractory_period_seconds): + self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64") + + self.segment_seed = seed + self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)} + self.num_samples = math.ceil(sampling_frequency * duration) + super().__init__(t_start) + + def get_unit_spike_train(self, unit_id, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray: + unit_seed = self.units_seed[unit_id] + unit_index = self.parent_extractor.id_to_index(unit_id) + + rng = np.random.default_rng(seed=unit_seed) + + firing_rate = self.firing_rates[unit_index] + refractory_period = self.refractory_period_seconds[unit_index] + + # p is the probably of an spike per tick of the sampling frequency + binomial_p = firing_rate / self.sampling_frequency + # We estimate how many spikes we will have in the duration + max_frames = int(self.duration * self.sampling_frequency) - 1 + max_binomial_p = float(np.max(binomial_p)) + num_spikes_expected = ceil(max_frames * max_binomial_p) + num_spikes_std = int(np.sqrt(num_spikes_expected * (1 - max_binomial_p))) + num_spikes_max = num_spikes_expected + 4 * num_spikes_std + + # Increase the firing rate to take into account the refractory period + modified_firing_rate = firing_rate / (1 - firing_rate * refractory_period) + binomial_p_modified = modified_firing_rate / self.sampling_frequency + binomial_p_modified = np.minimum(binomial_p_modified, 1.0) + + inter_spike_frames = rng.geometric(p=binomial_p_modified, size=num_spikes_max) + spike_frames = np.cumsum(inter_spike_frames) + + refractory_period_frames = int(refractory_period * self.sampling_frequency) + spike_frames[1:] += refractory_period_frames + + if start_frame is not None: + start_index = np.searchsorted(spike_frames, start_frame, side="left") + else: + start_index = 0 + + if end_frame is not None: + end_index = np.searchsorted(spike_frames[start_index:], end_frame, side="left") + else: + end_index = int(self.duration * self.sampling_frequency) + + spike_frames = spike_frames[start_index:end_index] + return spike_frames + + ## Noise generator zone ## class NoiseGeneratorRecording(BaseRecording): """ @@ -1994,6 +2164,53 @@ def generate_unit_locations( distance_strict=False, seed=None, ): + """ + Generate random 3D unit locations based on channel locations and distance constraints. + + This function generates random 3D coordinates for a specified number of units, + ensuring the following: + + 1) the x, y and z coordinates of the units are within a specified range: + * x and y coordinates are within the minimum and maximum x and y coordinates of the channel_locations + plus `margin_um`. + * z coordinates are within a specified range `(minimum_z, maximum_z)` + 2) the distance between any two units is greater than a specified minimum value + + If the minimum distance constraint cannot be met within the allowed number of iterations, + the function can either raise an exception or issue a warning based on the `distance_strict` flag. + + Parameters + ---------- + num_units : int + Number of units to generate locations for. + channel_locations : numpy.ndarray + A 2D array of shape (num_channels, 2), where each row represents the (x, y) coordinates + of a channel. + margin_um : float, default: 20.0 + The margin to add around the minimum and maximum x and y channel coordinates when + generating unit locations + minimum_z : float, default: 5.0 + The minimum z-coordinate value for generated unit locations. + maximum_z : float, default: 40.0 + The maximum z-coordinate value for generated unit locations. + minimum_distance : float, default: 20.0 + The minimum allowable distance in micrometers between any two units + max_iteration : int, default: 100 + The maximum number of iterations to attempt generating unit locations that meet + the minimum distance constraint. + distance_strict : bool, default: False + If True, the function will raise an exception if a solution meeting the distance + constraint cannot be found within the maximum number of iterations. If False, a warning + will be issued. + seed : int or None, optional + Random seed for reproducibility. If None, the seed is not set + + Returns + ------- + units_locations : numpy.ndarray + A 2D array of shape (num_units, 3), where each row represents the (x, y, z) coordinates + of a generated unit location. + """ rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index 23d60a5ac5..38f39c5481 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -97,8 +97,10 @@ def is_set_global_dataset_folder() -> bool: ######################################## +_default_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + global global_job_kwargs -global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) +global_job_kwargs = _default_job_kwargs.copy() global global_job_kwargs_set global_job_kwargs_set = False @@ -135,7 +137,7 @@ def reset_global_job_kwargs(): Reset the global job kwargs. """ global global_job_kwargs - global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) + global_job_kwargs = _default_job_kwargs.copy() def is_set_global_job_kwargs_set() -> bool: diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ceff8577d3..d90a20902d 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -1,22 +1,6 @@ """ -Pipeline on spikes/peaks/detected peaks - -Functions that can be chained: - * after peak detection - * already detected peaks - * spikes (labeled peaks) -to compute some additional features on-the-fly: - * peak localization - * peak-to-peak - * pca - * amplitude - * amplitude scaling - * ... - -There are two ways for using theses "plugin nodes": - * during `peak_detect()` - * when peaks are already detected and reduced with `select_peaks()` - * on a sorting object + + """ from __future__ import annotations @@ -96,16 +80,26 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar class PeakSource(PipelineNode): - # base class for peak detector + def get_trace_margin(self): raise NotImplementedError def get_dtype(self): return base_peak_dtype + def get_peak_slice( + self, + segment_index, + start_frame, + end_frame, + ): + # not needed for PeakDetector + raise NotImplementedError + # this is used in sorting components class PeakDetector(PeakSource): + # base class for peak detector or template matching pass @@ -127,11 +121,18 @@ def get_trace_margin(self): def get_dtype(self): return base_peak_dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # get local peaks + def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0, i1 + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + # i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + i0, i1 = peak_slice local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -212,8 +213,7 @@ def get_trace_margin(self): def get_dtype(self): return self._dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # get local peaks + def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] if self.include_spikes_in_margin: @@ -222,6 +222,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): ) else: i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0, i1 + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + # if self.include_spikes_in_margin: + # i0, i1 = np.searchsorted( + # peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin] + # ) + # else: + # i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + i0, i1 = peak_slice + local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -467,16 +481,71 @@ def run_node_pipeline( nodes, job_kwargs, job_name="pipeline", - mp_context=None, + # mp_context=None, gather_mode="memory", gather_kwargs={}, squeeze_output=True, folder=None, names=None, verbose=False, + skip_after_n_peaks=None, ): """ - Common function to run pipeline with peak detector or already detected peak. + Machinery to compute in parallel operations on peaks and traces. + + This useful in several use cases: + * in sortingcomponents : detect peaks and make some computation on then (localize, pca, ...) + * in sortingcomponents : replay some peaks and make some computation on then (localize, pca, ...) + * postprocessing : replay some spikes and make some computation on then (localize, pca, ...) + + Here a "peak" is a spike without any labels just a "detected". + Here a "spike" is a spike with any a label so already sorted. + + The main idea is to have a graph of nodes. + Every node is doing a computaion of some peaks and related traces. + The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever) + + Every node can have one or several output that can be directed to other nodes (aka nodes have parents). + + Every node can optionally have a global output that will be gathered by the main process. + This is controlled by return_output = True. + + The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector. + These vectors can be in "memory" or in files ("npy") + + + Parameters + ---------- + + recording: Recording + + nodes: a list of PipelineNode + + job_kwargs: dict + The classical job_kwargs + job_name : str + The name of the pipeline used for the progress_bar + gather_mode : "memory" | "npz" + + gather_kwargs : dict + OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. + squeeze_output : bool, default True + If only one output node then squeeze the tuple + folder : str | Path | None + Used for gather_mode="npz" + names : list of str + Names of outputs. + verbose : bool, default False + Verbosity. + skip_after_n_peaks : None | int + Skip the computation after n_peaks. + This is not an exact because internally this skip is done per worker in average. + + Returns + ------- + outputs: tuple of np.array | np.array + a tuple of vector for the output of nodes having return_output=True. + If squeeze_output=True and only one output then directly np.array. """ check_graph(nodes) @@ -484,6 +553,11 @@ def run_node_pipeline( job_kwargs = fix_job_kwargs(job_kwargs) assert all(isinstance(node, PipelineNode) for node in nodes) + if skip_after_n_peaks is not None: + skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"] + else: + skip_after_n_peaks_per_worker = None + if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": @@ -491,7 +565,7 @@ def run_node_pipeline( else: raise ValueError(f"wrong gather_mode : {gather_mode}") - init_args = (recording, nodes) + init_args = (recording, nodes, skip_after_n_peaks_per_worker) processor = ChunkRecordingExecutor( recording, @@ -510,12 +584,14 @@ def run_node_pipeline( return outs -def _init_peak_pipeline(recording, nodes): +def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker): # create a local dict per worker worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["nodes"] = nodes worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) + worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker + worker_ctx["num_peaks"] = 0 return worker_ctx @@ -523,66 +599,88 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c recording = worker_ctx["recording"] max_margin = worker_ctx["max_margin"] nodes = worker_ctx["nodes"] + skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] recording_segment = recording._recording_segments[segment_index] - traces_chunk, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True - ) + node0 = nodes[0] - # compute the graph - pipeline_outputs = {} - for node in nodes: - node_parents = node.parents if node.parents else list() - node_input_args = tuple() - for parent in node_parents: - parent_output = pipeline_outputs[parent] - parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) - node_input_args += parent_outputs_tuple - if isinstance(node, PeakDetector): - # to handle compatibility peak detector is a special case - # with specific margin - # TODO later when in master: change this later - extra_margin = max_margin - node.get_trace_margin() - if extra_margin: - trace_detection = traces_chunk[extra_margin:-extra_margin] + if isinstance(node0, (SpikeRetriever, PeakRetriever)): + # in this case PeakSource could have no peaks and so no need to load traces just skip + peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin) + load_trace_and_compute = i0 < i1 + else: + # PeakDetector always need traces + load_trace_and_compute = True + + if skip_after_n_peaks_per_worker is not None: + if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker: + load_trace_and_compute = False + + if load_trace_and_compute: + traces_chunk, left_margin, right_margin = get_chunk_with_margin( + recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True + ) + # compute the graph + pipeline_outputs = {} + for node in nodes: + node_parents = node.parents if node.parents else list() + node_input_args = tuple() + for parent in node_parents: + parent_output = pipeline_outputs[parent] + parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) + node_input_args += parent_outputs_tuple + if isinstance(node, PeakDetector): + # to handle compatibility peak detector is a special case + # with specific margin + # TODO later when in master: change this later + extra_margin = max_margin - node.get_trace_margin() + if extra_margin: + trace_detection = traces_chunk[extra_margin:-extra_margin] + else: + trace_detection = traces_chunk + node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) + # set sample index to local + node_output[0]["sample_index"] += extra_margin + elif isinstance(node, PeakSource): + node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice) else: - trace_detection = traces_chunk - node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) - # set sample index to local - node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakSource): - node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) - else: - # TODO later when in master: change the signature of all nodes (or maybe not!) - node_output = node.compute(traces_chunk, *node_input_args) - pipeline_outputs[node] = node_output - - # propagate the output - pipeline_outputs_tuple = tuple() - for node in nodes: - # handle which buffer are given to the output - # this is controlled by node.return_output being a bool or tuple of bool - out = pipeline_outputs[node] - if isinstance(out, tuple): - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += out - elif isinstance(node.return_output, tuple): - for flag, e in zip(node.return_output, out): - if flag: - pipeline_outputs_tuple += (e,) - else: - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += (out,) - elif isinstance(node.return_output, tuple): - # this should not apppend : maybe a checker somewhere before ? - pass + # TODO later when in master: change the signature of all nodes (or maybe not!) + node_output = node.compute(traces_chunk, *node_input_args) + pipeline_outputs[node] = node_output + + if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource): + worker_ctx["num_peaks"] += node_output[0].size + + # propagate the output + pipeline_outputs_tuple = tuple() + for node in nodes: + # handle which buffer are given to the output + # this is controlled by node.return_output being a bool or tuple of bool + out = pipeline_outputs[node] + if isinstance(out, tuple): + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += out + elif isinstance(node.return_output, tuple): + for flag, e in zip(node.return_output, out): + if flag: + pipeline_outputs_tuple += (e,) + else: + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += (out,) + elif isinstance(node.return_output, tuple): + # this should not apppend : maybe a checker somewhere before ? + pass + + if isinstance(nodes[0], PeakDetector): + # the first out element is the peak vector + # we need to go back to absolut sample index + pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin - if isinstance(nodes[0], PeakDetector): - # the first out element is the peak vector - # we need to go back to absolut sample index - pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin + return pipeline_outputs_tuple - return pipeline_outputs_tuple + else: + # the gather will skip this output and not concatenate it + return class GatherToMemory: @@ -595,6 +693,9 @@ def __init__(self): self.tuple_mode = None def __call__(self, res): + if res is None: + return + if self.tuple_mode is None: # first loop only self.tuple_mode = isinstance(res, tuple) @@ -655,6 +756,9 @@ def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): self.final_shapes.append(None) def __call__(self, res): + if res is None: + return + if self.tuple_mode is None: # first loop only self.tuple_mode = isinstance(res, tuple) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 5f33350820..213968a80b 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -197,17 +197,23 @@ def random_spikes_selection( cum_sizes = np.cumsum([0] + [s.size for s in spikes]) # this fast when numba - spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids) + spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids, absolute_index=False) random_spikes_indices = [] for unit_index, unit_id in enumerate(sorting.unit_ids): all_unit_indices = [] for segment_index in range(sorting.get_num_segments()): - inds_in_seg = spike_indices[segment_index][unit_id] + cum_sizes[segment_index] + # this is local index + inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: - inds_in_seg = inds_in_seg[inds_in_seg >= margin_size] - inds_in_seg = inds_in_seg[inds_in_seg < (num_samples[segment_index] - margin_size)] - all_unit_indices.append(inds_in_seg) + local_spikes = spikes[segment_index][inds_in_seg] + mask = (local_spikes["sample_index"] >= margin_size) & ( + local_spikes["sample_index"] < (num_samples[segment_index] - margin_size) + ) + inds_in_seg = inds_in_seg[mask] + # go back to absolut index + inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index] + all_unit_indices.append(inds_in_seg_abs) all_unit_indices = np.concatenate(all_unit_indices) selected_unit_indices = rng.choice( all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 4961db8524..55cbe6070a 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from copy import copy from packaging.version import parse from time import perf_counter @@ -45,6 +46,7 @@ def create_sorting_analyzer( sparsity=None, return_scaled=True, overwrite=False, + backend_options=None, **sparsity_kwargs, ) -> "SortingAnalyzer": """ @@ -63,24 +65,29 @@ def create_sorting_analyzer( recording : Recording The recording object folder : str or Path or None, default: None - The folder where waveforms are cached + The folder where analyzer is cached format : "memory | "binary_folder" | "zarr", default: "memory" - The mode to store waveforms. If "folder", waveforms are stored on disk in the specified folder. + The mode to store analyzer. If "folder", the analyzer is stored on disk in the specified folder. The "folder" argument must be specified in case of mode "folder". - If "memory" is used, the waveforms are stored in RAM. Use this option carefully! + If "memory" is used, the analyzer is stored in RAM. Use this option carefully! sparse : bool, default: True If True, then a sparsity mask is computed using the `estimate_sparsity()` function using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the sparsity will be propagated to all ResultExtention that handle sparsity (like wavforms, pca, ...) You can control `estimate_sparsity()` : all extra arguments are propagated to it (included job_kwargs) sparsity : ChannelSparsity or None, default: None - The sparsity used to compute waveforms. If this is given, `sparse` is ignored. + The sparsity used to compute exensions. If this is given, `sparse` is ignored. return_scaled : bool, default: True All extensions that play with traces will use this global return_scaled : "waveforms", "noise_levels", "templates". This prevent return_scaled being differents from different extensions and having wrong snr for instance. overwrite: bool, default: False If True, overwrite the folder if it already exists. - + backend_options : dict | None, default: None + Keyword arguments for the backend specified by format. It can contain the: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) + sparsity_kwargs : keyword arguments Returns ------- @@ -91,7 +98,7 @@ def create_sorting_analyzer( -------- >>> import spikeinterface as si - >>> # Extract dense waveforms and save to disk with binary_folder format. + >>> # Create dense analyzer and save to disk with binary_folder format. >>> sorting_analyzer = si.create_sorting_analyzer(sorting, recording, format="binary_folder", folder="/path/to_my/result") >>> # Can be reload @@ -117,12 +124,14 @@ def create_sorting_analyzer( """ if format != "memory": if format == "zarr": - folder = clean_zarr_folder_name(folder) - if Path(folder).is_dir(): - if not overwrite: - raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") - else: - shutil.rmtree(folder) + if not is_path_remote(folder): + folder = clean_zarr_folder_name(folder) + if not is_path_remote(folder): + if Path(folder).is_dir(): + if not overwrite: + raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") + else: + shutil.rmtree(folder) # handle sparsity if sparsity is not None: @@ -144,27 +153,38 @@ def create_sorting_analyzer( return_scaled = False sorting_analyzer = SortingAnalyzer.create( - sorting, recording, format=format, folder=folder, sparsity=sparsity, return_scaled=return_scaled + sorting, + recording, + format=format, + folder=folder, + sparsity=sparsity, + return_scaled=return_scaled, + backend_options=backend_options, ) return sorting_analyzer -def load_sorting_analyzer(folder, load_extensions=True, format="auto", storage_options=None) -> "SortingAnalyzer": +def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_options=None) -> "SortingAnalyzer": """ Load a SortingAnalyzer object from disk. Parameters ---------- folder : str or Path - The folder / zarr folder where the waveform extractor is stored + The folder / zarr folder where the analyzer is stored. If the folder is a remote path stored in the cloud, + the backend_options can be used to specify credentials. If the remote path is not accessible, + and backend_options is not provided, the function will try to load the object in anonymous mode (anon=True), + which enables to load data from open buckets. load_extensions : bool, default: True Load all extensions or not. format : "auto" | "binary_folder" | "zarr" The format of the folder. - storage_options : dict | None, default: None - The storage options to specify credentials to remote zarr bucket. - For open buckets, it doesn't need to be specified. + backend_options : dict | None, default: None + The backend options for the backend. + The dictionary can contain the following keys: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets) Returns ------- @@ -172,7 +192,20 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", storage_o The loaded SortingAnalyzer """ - return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, storage_options=storage_options) + if is_path_remote(folder) and backend_options is None: + try: + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options + ) + except Exception as e: + backend_options = dict(storage_options=dict(anon=True)) + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options + ) + else: + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options + ) class SortingAnalyzer: @@ -205,7 +238,7 @@ def __init__( format=None, sparsity=None, return_scaled=True, - storage_options=None, + backend_options=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -215,10 +248,18 @@ def __init__( self.format = format self.sparsity = sparsity self.return_scaled = return_scaled - self.storage_options = storage_options + # this is used to store temporary recording self._temporary_recording = None + # backend-specific kwargs for different formats, which can be used to + # set some parameters for saving (e.g., compression) + # + # - storage_options: dict | None (fsspec storage options) + # - saving_options: dict | None + # (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) + self._backend_options = {} if backend_options is None else backend_options + # extensions are not loaded at init self.extensions = dict() @@ -228,13 +269,18 @@ def __repr__(self) -> str: nchan = self.get_num_channels() nunits = self.get_num_units() txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" + if self.format != "memory": + if is_path_remote(str(self.folder)): + txt += f" (remote)" if self.is_sparse(): txt += " - sparse" if self.has_recording(): txt += " - has recording" if self.has_temporary_recording(): txt += " - has temporary recording" - ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys()) + ext_txt = f"Loaded {len(self.extensions)} extensions" + if len(self.extensions) > 0: + ext_txt += f": {', '.join(self.extensions.keys())}" txt += "\n" + ext_txt return txt @@ -253,7 +299,9 @@ def create( folder=None, sparsity=None, return_scaled=True, + backend_options=None, ): + assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" # some checks if sorting.sampling_frequency != recording.sampling_frequency: if math.isclose(sorting.sampling_frequency, recording.sampling_frequency, abs_tol=1e-2, rel_tol=1e-5): @@ -277,22 +325,35 @@ def create( if format == "memory": sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_scaled, rec_attributes=None) elif format == "binary_folder": - cls.create_binary_folder(folder, sorting, recording, sparsity, return_scaled, rec_attributes=None) - sorting_analyzer = cls.load_from_binary_folder(folder, recording=recording) - sorting_analyzer.folder = Path(folder) + sorting_analyzer = cls.create_binary_folder( + folder, + sorting, + recording, + sparsity, + return_scaled, + rec_attributes=None, + backend_options=backend_options, + ) elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" - folder = clean_zarr_folder_name(folder) - cls.create_zarr(folder, sorting, recording, sparsity, return_scaled, rec_attributes=None) - sorting_analyzer = cls.load_from_zarr(folder, recording=recording) - sorting_analyzer.folder = Path(folder) + if not is_path_remote(folder): + folder = clean_zarr_folder_name(folder) + sorting_analyzer = cls.create_zarr( + folder, + sorting, + recording, + sparsity, + return_scaled, + rec_attributes=None, + backend_options=backend_options, + ) else: raise ValueError("SortingAnalyzer.create: wrong format") return sorting_analyzer @classmethod - def load(cls, folder, recording=None, load_extensions=True, format="auto", storage_options=None): + def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None): """ Load folder or zarr. The recording can be given if the recording location has changed. @@ -306,18 +367,15 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", stora format = "binary_folder" if format == "binary_folder": - sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) + sorting_analyzer = SortingAnalyzer.load_from_binary_folder( + folder, recording=recording, backend_options=backend_options + ) elif format == "zarr": sorting_analyzer = SortingAnalyzer.load_from_zarr( - folder, recording=recording, storage_options=storage_options + folder, recording=recording, backend_options=backend_options ) - if is_path_remote(str(folder)): - sorting_analyzer.folder = folder - # in this case we only load extensions when needed - else: - sorting_analyzer.folder = Path(folder) - + if not is_path_remote(str(folder)): if load_extensions: sorting_analyzer.load_all_saved_extension() @@ -349,11 +407,9 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut return sorting_analyzer @classmethod - def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes): + def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, backend_options): # used by create and save_as - assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" - folder = Path(folder) if folder.is_dir(): raise ValueError(f"Folder already exists {folder}") @@ -369,26 +425,34 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scale json.dump(check_json(info), f, indent=4) # save a copy of the sorting - # NumpyFolderSorting.write_sorting(sorting, folder / "sorting") sorting.save(folder=folder / "sorting") - # save recording and sorting provenance - if recording.check_serializability("json"): - recording.dump(folder / "recording.json", relative_to=folder) - elif recording.check_serializability("pickle"): - recording.dump(folder / "recording.pickle", relative_to=folder) + if recording is not None: + # save recording and sorting provenance + if recording.check_serializability("json"): + recording.dump(folder / "recording.json", relative_to=folder) + elif recording.check_serializability("pickle"): + recording.dump(folder / "recording.pickle", relative_to=folder) + else: + warnings.warn("The Recording is not serializable! The recording link will be lost for future load") + else: + assert rec_attributes is not None, "recording or rec_attributes must be provided" + warnings.warn("Recording not provided, instntiating SortingAnalyzer in recordingless mode.") if sorting.check_serializability("json"): sorting.dump(folder / "sorting_provenance.json", relative_to=folder) elif sorting.check_serializability("pickle"): sorting.dump(folder / "sorting_provenance.pickle", relative_to=folder) + else: + warnings.warn( + "The sorting provenance is not serializable! The sorting provenance link will be lost for future load" + ) # dump recording attributes probegroup = None rec_attributes_file = folder / "recording_info" / "recording_attributes.json" rec_attributes_file.parent.mkdir() if rec_attributes is None: - assert recording is not None rec_attributes = get_rec_attributes(recording) rec_attributes_file.write_text(json.dumps(check_json(rec_attributes), indent=4), encoding="utf8") probegroup = recording.get_probegroup() @@ -411,8 +475,10 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scale with open(settings_file, mode="w") as f: json.dump(check_json(settings), f, indent=4) + return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) + @classmethod - def load_from_binary_folder(cls, folder, recording=None): + def load_from_binary_folder(cls, folder, recording=None, backend_options=None): folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" @@ -483,34 +549,50 @@ def load_from_binary_folder(cls, folder, recording=None): format="binary_folder", sparsity=sparsity, return_scaled=return_scaled, + backend_options=backend_options, ) + sorting_analyzer.folder = folder return sorting_analyzer def _get_zarr_root(self, mode="r+"): import zarr - if is_path_remote(str(self.folder)): - mode = "r" + assert mode in ("r+", "a", "r"), "mode must be 'r+', 'a' or 'r'" + + storage_options = self._backend_options.get("storage_options", {}) # we open_consolidated only if we are in read mode if mode in ("r+", "a"): - zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=self.storage_options) + try: + zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=storage_options) + except Exception as e: + # this could happen in remote mode, and it's a way to check if the folder is still there + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=storage_options) else: - zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=self.storage_options) + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=storage_options) return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes): + def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, backend_options): # used by create and save_as import zarr import numcodecs + from .zarrextractors import add_sorting_to_zarr_group - folder = clean_zarr_folder_name(folder) + if is_path_remote(folder): + remote = True + else: + remote = False + if not remote: + folder = clean_zarr_folder_name(folder) + if folder.is_dir(): + raise ValueError(f"Folder already exists {folder}") - if folder.is_dir(): - raise ValueError(f"Folder already exists {folder}") + backend_options = {} if backend_options is None else backend_options + storage_options = backend_options.get("storage_options", {}) + saving_options = backend_options.get("saving_options", {}) - zarr_root = zarr.open(folder, mode="w") + zarr_root = zarr.open(folder, mode="w", storage_options=storage_options) info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) @@ -519,37 +601,39 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at zarr_root.attrs["settings"] = check_json(settings) # the recording - rec_dict = recording.to_dict(relative_to=folder, recursive=True) - - if recording.check_serializability("json"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) - zarr_rec = np.array([check_json(rec_dict)], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) - elif recording.check_serializability("pickle"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) - zarr_rec = np.array([rec_dict], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) + relative_to = folder if not remote else None + if recording is not None: + rec_dict = recording.to_dict(relative_to=relative_to, recursive=True) + if recording.check_serializability("json"): + # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) + zarr_rec = np.array([check_json(rec_dict)], dtype=object) + zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) + elif recording.check_serializability("pickle"): + # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) + zarr_rec = np.array([rec_dict], dtype=object) + zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) + else: + warnings.warn("The Recording is not serializable! The recording link will be lost for future load") else: - warnings.warn( - "SortingAnalyzer with zarr : the Recording is not json serializable, the recording link will be lost for future load" - ) + assert rec_attributes is not None, "recording or rec_attributes must be provided" + warnings.warn("Recording not provided, instntiating SortingAnalyzer in recordingless mode.") # sorting provenance - sort_dict = sorting.to_dict(relative_to=folder, recursive=True) + sort_dict = sorting.to_dict(relative_to=relative_to, recursive=True) if sorting.check_serializability("json"): zarr_sort = np.array([check_json(sort_dict)], dtype=object) zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON()) elif sorting.check_serializability("pickle"): zarr_sort = np.array([sort_dict], dtype=object) zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle()) - - # else: - # warnings.warn("SortingAnalyzer with zarr : the sorting provenance is not json serializable, the sorting provenance link will be lost for futur load") + else: + warnings.warn( + "The sorting provenance is not serializable! The sorting provenance link will be lost for future load" + ) recording_info = zarr_root.create_group("recording_info") if rec_attributes is None: - assert recording is not None rec_attributes = get_rec_attributes(recording) probegroup = recording.get_probegroup() else: @@ -562,24 +646,23 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) if sparsity is not None: - zarr_root.create_dataset("sparsity_mask", data=sparsity.mask) + zarr_root.create_dataset("sparsity_mask", data=sparsity.mask, **saving_options) - # write sorting copy - from .zarrextractors import add_sorting_to_zarr_group - - # Alessio : we need to find a way to propagate compressor for all steps. - # kwargs = dict(compressor=...) - zarr_kwargs = dict() - add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **zarr_kwargs) + add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **saving_options) recording_info = zarr_root.create_group("extensions") zarr.consolidate_metadata(zarr_root.store) + return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options) + @classmethod - def load_from_zarr(cls, folder, recording=None, storage_options=None): + def load_from_zarr(cls, folder, recording=None, backend_options=None): import zarr + backend_options = {} if backend_options is None else backend_options + storage_options = backend_options.get("storage_options", {}) + zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) si_info = zarr_root.attrs["spikeinterface_info"] @@ -605,11 +688,13 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): # load recording if possible if recording is None: - rec_dict = zarr_root["recording"][0] - try: - recording = load_extractor(rec_dict, base_folder=folder) - except: - recording = None + rec_field = zarr_root.get("recording") + if rec_field is not None: + rec_dict = rec_field[0] + try: + recording = load_extractor(rec_dict, base_folder=folder) + except: + recording = None else: # TODO maybe maybe not??? : do we need to check attributes match internal rec_attributes # Note this will make the loading too slow @@ -640,8 +725,9 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): format="zarr", sparsity=sparsity, return_scaled=return_scaled, - storage_options=storage_options, + backend_options=backend_options, ) + sorting_analyzer.folder = folder return sorting_analyzer @@ -683,6 +769,7 @@ def _save_or_select_or_merge( sparsity_overlap=0.75, verbose=False, new_unit_ids=None, + backend_options=None, **job_kwargs, ) -> "SortingAnalyzer": """ @@ -712,8 +799,13 @@ def _save_or_select_or_merge( The new unit ids for merged units. Required if `merge_unit_groups` is not None. verbose : bool, default: False If True, output is verbose. - job_kwargs : dict - Keyword arguments for parallelization. + backend_options : dict | None, default: None + Keyword arguments for the backend specified by format. It can contain the: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) + job_kwargs : keyword arguments + Keyword arguments for the job parallelization. Returns ------- @@ -787,6 +879,8 @@ def _save_or_select_or_merge( # TODO: sam/pierre would create a curation field / curation.json with the applied merges. # What do you think? + backend_options = {} if backend_options is None else backend_options + if format == "memory": # This make a copy of actual SortingAnalyzer new_sorting_analyzer = SortingAnalyzer.create_memory( @@ -797,20 +891,28 @@ def _save_or_select_or_merge( # create a new folder assert folder is not None, "For format='binary_folder' folder must be provided" folder = Path(folder) - SortingAnalyzer.create_binary_folder( - folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes + new_sorting_analyzer = SortingAnalyzer.create_binary_folder( + folder, + sorting_provenance, + recording, + sparsity, + self.return_scaled, + self.rec_attributes, + backend_options=backend_options, ) - new_sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) - new_sorting_analyzer.folder = folder elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" folder = clean_zarr_folder_name(folder) - SortingAnalyzer.create_zarr( - folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes + new_sorting_analyzer = SortingAnalyzer.create_zarr( + folder, + sorting_provenance, + recording, + sparsity, + self.return_scaled, + self.rec_attributes, + backend_options=backend_options, ) - new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) - new_sorting_analyzer.folder = folder else: raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") @@ -848,7 +950,7 @@ def _save_or_select_or_merge( return new_sorting_analyzer - def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": + def save_as(self, format="memory", folder=None, backend_options=None) -> "SortingAnalyzer": """ Save SortingAnalyzer object into another format. Uselful for memory to zarr or memory to binary. @@ -863,10 +965,15 @@ def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": The output folder if `format` is "zarr" or "binary_folder" format : "memory" | "binary_folder" | "zarr", default: "memory" The new backend format to use + backend_options : dict | None, default: None + Keyword arguments for the backend specified by format. It can contain the: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) """ if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder) + return self._save_or_select_or_merge(format=format, folder=folder, backend_options=backend_options) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -1029,7 +1136,15 @@ def copy(self): def is_read_only(self) -> bool: if self.format == "memory": return False - return not os.access(self.folder, os.W_OK) + elif self.format == "binary_folder": + return not os.access(self.folder, os.W_OK) + else: + if not is_path_remote(str(self.folder)): + return not os.access(self.folder, os.W_OK) + else: + # in this case we don't know if the file is read only so an error + # will be raised if we try to save/append + return False ## map attribute and property zone @@ -1965,7 +2080,8 @@ def load_data(self): continue ext_data_name = ext_data_file.stem if ext_data_file.suffix == ".json": - ext_data = json.load(ext_data_file.open("r")) + with ext_data_file.open("r") as f: + ext_data = json.load(f) elif ext_data_file.suffix == ".npy": # The lazy loading of an extension is complicated because if we compute again # and have a link to the old buffer on windows then it fails @@ -1977,7 +2093,8 @@ def load_data(self): ext_data = pd.read_csv(ext_data_file, index_col=0) elif ext_data_file.suffix == ".pkl": - ext_data = pickle.load(ext_data_file.open("rb")) + with ext_data_file.open("rb") as f: + ext_data = pickle.load(f) else: continue self.data[ext_data_name] = ext_data @@ -2015,7 +2132,7 @@ def copy(self, new_sorting_analyzer, unit_ids=None): new_extension.data = self.data else: new_extension.data = self._select_extension_data(unit_ids) - new_extension.run_info = self.run_info.copy() + new_extension.run_info = copy(self.run_info) new_extension.save() return new_extension @@ -2033,7 +2150,7 @@ def merge( new_extension.data = self._merge_extension_data( merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs ) - new_extension.run_info = self.run_info.copy() + new_extension.run_info = copy(self.run_info) new_extension.save() return new_extension @@ -2051,24 +2168,24 @@ def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): self._save_run_info() - self._save_data(**kwargs) + self._save_data() if self.format == "zarr": import zarr zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) - def save(self, **kwargs): + def save(self): self._save_params() self._save_importing_provenance() self._save_run_info() - self._save_data(**kwargs) + self._save_data() if self.format == "zarr": import zarr zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) - def _save_data(self, **kwargs): + def _save_data(self): if self.format == "memory": return @@ -2107,14 +2224,14 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - import zarr import numcodecs + saving_options = self.sorting_analyzer._backend_options.get("saving_options", {}) extension_group = self._get_zarr_extension_group(mode="r+") - compressor = kwargs.get("compressor", None) - if compressor is None: - compressor = get_default_zarr_compressor() + # if compression is not externally given, we use the default + if "compressor" not in saving_options: + saving_options["compressor"] = get_default_zarr_compressor() for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: @@ -2124,13 +2241,19 @@ def _save_data(self, **kwargs): name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() ) elif isinstance(ext_data, np.ndarray): - extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) + extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): df_group = extension_group.create_group(ext_data_name) # first we save the index - df_group.create_dataset(name="index", data=ext_data.index.to_numpy()) + indices = ext_data.index.to_numpy() + if indices.dtype.kind == "O": + indices = indices.astype(str) + df_group.create_dataset(name="index", data=indices) for col in ext_data.columns: - df_group.create_dataset(name=col, data=ext_data[col].to_numpy()) + col_data = ext_data[col].to_numpy() + if col_data.dtype.kind == "O": + col_data = col_data.astype(str) + df_group.create_dataset(name=col, data=col_data) df_group.attrs["dataframe"] = True else: # any object @@ -2187,7 +2310,7 @@ def delete(self): def reset(self): """ - Reset the waveform extension. + Reset the extension. Delete the sub folder and create a new empty one. """ self._reset_extension_folder() @@ -2202,7 +2325,8 @@ def set_params(self, save=True, **params): """ # this ensure data is also deleted and corresponds to params # this also ensure the group is created - self._reset_extension_folder() + if save: + self._reset_extension_folder() params = self._set_params(**params) self.params = params @@ -2251,15 +2375,16 @@ def _save_importing_provenance(self): extension_group.attrs["info"] = info def _save_run_info(self): - run_info = self.run_info.copy() - - if self.format == "binary_folder": - extension_folder = self._get_binary_extension_folder() - run_info_file = extension_folder / "run_info.json" - run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8") - elif self.format == "zarr": - extension_group = self._get_zarr_extension_group(mode="r+") - extension_group.attrs["run_info"] = run_info + if self.run_info is not None: + run_info = self.run_info.copy() + + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + run_info_file = extension_folder / "run_info.json" + run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8") + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode="r+") + extension_group.attrs["run_info"] = run_info def get_pipeline_nodes(self): assert ( diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 844abf2ccf..cb7debf3e0 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -10,6 +10,7 @@ generate_recording, generate_sorting, NoiseGeneratorRecording, + SortingGenerator, TransformSorting, generate_recording_by_size, InjectTemplatesRecording, @@ -94,6 +95,73 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float: return memory +def test_memory_sorting_generator(): + # Test that get_traces does not consume more memory than allocated. + + bytes_to_MiB_factor = 1024**2 + relative_tolerance = 0.05 # relative tolerance of 5 per cent + + sampling_frequency = 30000 # Hz + durations = [60.0] + num_units = 1000 + seed = 0 + + before_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor + sorting = SortingGenerator( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=durations, + seed=seed, + ) + after_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor + memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB + ratio = memory_usage_MiB / before_instanciation_MiB + expected_allocation_MiB = 0 + assert ( + ratio <= 1.0 + relative_tolerance + ), f"SortingGenerator wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" + + +def test_sorting_generator_consisency_across_calls(): + sampling_frequency = 30000 # Hz + durations = [1.0] + num_units = 3 + seed = 0 + + sorting = SortingGenerator( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=durations, + seed=seed, + ) + + for unit_id in sorting.get_unit_ids(): + spike_train = sorting.get_unit_spike_train(unit_id=unit_id) + spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id) + + assert np.allclose(spike_train, spike_train_again) + + +def test_sorting_generator_consisency_within_trains(): + sampling_frequency = 30000 # Hz + durations = [1.0] + num_units = 3 + seed = 0 + + sorting = SortingGenerator( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=durations, + seed=seed, + ) + + for unit_id in sorting.get_unit_ids(): + spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000) + spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000) + + assert np.allclose(spike_train, spike_train_again) + + def test_noise_generator_memory(): # Test that get_traces does not consume more memory than allocated. diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 8d788acbad..deef2291c6 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -83,8 +83,12 @@ def test_run_node_pipeline(cache_folder_creation): extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) + # print(peaks.size) peak_retriever = PeakRetriever(recording, peaks) + # this test when no spikes in last chunks + peak_retriever_few = PeakRetriever(recording, peaks[: peaks.size // 2]) + # channel index is from template spike_retriever_T = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds @@ -100,7 +104,7 @@ def test_run_node_pipeline(cache_folder_creation): ) # test with 3 differents first nodes - for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): + for loop, peak_source in enumerate((peak_retriever, peak_retriever_few, spike_retriever_T, spike_retriever_S)): # one step only : squeeze output nodes = [ peak_source, @@ -139,10 +143,12 @@ def test_run_node_pipeline(cache_folder_creation): num_peaks = peaks.shape[0] num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks + if peak_source != peak_retriever_few: + assert waveforms_rms.shape[0] == num_peaks assert waveforms_rms.shape[1] == num_channels - assert waveforms_rms.shape[0] == num_peaks + if peak_source != peak_retriever_few: + assert waveforms_rms.shape[0] == num_peaks assert waveforms_rms.shape[1] == num_channels # gather npy mode @@ -185,5 +191,38 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) +def test_skip_after_n_peaks(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) + + # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) + + spikes = sorting.to_spike_vector() + + # create peaks from spikes + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + + peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) + # print(peaks.size) + + node0 = PeakRetriever(recording, peaks) + node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) + nodes = [node0, node1] + + skip_after_n_peaks = 30 + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks + ) + + assert some_amplitudes.size >= skip_after_n_peaks + assert some_amplitudes.size < spikes.size + + +# the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": - test_run_node_pipeline() + # folder = Path("./cache_folder/core") + # test_run_node_pipeline(folder) + + test_skip_after_n_peaks() diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 34bb3a221d..7d26773ac3 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -162,8 +162,8 @@ def test_generate_unit_ids_for_merge_group(): if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() - # test_random_spikes_selection() + test_random_spikes_selection() - test_apply_merges_to_sorting() - test_get_ids_after_merging() + # test_apply_merges_to_sorting() + # test_get_ids_after_merging() # test_generate_unit_ids_for_merge_group() diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 5c7e267cc6..35ab18b5f2 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -10,6 +10,7 @@ load_sorting_analyzer, get_available_analyzer_extensions, get_default_analyzer_extension_params, + get_default_zarr_compressor, ) from spikeinterface.core.sortinganalyzer import ( register_result_extension, @@ -99,16 +100,25 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" - if folder.exists(): - shutil.rmtree(folder) + default_compressor = get_default_zarr_compressor() sorting_analyzer = create_sorting_analyzer( - sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None + sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None, overwrite=True ) sorting_analyzer.compute(["random_spikes", "templates"]) sorting_analyzer = load_sorting_analyzer(folder, format="auto") _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + # check that compression is applied + assert ( + sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressor.codec_id + == default_compressor.codec_id + ) + assert ( + sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id + == default_compressor.codec_id + ) + # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 # this bug requires that we have an info.json file so we calculate templates above select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) @@ -117,11 +127,45 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): assert len(remove_units_sorting_analyer.unit_ids) == len(sorting_analyzer.unit_ids) - 1 assert 1 not in remove_units_sorting_analyer.unit_ids - folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" - if folder.exists(): - shutil.rmtree(folder) - sorting_analyzer = create_sorting_analyzer( - sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None, return_scaled=False + # test no compression + sorting_analyzer_no_compression = create_sorting_analyzer( + sorting, + recording, + format="zarr", + folder=folder, + sparse=False, + sparsity=None, + return_scaled=False, + overwrite=True, + backend_options={"saving_options": {"compressor": None}}, + ) + print(sorting_analyzer_no_compression._backend_options) + sorting_analyzer_no_compression.compute(["random_spikes", "templates"]) + assert ( + sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ + "random_spikes_indices" + ].compressor + is None + ) + assert sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressor is None + + # test a different compressor + from numcodecs import LZMA + + lzma_compressor = LZMA() + folder = tmp_path / "test_SortingAnalyzer_zarr_lzma.zarr" + sorting_analyzer_lzma = sorting_analyzer_no_compression.save_as( + format="zarr", folder=folder, backend_options={"saving_options": {"compressor": lzma_compressor}} + ) + assert ( + sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"][ + "random_spikes_indices" + ].compressor.codec_id + == LZMA.codec_id + ) + assert ( + sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id + == LZMA.codec_id ) @@ -326,7 +370,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): else: folder = None sorting_analyzer5 = sorting_analyzer.merge_units( - merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, mode="hard" + merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, merging_mode="hard" ) # test compute with extension-specific params diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 17f1ac08b3..ff552dfb54 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -12,6 +12,19 @@ from .core_tools import define_function_from_class, check_json from .job_tools import split_job_kwargs from .recording_tools import determine_cast_unsigned +from .core_tools import is_path_remote + + +def anononymous_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: dict | None = None): + if is_path_remote(str(folder_path)) and storage_options is None: + try: + root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + except Exception as e: + storage_options = {"anon": True} + root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + else: + root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + return root class ZarrRecordingExtractor(BaseRecording): @@ -21,7 +34,11 @@ class ZarrRecordingExtractor(BaseRecording): Parameters ---------- folder_path : str or Path - Path to the zarr root folder + Path to the zarr root folder. This can be a local path or a remote path (s3:// or gcs://). + If the path is a remote path, the storage_options can be provided to specify credentials. + If the remote path is not accessible and backend_options is not provided, + the function will try to load the object in anonymous mode (anon=True), + which enables to load data from open buckets. storage_options : dict or None Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. @@ -35,7 +52,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) - self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + self._root = anononymous_zarr_open(folder_path, mode="r", storage_options=storage_options) sampling_frequency = self._root.attrs.get("sampling_frequency", None) num_segments = self._root.attrs.get("num_segments", None) @@ -81,7 +98,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) nbytes_segment = self._root[trace_name].nbytes nbytes_stored_segment = self._root[trace_name].nbytes_stored - cr_by_segment[segment_index] = nbytes_segment / nbytes_stored_segment + if nbytes_stored_segment > 0: + cr_by_segment[segment_index] = nbytes_segment / nbytes_stored_segment + else: + cr_by_segment[segment_index] = np.nan total_nbytes += nbytes_segment total_nbytes_stored += nbytes_stored_segment @@ -105,7 +125,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) if annotations is not None: self.annotate(**annotations) # annotate compression ratios - cr = total_nbytes / total_nbytes_stored + if total_nbytes_stored > 0: + cr = total_nbytes / total_nbytes_stored + else: + cr = np.nan self.annotate(compression_ratio=cr, compression_ratio_segments=cr_by_segment) self._kwargs = {"folder_path": folder_path_kwarg, "storage_options": storage_options} @@ -150,7 +173,11 @@ class ZarrSortingExtractor(BaseSorting): Parameters ---------- folder_path : str or Path - Path to the zarr root file + Path to the zarr root file. This can be a local path or a remote path (s3:// or gcs://). + If the path is a remote path, the storage_options can be provided to specify credentials. + If the remote path is not accessible and backend_options is not provided, + the function will try to load the object in anonymous mode (anon=True), + which enables to load data from open buckets. storage_options : dict or None Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. zarr_group : str or None, default: None @@ -165,7 +192,8 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) - zarr_root = self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + zarr_root = anononymous_zarr_open(folder_path, mode="r", storage_options=storage_options) + if zarr_group is None: self._root = zarr_root else: @@ -243,7 +271,7 @@ def read_zarr( """ # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is. # for the futur SortingAnalyzer we will have this 2 fields!!! - root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + root = anononymous_zarr_open(folder_path, mode="r", storage_options=storage_options) if "channel_ids" in root.keys(): return read_zarr_recording(folder_path, storage_options=storage_options) elif "unit_ids" in root.keys(): @@ -329,8 +357,7 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G zarr_group.attrs["num_segments"] = int(num_segments) zarr_group.create_dataset(name="unit_ids", data=sorting.unit_ids, compressor=None) - if "compressor" not in kwargs: - compressor = get_default_zarr_compressor() + compressor = kwargs.get("compressor", get_default_zarr_compressor()) # save sub fields spikes_group = zarr_group.create_group(name="spikes") diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 5dd549347d..317ea21cce 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -105,6 +105,8 @@ def get_stream_names(eid: str, cache_folder: Optional[Union[Path, str]] = None, An instance of the ONE API to use for data loading. If not provided, a default instance is created using the default parameters. If you need to use a specific instance, you can create it using the ONE API and pass it here. + stream_type : "ap" | "lf" | None, default: None + The stream type to load, required when pid is provided and stream_name is not. Returns ------- @@ -140,6 +142,7 @@ def __init__( remove_cached: bool = True, stream: bool = True, one: "one.api.OneAlyx" = None, + stream_type: str | None = None, ): try: from brainbox.io.one import SpikeSortingLoader @@ -154,20 +157,24 @@ def __init__( one = IblRecordingExtractor._get_default_one(cache_folder=cache_folder) if pid is not None: + assert stream_type is not None, "When providing a PID, you must also provide a stream type." eid, _ = one.pid2eid(pid) - - stream_names = IblRecordingExtractor.get_stream_names(eid=eid, cache_folder=cache_folder, one=one) - if len(stream_names) > 1: - assert ( - stream_name is not None - ), f"Multiple streams found for session. Please specify a stream name from {stream_names}." - assert stream_name in stream_names, ( - f"The `stream_name` '{stream_name}' is not available for this experiment {eid}! " - f"Please choose one of {stream_names}." - ) + pids, probes = one.eid2pid(eid) + pname = probes[pids.index(pid)] + stream_name = f"{pname}.{stream_type}" else: - stream_name = stream_names[0] - pname, stream_type = stream_name.split(".") + stream_names = IblRecordingExtractor.get_stream_names(eid=eid, cache_folder=cache_folder, one=one) + if len(stream_names) > 1: + assert ( + stream_name is not None + ), f"Multiple streams found for session. Please specify a stream name from {stream_names}." + assert stream_name in stream_names, ( + f"The `stream_name` '{stream_name}' is not available for this experiment {eid}! " + f"Please choose one of {stream_names}." + ) + else: + stream_name = stream_names[0] + pname, stream_type = stream_name.split(".") self.ssl = SpikeSortingLoader(one=one, eid=eid, pid=pid, pname=pname) if pid is None: diff --git a/src/spikeinterface/full.py b/src/spikeinterface/full.py index 0cd0fb0fb5..b9410bc021 100644 --- a/src/spikeinterface/full.py +++ b/src/spikeinterface/full.py @@ -25,3 +25,4 @@ from .widgets import * from .exporters import * from .generation import * +from .benchmark import * diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 1871c11b85..809f2c5bba 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -1,12 +1,14 @@ from __future__ import annotations -import shutil -import pickle import warnings -import tempfile +import platform from pathlib import Path from tqdm.auto import tqdm +from concurrent.futures import ProcessPoolExecutor +import multiprocessing as mp +from threadpoolctl import threadpool_limits + import numpy as np from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -314,11 +316,13 @@ def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + mp_context = job_kwargs["mp_context"] # fit model/models # TODO : make parralel for by_channel_global and concatenated if mode == "by_channel_local": - pca_models = self._fit_by_channel_local(n_jobs, progress_bar) + pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_process, mp_context) for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids): self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind] pca_model = pca_models @@ -411,12 +415,16 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): ) processor.run() - def _fit_by_channel_local(self, n_jobs, progress_bar): + def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, mp_context): from sklearn.decomposition import IncrementalPCA - from concurrent.futures import ProcessPoolExecutor p = self.params + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + unit_ids = self.sorting_analyzer.unit_ids channel_ids = self.sorting_analyzer.channel_ids # there is one PCA per channel for independent fit per channel @@ -436,13 +444,18 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): pca = pca_models[chan_ind] pca.partial_fit(wfs[:, :, wf_ind]) else: - # parallel + # create list of args to parallelize. For convenience, the max_threads_per_process is passed + # as last argument items = [ - (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds) + (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_process) + for wf_ind, chan_ind in enumerate(channel_inds) ] n_jobs = min(n_jobs, len(items)) - with ProcessPoolExecutor(max_workers=n_jobs) as executor: + with ProcessPoolExecutor( + max_workers=n_jobs, + mp_context=mp.get_context(mp_context), + ) as executor: results = executor.map(_partial_fit_one_channel, items) for chan_ind, pca_model_updated in results: pca_models[chan_ind] = pca_model_updated @@ -674,6 +687,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte def _partial_fit_one_channel(args): - chan_ind, pca_model, wf_chan = args - pca_model.partial_fit(wf_chan) - return chan_ind, pca_model + chan_ind, pca_model, wf_chan, max_threads_per_process = args + + if max_threads_per_process is None: + pca_model.partial_fit(wf_chan) + return chan_ind, pca_model + else: + with threadpool_limits(limits=int(max_threads_per_process)): + pca_model.partial_fit(wf_chan) + return chan_ind, pca_model diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0e70b1f494..cfa9d89fea 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -7,6 +7,13 @@ from ..core.template_tools import get_dense_templates_array from ..core.sparsity import ChannelSparsity +try: + import numba + + HAVE_NUMBA = True +except ImportError: + HAVE_NUMBA = False + class ComputeTemplateSimilarity(AnalyzerExtension): """Compute similarity between templates with several methods. @@ -147,54 +154,15 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def compute_similarity_with_templates_array( - templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None -): - import sklearn.metrics.pairwise +def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): - if method == "cosine_similarity": - method = "cosine" - - all_metrics = ["cosine", "l1", "l2"] - - if method not in all_metrics: - raise ValueError(f"compute_template_similarity (method {method}) not exists") - - assert ( - templates_array.shape[1] == other_templates_array.shape[1] - ), "The number of samples in the templates should be the same for both arrays" - assert ( - templates_array.shape[2] == other_templates_array.shape[2] - ), "The number of channels in the templates should be the same for both arrays" num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - num_channels = templates_array.shape[2] other_num_templates = other_templates_array.shape[0] - same_array = np.array_equal(templates_array, other_templates_array) - - mask = None - if sparsity is not None and other_sparsity is not None: - if support == "intersection": - mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) - elif support == "union": - mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) - units_overlaps = np.sum(mask, axis=2) > 0 - mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) - mask[~units_overlaps] = False - if mask is not None: - units_overlaps = np.sum(mask, axis=2) > 0 - overlapping_templates = {} - for i in range(num_templates): - overlapping_templates[i] = np.flatnonzero(units_overlaps[i]) - else: - # here we make a dense mask and overlapping templates - overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)} - mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) - - assert num_shifts < num_samples, "max_lag is too large" num_shifts_both_sides = 2 * num_shifts + 1 distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + same_array = np.array_equal(templates_array, other_templates_array) # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t # So the matrix can be computed only for negative lags and be transposed @@ -210,8 +178,9 @@ def compute_similarity_with_templates_array( tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - tgt_templates = tgt_sliced_templates[overlapping_templates[i]] - for gcount, j in enumerate(overlapping_templates[i]): + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + tgt_templates = tgt_sliced_templates[overlapping_templates] + for gcount, j in enumerate(overlapping_templates): # symmetric values are handled later if same_array and j < i: # no need exhaustive looping when same template @@ -222,23 +191,156 @@ def compute_similarity_with_templates_array( if method == "l1": norm_i = np.sum(np.abs(src)) norm_j = np.sum(np.abs(tgt)) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item() + distances[count, i, j] = np.sum(np.abs(src - tgt)) distances[count, i, j] /= norm_i + norm_j elif method == "l2": norm_i = np.linalg.norm(src, ord=2) norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item() + distances[count, i, j] = np.linalg.norm(src - tgt, ord=2) distances[count, i, j] /= norm_i + norm_j - else: - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances( - src, tgt, metric="cosine" - ).item() + elif method == "cosine": + norm_i = np.linalg.norm(src, ord=2) + norm_j = np.linalg.norm(tgt, ord=2) + distances[count, i, j] = np.sum(src * tgt) + distances[count, i, j] /= norm_i * norm_j + distances[count, i, j] = 1 - distances[count, i, j] if same_array: distances[count, j, i] = distances[count, i, j] if same_array and num_shifts != 0: distances[num_shifts_both_sides - count - 1] = distances[count].T + return distances + + +if HAVE_NUMBA: + + from math import sqrt + + @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) + def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + other_num_templates = other_templates_array.shape[0] + + num_shifts_both_sides = 2 * num_shifts + 1 + distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + same_array = np.array_equal(templates_array, other_templates_array) + + # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t + # So the matrix can be computed only for negative lags and be transposed + + if same_array: + # optimisation when array are the same because of symetry in shift + shift_loop = list(range(-num_shifts, 1)) + else: + shift_loop = list(range(-num_shifts, num_shifts + 1)) + + if method == "l1": + metric = 0 + elif method == "l2": + metric = 1 + elif method == "cosine": + metric = 2 + + for count in range(len(shift_loop)): + shift = shift_loop[count] + src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] + tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] + for i in numba.prange(num_templates): + src_template = src_sliced_templates[i] + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + tgt_templates = tgt_sliced_templates[overlapping_templates] + for gcount in range(len(overlapping_templates)): + + j = overlapping_templates[gcount] + # symmetric values are handled later + if same_array and j < i: + # no need exhaustive looping when same template + continue + src = src_template[:, mask[i, j]].flatten() + tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten() + + norm_i = 0 + norm_j = 0 + distances[count, i, j] = 0 + + for k in range(len(src)): + if metric == 0: + norm_i += abs(src[k]) + norm_j += abs(tgt[k]) + distances[count, i, j] += abs(src[k] - tgt[k]) + elif metric == 1: + norm_i += src[k] ** 2 + norm_j += tgt[k] ** 2 + distances[count, i, j] += (src[k] - tgt[k]) ** 2 + elif metric == 2: + distances[count, i, j] += src[k] * tgt[k] + norm_i += src[k] ** 2 + norm_j += tgt[k] ** 2 + + if metric == 0: + distances[count, i, j] /= norm_i + norm_j + elif metric == 1: + norm_i = sqrt(norm_i) + norm_j = sqrt(norm_j) + distances[count, i, j] = sqrt(distances[count, i, j]) + distances[count, i, j] /= norm_i + norm_j + elif metric == 2: + norm_i = sqrt(norm_i) + norm_j = sqrt(norm_j) + distances[count, i, j] /= norm_i * norm_j + distances[count, i, j] = 1 - distances[count, i, j] + + if same_array: + distances[count, j, i] = distances[count, i, j] + + if same_array and num_shifts != 0: + distances[num_shifts_both_sides - count - 1] = distances[count].T + + return distances + + _compute_similarity_matrix = _compute_similarity_matrix_numba +else: + _compute_similarity_matrix = _compute_similarity_matrix_numpy + + +def compute_similarity_with_templates_array( + templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None +): + + if method == "cosine_similarity": + method = "cosine" + + all_metrics = ["cosine", "l1", "l2"] + + if method not in all_metrics: + raise ValueError(f"compute_template_similarity (method {method}) not exists") + + assert ( + templates_array.shape[1] == other_templates_array.shape[1] + ), "The number of samples in the templates should be the same for both arrays" + assert ( + templates_array.shape[2] == other_templates_array.shape[2] + ), "The number of channels in the templates should be the same for both arrays" + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + num_channels = templates_array.shape[2] + other_num_templates = other_templates_array.shape[0] + + mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + + if sparsity is not None and other_sparsity is not None: + if support == "intersection": + mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + elif support == "union": + mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + units_overlaps = np.sum(mask, axis=2) > 0 + mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + mask[~units_overlaps] = False + + assert num_shifts < num_samples, "max_lag is too large" + distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method) distances = np.min(distances, axis=0) similarity = 1 - distances diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 66d84c9565..0431c8d675 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -93,7 +93,6 @@ def test_equal_results_correlograms(window_and_bin_ms): ) assert np.array_equal(result_numpy, result_numba) - assert np.array_equal(result_numpy, result_numba) @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 4de86be32b..7a509c410f 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -18,6 +18,18 @@ class TestPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite): def test_extension(self, params): self.run_extension_tests(ComputePrincipalComponents, params=params) + def test_multi_processing(self): + """ + Test the extension works with multiple processes. + """ + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=False, extension_class=ComputePrincipalComponents + ) + sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2) + sorting_analyzer.compute( + "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_process=4, mp_context="spawn" + ) + def test_mode_concatenated(self): """ Replicate the "extension_function_params_list" test outside of diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index cc6797c262..20d8373981 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -7,7 +7,23 @@ ) from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity -from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array +from spikeinterface.postprocessing.template_similarity import ( + compute_similarity_with_templates_array, + _compute_similarity_matrix_numpy, +) + +try: + import numba + + HAVE_NUMBA = True + from spikeinterface.postprocessing.template_similarity import _compute_similarity_matrix_numba +except ModuleNotFoundError as err: + HAVE_NUMBA = False + +import pytest +from pytest import param + +SKIP_NUMBA = pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite): @@ -72,6 +88,35 @@ def test_compute_similarity_with_templates_array(params): print(similarity.shape) +pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") + + +@pytest.mark.parametrize( + "params", + [ + dict(method="cosine", num_shifts=8), + dict(method="l1", num_shifts=0), + dict(method="l2", num_shifts=0), + dict(method="cosine", num_shifts=0), + ], +) +def test_equal_results_numba(params): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + + rng = np.random.default_rng(seed=2205) + templates_array = rng.random(size=(4, 20, 5), dtype=np.float32) + other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32) + mask = np.ones((4, 2, 5), dtype=bool) + + result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params) + result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params) + + assert np.allclose(result_numpy, result_numba, 1e-3) + + if __name__ == "__main__": from spikeinterface.postprocessing.tests.common_extension_tests import get_dataset from spikeinterface.core import estimate_sparsity diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 9df60af3db..bf723c84b9 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -46,7 +46,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data): filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() - assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2) # Then, change all kwargs to ensure they are propagated # and check the backwards version. @@ -66,7 +66,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data): filt_data = causal_filter(recording, direction="backward", **options, margin_ms=0).get_traces() - assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2) def test_causal_filter_custom_coeff(self, recording_and_data): """ @@ -89,7 +89,7 @@ def test_causal_filter_custom_coeff(self, recording_and_data): filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() - assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2, equal_nan=True) # Next, in "sos" mode options["filter_mode"] = "sos" @@ -100,7 +100,7 @@ def test_causal_filter_custom_coeff(self, recording_and_data): filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() - assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2, equal_nan=True) def test_causal_kwarg_error_raised(self, recording_and_data): """ diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 7c099a2f74..4c68dfea59 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -2,15 +2,16 @@ from __future__ import annotations - +import warnings from copy import deepcopy - -import numpy as np +import platform from tqdm.auto import tqdm -from concurrent.futures import ProcessPoolExecutor +import numpy as np -import warnings +import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor +from threadpoolctl import threadpool_limits from .misc_metrics import compute_num_spikes, compute_firing_rates @@ -56,6 +57,8 @@ def compute_pc_metrics( seed=None, n_jobs=1, progress_bar=False, + mp_context=None, + max_threads_per_process=None, ) -> dict: """ Calculate principal component derived metrics. @@ -144,17 +147,7 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = ( - pcs_flat, - labels, - non_nn_metrics, - unit_id, - unit_ids, - qm_params, - seed, - n_spikes_all_units, - fr_all_units, - ) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process) items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -167,7 +160,15 @@ def compute_pc_metrics( for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric elif run_in_parallel and non_nn_metrics: - with ProcessPoolExecutor(n_jobs) as executor: + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + + with ProcessPoolExecutor( + max_workers=n_jobs, + mp_context=mp.get_context(mp_context), + ) as executor: results = executor.map(pca_metrics_one_unit, items) if progress_bar: results = tqdm(results, total=len(unit_ids), desc="calculate_pc_metrics") @@ -976,26 +977,19 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - ( - pcs_flat, - labels, - metric_names, - unit_id, - unit_ids, - qm_params, - seed, - # we_folder, - n_spikes_all_units, - fr_all_units, - ) = args - - # if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names: - # we = load_waveforms(we_folder) + (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args + + if max_threads_per_process is None: + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + else: + with threadpool_limits(limits=int(max_threads_per_process)): + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + +def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: - try: isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) except: diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 3b6c6d3e50..b6a50d60f5 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -234,7 +234,8 @@ def _run(self, verbose=False, **job_kwargs): ) existing_metrics = [] - qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) + qm_extension = self.sorting_analyzer.extensions.get("quality_metrics", None) if ( delete_existing_metrics is False and qm_extension is not None diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 6ddeb02689..f2e912c6b4 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -1,9 +1,7 @@ import pytest import numpy as np -from spikeinterface.qualitymetrics import ( - compute_pc_metrics, -) +from spikeinterface.qualitymetrics import compute_pc_metrics, get_quality_pca_metric_list def test_calculate_pc_metrics(small_sorting_analyzer): @@ -22,3 +20,24 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res2[metric_name].values)) assert np.array_equal(res1[metric_name].values, res2[metric_name].values) + + +def test_pca_metrics_multi_processing(small_sorting_analyzer): + sorting_analyzer = small_sorting_analyzer + + metric_names = get_quality_pca_metric_list() + metric_names.remove("nn_isolation") + metric_names.remove("nn_noise_overlap") + + print(f"Computing PCA metrics with 1 thread per process") + res1 = compute_pc_metrics( + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=1, progress_bar=True + ) + print(f"Computing PCA metrics with 2 thread per process") + res2 = compute_pc_metrics( + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + ) + print("Computing PCA metrics with spawn context") + res2 = compute_pc_metrics( + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + ) diff --git a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py index 333bcdbc32..df6e3821bb 100644 --- a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py +++ b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py @@ -4,12 +4,18 @@ from spikeinterface.sorters import Spykingcircus2Sorter +from pathlib import Path + class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Spykingcircus2Sorter if __name__ == "__main__": + from spikeinterface import set_global_job_kwargs + + set_global_job_kwargs(n_jobs=1, progress_bar=False) test = SpykingCircus2SorterCommonTestSuite() + test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters" test.setUp() test.test_with_run() diff --git a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py index 58d6c15c8d..b256dd1328 100644 --- a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py @@ -4,6 +4,8 @@ from spikeinterface.sorters import Tridesclous2Sorter +from pathlib import Path + class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Tridesclous2Sorter @@ -11,5 +13,6 @@ class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase if __name__ == "__main__": test = Tridesclous2SorterCommonTestSuite() + test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters" test.setUp() test.test_with_run() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 57755cd759..a180fb4e02 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -226,7 +226,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_method = params["matching"]["method"] matching_params = params["matching"]["method_kwargs"].copy() matching_params["templates"] = templates - matching_params["noise_levels"] = noise_levels + if params["matching"]["method"] in ("tdc-peeler",): + matching_params["noise_levels"] = noise_levels spikes = find_spikes_from_templates( recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs ) diff --git a/src/spikeinterface/sorters/utils/shellscript.py b/src/spikeinterface/sorters/utils/shellscript.py index 286445dd2d..24f353bf00 100644 --- a/src/spikeinterface/sorters/utils/shellscript.py +++ b/src/spikeinterface/sorters/utils/shellscript.py @@ -86,15 +86,15 @@ def start(self) -> None: if self._verbose: print("RUNNING SHELL SCRIPT: " + cmd) self._start_time = time.time() + encoding = sys.getdefaultencoding() self._process = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True, encoding=encoding ) with open(script_log_path, "w+") as script_log_file: for line in self._process.stdout: script_log_file.write(line) - if ( - self._verbose - ): # Print onto console depending on the verbose property passed on from the sorter class + if self._verbose: + # Print onto console depending on the verbose property passed on from the sorter class print(line) def wait(self, timeout=None) -> Optional[int]: diff --git a/src/spikeinterface/sortingcomponents/benchmark/__init__.py b/src/spikeinterface/sortingcomponents/benchmark/__init__.py deleted file mode 100644 index ad6d444bdb..0000000000 --- a/src/spikeinterface/sortingcomponents/benchmark/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Module to benchmark some sorting components: - * clustering - * motion - * template matching -""" diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py deleted file mode 100644 index a9e404292d..0000000000 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - - -@pytest.mark.skip() -def test_benchmark_peak_selection(create_cache_folder): - cache_folder = create_cache_folder - pass - - -if __name__ == "__main__": - test_benchmark_peak_selection() diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 18fb7d198e..65b2308996 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -601,21 +601,11 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, sub_recording = recording.frame_slice(t_start, t_stop) local_params.update({"ignore_inds": ignore_inds + [i]}) - spikes, computed = find_spikes_from_templates( + + spikes, more_outputs = find_spikes_from_templates( sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs ) - local_params.update( - { - "overlaps": computed["overlaps"], - "normed_templates": computed["normed_templates"], - "norms": computed["norms"], - "temporal": computed["temporal"], - "spatial": computed["spatial"], - "singular": computed["singular"], - "units_overlaps": computed["units_overlaps"], - "unit_overlaps_indices": computed["unit_overlaps_indices"], - } - ) + local_params["precomputed"] = more_outputs valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2 * margin) if np.sum(valid) > 0: diff --git a/src/spikeinterface/sortingcomponents/matching/base.py b/src/spikeinterface/sortingcomponents/matching/base.py new file mode 100644 index 0000000000..0e60a9e864 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/matching/base.py @@ -0,0 +1,48 @@ +import numpy as np +from spikeinterface.core import Templates +from spikeinterface.core.node_pipeline import PeakDetector + +_base_matching_dtype = [ + ("sample_index", "int64"), + ("channel_index", "int64"), + ("cluster_index", "int64"), + ("amplitude", "float64"), + ("segment_index", "int64"), +] + + +class BaseTemplateMatching(PeakDetector): + def __init__(self, recording, templates, return_output=True, parents=None): + # TODO make a sharedmem of template here + # TODO maybe check that channel_id are the same with recording + + assert isinstance( + templates, Templates + ), f"The templates supplied is of type {type(templates)} and must be a Templates" + self.templates = templates + PeakDetector.__init__(self, recording, return_output=return_output, parents=parents) + + def get_dtype(self): + return np.dtype(_base_matching_dtype) + + def get_trace_margin(self): + raise NotImplementedError + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + spikes = self.compute_matching(traces, start_frame, end_frame, segment_index) + spikes["segment_index"] = segment_index + + margin = self.get_trace_margin() + if margin > 0 and spikes.size > 0: + keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (traces.shape[0] - margin)) + spikes = spikes[keep] + + # node pipeline need to return a tuple + return (spikes,) + + def compute_matching(self, traces, start_frame, end_frame, segment_index): + raise NotImplementedError + + def get_extra_outputs(self): + # can be overwritten if need to ouput some variables with a dict + return None diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ad7391a297..a3624f4296 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -17,7 +17,7 @@ ("segment_index", "int64"), ] -from .main import BaseTemplateMatchingEngine +from .base import BaseTemplateMatching def compress_templates( @@ -89,7 +89,7 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities): return new_overlaps -class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): +class CircusOMPSVDPeeler(BaseTemplateMatching): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -121,147 +121,148 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): ----- """ - _default_params = { - "amplitudes": [0.6, np.inf], - "stop_criteria": "max_failures", - "max_failures": 10, - "omp_min_sps": 0.1, - "relative_error": 5e-5, - "templates": None, - "rank": 5, - "ignore_inds": [], - "vicinity": 3, - } + _more_output_keys = [ + "norms", + "temporal", + "spatial", + "singular", + "units_overlaps", + "unit_overlaps_indices", + "normed_templates", + "overlaps", + ] + + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + amplitudes=[0.6, np.inf], + stop_criteria="max_failures", + max_failures=10, + omp_min_sps=0.1, + relative_error=5e-5, + rank=5, + ignore_inds=[], + vicinity=3, + precomputed=None, + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) + + self.num_channels = recording.get_num_channels() + self.num_samples = templates.num_samples + self.nbefore = templates.nbefore + self.nafter = templates.nafter + self.sampling_frequency = recording.get_sampling_frequency() + self.vicinity = vicinity * self.num_samples + + self.amplitudes = amplitudes + self.stop_criteria = stop_criteria + self.max_failures = max_failures + self.omp_min_sps = omp_min_sps + self.relative_error = relative_error + self.rank = rank + + self.num_templates = len(templates.unit_ids) + + if precomputed is None: + self._prepare_templates() + else: + for key in self._more_output_keys: + assert precomputed[key] is not None, "If templates are provided, %d should also be there" % key + setattr(self, key, precomputed[key]) + + self.ignore_inds = np.array(ignore_inds) + + self.unit_overlaps_tables = {} + for i in range(self.num_templates): + self.unit_overlaps_tables[i] = np.zeros(self.num_templates, dtype=int) + self.unit_overlaps_tables[i][self.unit_overlaps_indices[i]] = np.arange(len(self.unit_overlaps_indices[i])) - @classmethod - def _prepare_templates(cls, d): - templates = d["templates"] - num_templates = len(d["templates"].unit_ids) + if self.vicinity > 0: + self.margin = self.vicinity + else: + self.margin = 2 * self.num_samples + + def _prepare_templates(self): - assert d["stop_criteria"] in ["max_failures", "omp_min_sps", "relative_error"] + assert self.stop_criteria in ["max_failures", "omp_min_sps", "relative_error"] - sparsity = templates.sparsity.mask + sparsity = self.templates.sparsity.mask units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) - d["units_overlaps"] = units_overlaps > 0 - d["unit_overlaps_indices"] = {} - for i in range(num_templates): - (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) + self.units_overlaps = units_overlaps > 0 + self.unit_overlaps_indices = {} + for i in range(self.num_templates): + self.unit_overlaps_indices[i] = np.flatnonzero(self.units_overlaps[i]) - templates_array = templates.get_dense_templates().copy() + templates_array = self.templates.get_dense_templates().copy() # Then we keep only the strongest components - d["temporal"], d["singular"], d["spatial"], templates_array = compress_templates(templates_array, d["rank"]) + self.temporal, self.singular, self.spatial, templates_array = compress_templates(templates_array, self.rank) - d["normed_templates"] = np.zeros(templates_array.shape, dtype=np.float32) - d["norms"] = np.zeros(num_templates, dtype=np.float32) + self.normed_templates = np.zeros(templates_array.shape, dtype=np.float32) + self.norms = np.zeros(self.num_templates, dtype=np.float32) # And get the norms, saving compressed templates for CC matrix - for count in range(num_templates): + for count in range(self.num_templates): template = templates_array[count][:, sparsity[count]] - d["norms"][count] = np.linalg.norm(template) - d["normed_templates"][count][:, sparsity[count]] = template / d["norms"][count] + self.norms[count] = np.linalg.norm(template) + self.normed_templates[count][:, sparsity[count]] = template / self.norms[count] - d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] - d["temporal"] = np.flip(d["temporal"], axis=1) + self.temporal /= self.norms[:, np.newaxis, np.newaxis] + self.temporal = np.flip(self.temporal, axis=1) - d["overlaps"] = [] - d["max_similarity"] = np.zeros((num_templates, num_templates), dtype=np.float32) - for i in range(num_templates): - num_overlaps = np.sum(d["units_overlaps"][i]) - overlapping_units = np.where(d["units_overlaps"][i])[0] + self.overlaps = [] + self.max_similarity = np.zeros((self.num_templates, self.num_templates), dtype=np.float32) + for i in range(self.num_templates): + num_overlaps = np.sum(self.units_overlaps[i]) + overlapping_units = np.flatnonzero(self.units_overlaps[i]) # Reconstruct unit template from SVD Matrices - data = d["temporal"][i] * d["singular"][i][np.newaxis, :] - template_i = np.matmul(data, d["spatial"][i, :, :]) + data = self.temporal[i] * self.singular[i][np.newaxis, :] + template_i = np.matmul(data, self.spatial[i, :, :]) template_i = np.flipud(template_i) - unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32) + unit_overlaps = np.zeros([num_overlaps, 2 * self.num_samples - 1], dtype=np.float32) for count, j in enumerate(overlapping_units): overlapped_channels = sparsity[j] visible_i = template_i[:, overlapped_channels] - spatial_filters = d["spatial"][j, :, overlapped_channels] + spatial_filters = self.spatial[j, :, overlapped_channels] spatially_filtered_template = np.matmul(visible_i, spatial_filters) - visible_i = spatially_filtered_template * d["singular"][j] + visible_i = spatially_filtered_template * self.singular[j] for rank in range(visible_i.shape[1]): - unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full") + unit_overlaps[count, :] += np.convolve(visible_i[:, rank], self.temporal[j][:, rank], mode="full") - d["max_similarity"][i, j] = np.max(unit_overlaps[count]) + self.max_similarity[i, j] = np.max(unit_overlaps[count]) - d["overlaps"].append(unit_overlaps) + self.overlaps.append(unit_overlaps) - if d["amplitudes"] is None: - distances = np.sort(d["max_similarity"], axis=1)[:, ::-1] + if self.amplitudes is None: + distances = np.sort(self.max_similarity, axis=1)[:, ::-1] distances = 1 - distances[:, 1] / 2 - d["amplitudes"] = np.zeros((num_templates, 2)) - d["amplitudes"][:, 0] = distances - d["amplitudes"][:, 1] = np.inf - - d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) - d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) - d["singular"] = d["singular"].T[:, :, np.newaxis] - return d - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - d = cls._default_params.copy() - d.update(kwargs) - - assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) + self.amplitudes = np.zeros((self.num_templates, 2)) + self.amplitudes[:, 0] = distances + self.amplitudes[:, 1] = np.inf - d["num_channels"] = recording.get_num_channels() - d["num_samples"] = d["templates"].num_samples - d["nbefore"] = d["templates"].nbefore - d["nafter"] = d["templates"].nafter - d["sampling_frequency"] = recording.get_sampling_frequency() - d["vicinity"] *= d["num_samples"] + self.spatial = np.moveaxis(self.spatial, [0, 1, 2], [1, 0, 2]) + self.temporal = np.moveaxis(self.temporal, [0, 1, 2], [1, 2, 0]) + self.singular = self.singular.T[:, :, np.newaxis] - if "overlaps" not in d: - d = cls._prepare_templates(d) - else: - for key in [ - "norms", - "temporal", - "spatial", - "singular", - "units_overlaps", - "unit_overlaps_indices", - ]: - assert d[key] is not None, "If templates are provided, %d should also be there" % key - - d["num_templates"] = len(d["templates"].templates_array) - d["ignore_inds"] = np.array(d["ignore_inds"]) - - d["unit_overlaps_tables"] = {} - for i in range(d["num_templates"]): - d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) - d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) - - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - if kwargs["vicinity"] > 0: - margin = kwargs["vicinity"] - else: - margin = 2 * kwargs["num_samples"] - return margin + def get_extra_outputs(self): + output = {} + for key in self._more_output_keys: + output[key] = getattr(self, key) + return output + + def get_trace_margin(self): + return self.margin - @classmethod - def main_function(cls, traces, d): + def compute_matching(self, traces, start_frame, end_frame, segment_index): import scipy.spatial import scipy @@ -269,50 +270,45 @@ def main_function(cls, traces, d): (nrm2,) = scipy.linalg.get_blas_funcs(("nrm2",), dtype=np.float32) - num_templates = d["num_templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - overlaps_array = d["overlaps"] - norms = d["norms"] + overlaps_array = self.overlaps + omp_tol = np.finfo(np.float32).eps - num_samples = d["nafter"] + d["nbefore"] + num_samples = self.nafter + self.nbefore neighbor_window = num_samples - 1 - if isinstance(d["amplitudes"], list): - min_amplitude, max_amplitude = d["amplitudes"] + if isinstance(self.amplitudes, list): + min_amplitude, max_amplitude = self.amplitudes else: - min_amplitude, max_amplitude = d["amplitudes"][:, 0], d["amplitudes"][:, 1] + min_amplitude, max_amplitude = self.amplitudes[:, 0], self.amplitudes[:, 1] min_amplitude = min_amplitude[:, np.newaxis] max_amplitude = max_amplitude[:, np.newaxis] - ignore_inds = d["ignore_inds"] - vicinity = d["vicinity"] num_timesteps = len(traces) num_peaks = num_timesteps - num_samples + 1 - conv_shape = (num_templates, num_peaks) + conv_shape = (self.num_templates, num_peaks) scalar_products = np.zeros(conv_shape, dtype=np.float32) # Filter using overlap-and-add convolution - if len(ignore_inds) > 0: - not_ignored = ~np.isin(np.arange(num_templates), ignore_inds) - spatially_filtered_data = np.matmul(d["spatial"][:, not_ignored, :], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"][:, not_ignored, :] + if len(self.ignore_inds) > 0: + not_ignored = ~np.isin(np.arange(self.num_templates), self.ignore_inds) + spatially_filtered_data = np.matmul(self.spatial[:, not_ignored, :], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * self.singular[:, not_ignored, :] objective_by_rank = scipy.signal.oaconvolve( - scaled_filtered_data, d["temporal"][:, not_ignored, :], axes=2, mode="valid" + scaled_filtered_data, self.temporal[:, not_ignored, :], axes=2, mode="valid" ) scalar_products[not_ignored] += np.sum(objective_by_rank, axis=0) - scalar_products[ignore_inds] = -np.inf + scalar_products[self.ignore_inds] = -np.inf else: - spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"] - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") + spatially_filtered_data = np.matmul(self.spatial, traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * self.singular + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, self.temporal, axes=2, mode="valid") scalar_products += np.sum(objective_by_rank, axis=0) num_spikes = 0 spikes = np.empty(scalar_products.size, dtype=spike_dtype) - M = np.zeros((num_templates, num_templates), dtype=np.float32) + M = np.zeros((self.num_templates, self.num_templates), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -325,13 +321,13 @@ def main_function(cls, traces, d): all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) - if d["stop_criteria"] == "omp_min_sps": - stop_criteria = d["omp_min_sps"] * np.maximum(d["norms"], np.sqrt(num_channels * num_samples)) - elif d["stop_criteria"] == "max_failures": + if self.stop_criteria == "omp_min_sps": + stop_criteria = self.omp_min_sps * np.maximum(self.norms, np.sqrt(self.num_channels * num_samples)) + elif self.stop_criteria == "max_failures": num_valids = 0 - nb_failures = d["max_failures"] - elif d["stop_criteria"] == "relative_error": - if len(ignore_inds) > 0: + nb_failures = self.max_failures + elif self.stop_criteria == "relative_error": + if len(self.ignore_inds) > 0: new_error = np.linalg.norm(scalar_products[not_ignored]) else: new_error = np.linalg.norm(scalar_products) @@ -350,8 +346,8 @@ def main_function(cls, traces, d): myindices = selection[0, idx] local_overlaps = overlaps_array[best_cluster_ind] - overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] - table = d["unit_overlaps_tables"][best_cluster_ind] + overlapping_templates = self.unit_overlaps_indices[best_cluster_ind] + table = self.unit_overlaps_tables[best_cluster_ind] if num_selection == M.shape[0]: Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) @@ -362,7 +358,7 @@ def main_function(cls, traces, d): a, b = myindices[mask], myline[mask] M[num_selection, idx[mask]] = local_overlaps[table[a], b] - if vicinity == 0: + if self.vicinity == 0: scipy.linalg.solve_triangular( M[:num_selection, :num_selection], M[num_selection, :num_selection], @@ -378,7 +374,7 @@ def main_function(cls, traces, d): break M[num_selection, num_selection] = np.sqrt(Lkk) else: - is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + is_in_vicinity = np.where(np.abs(delta_t) < self.vicinity)[0] if len(is_in_vicinity) > 0: L = M[is_in_vicinity, :][:, is_in_vicinity] @@ -403,15 +399,15 @@ def main_function(cls, traces, d): selection = all_selections[:, :num_selection] res_sps = full_sps[selection[0], selection[1]] - if vicinity == 0: + if self.vicinity == 0: all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - all_amplitudes /= norms[selection[0]] + all_amplitudes /= self.norms[selection[0]] else: is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) all_amplitudes = np.append(all_amplitudes, np.float32(1)) L = M[is_in_vicinity, :][:, is_in_vicinity] all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) - all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] + all_amplitudes[is_in_vicinity] /= self.norms[selection[0][is_in_vicinity]] diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] @@ -419,10 +415,10 @@ def main_function(cls, traces, d): for i in modified: tmp_best, tmp_peak = selection[:, i] - diff_amp = diff_amplitudes[i] * norms[tmp_best] + diff_amp = diff_amplitudes[i] * self.norms[tmp_best] local_overlaps = overlaps_array[tmp_best] - overlapping_templates = d["units_overlaps"][tmp_best] + overlapping_templates = self.units_overlaps[tmp_best] if not tmp_peak in neighbors.keys(): idx = [max(0, tmp_peak - neighbor_window), min(num_peaks, tmp_peak + num_samples)] @@ -436,44 +432,47 @@ def main_function(cls, traces, d): scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add # We stop when updates do not modify the chosen spikes anymore - if d["stop_criteria"] == "omp_min_sps": + if self.stop_criteria == "omp_min_sps": is_valid = scalar_products > stop_criteria[:, np.newaxis] do_loop = np.any(is_valid) - elif d["stop_criteria"] == "max_failures": + elif self.stop_criteria == "max_failures": is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) new_num_valids = np.sum(is_valid) if (new_num_valids - num_valids) > 0: - nb_failures = d["max_failures"] + nb_failures = self.max_failures else: nb_failures -= 1 num_valids = new_num_valids do_loop = nb_failures > 0 - elif d["stop_criteria"] == "relative_error": + elif self.stop_criteria == "relative_error": previous_error = new_error - if len(ignore_inds) > 0: + if len(self.ignore_inds) > 0: new_error = np.linalg.norm(scalar_products[not_ignored]) else: new_error = np.linalg.norm(scalar_products) delta_error = np.abs(new_error / previous_error - 1) - do_loop = delta_error > d["relative_error"] + do_loop = delta_error > self.relative_error is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) num_spikes = len(valid_indices[0]) - spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["sample_index"][:num_spikes] = valid_indices[1] + self.nbefore spikes["channel_index"][:num_spikes] = 0 spikes["cluster_index"][:num_spikes] = valid_indices[0] spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] + print("yep0", spikes.size, num_spikes, spikes.shape, spikes.dtype) spikes = spikes[:num_spikes] - order = np.argsort(spikes["sample_index"]) - spikes = spikes[order] + print("yep1", spikes.size, spikes.shape, spikes.dtype) + if spikes.size > 0: + order = np.argsort(spikes["sample_index"]) + spikes = spikes[order] return spikes -class CircusPeeler(BaseTemplateMatchingEngine): +class CircusPeeler(BaseTemplateMatching): """ Greedy Template-matching ported from the Spyking Circus sorter @@ -519,115 +518,25 @@ class CircusPeeler(BaseTemplateMatchingEngine): """ - _default_params = { - "peak_sign": "neg", - "exclude_sweep_ms": 0.1, - "jitter_ms": 0.1, - "detect_threshold": 5, - "noise_levels": None, - "random_chunk_kwargs": {}, - "max_amplitude": 1.5, - "min_amplitude": 0.5, - "use_sparse_matrix_threshold": 0.25, - "templates": None, - } - - @classmethod - def _prepare_templates(cls, d): - import scipy.spatial - import scipy - - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - use_sparse_matrix_threshold = d["use_sparse_matrix_threshold"] - - d["norms"] = np.zeros(num_templates, dtype=np.float32) - - all_units = d["templates"].unit_ids - - sparsity = templates.sparsity.mask + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + peak_sign="neg", + exclude_sweep_ms=0.1, + jitter_ms=0.1, + detect_threshold=5, + noise_levels=None, + random_chunk_kwargs={}, + max_amplitude=1.5, + min_amplitude=0.5, + use_sparse_matrix_threshold=0.25, + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - templates_array = templates.get_dense_templates() - d["sparsities"] = {} - d["normed_templates"] = {} - - for count, unit_id in enumerate(all_units): - (d["sparsities"][count],) = np.nonzero(sparsity[count]) - d["norms"][count] = np.linalg.norm(templates_array[count]) - templates_array[count] /= d["norms"][count] - d["normed_templates"][count] = templates_array[count][:, sparsity[count]] - - templates_array = templates_array.reshape(num_templates, -1) - - nnz = np.sum(templates_array != 0) / (num_templates * num_samples * num_channels) - if nnz <= use_sparse_matrix_threshold: - templates_array = scipy.sparse.csr_matrix(templates_array) - print(f"Templates are automatically sparsified (sparsity level is {nnz})") - d["is_dense"] = False - else: - d["is_dense"] = True - - d["circus_templates"] = templates_array - - return d - - # @classmethod - # def _mcc_error(cls, bounds, good, bad): - # fn = np.sum((good < bounds[0]) | (good > bounds[1])) - # fp = np.sum((bounds[0] <= bad) & (bad <= bounds[1])) - # tp = np.sum((bounds[0] <= good) & (good <= bounds[1])) - # tn = np.sum((bad < bounds[0]) | (bad > bounds[1])) - # denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) - # if denom > 0: - # mcc = 1 - (tp * tn - fp * fn) / np.sqrt(denom) - # else: - # mcc = 1 - # return mcc - - # @classmethod - # def _cost_function_mcc(cls, bounds, good, bad, delta_amplitude, alpha): - # # We want a minimal error, with the larger bounds that are possible - # cost = alpha * cls._mcc_error(bounds, good, bad) + (1 - alpha) * np.abs( - # (1 - (bounds[1] - bounds[0]) / delta_amplitude) - # ) - # return cost - - # @classmethod - # def _optimize_amplitudes(cls, noise_snippets, d): - # parameters = d - # waveform_extractor = parameters["waveform_extractor"] - # templates = parameters["templates"] - # num_templates = parameters["num_templates"] - # max_amplitude = parameters["max_amplitude"] - # min_amplitude = parameters["min_amplitude"] - # alpha = 0.5 - # norms = parameters["norms"] - # all_units = list(waveform_extractor.sorting.unit_ids) - - # parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) - # noise = templates.dot(noise_snippets) / norms[:, np.newaxis] - - # all_amps = {} - # for count, unit_id in enumerate(all_units): - # waveform = waveform_extractor.get_waveforms(unit_id, force_dense=True) - # snippets = waveform.reshape(waveform.shape[0], -1).T - # amps = templates.dot(snippets) / norms[:, np.newaxis] - # good = amps[count, :].flatten() - - # sub_amps = amps[np.concatenate((np.arange(count), np.arange(count + 1, num_templates))), :] - # bad = sub_amps[sub_amps >= good] - # bad = np.concatenate((bad, noise[count])) - # cost_kwargs = [good, bad, max_amplitude - min_amplitude, alpha] - # cost_bounds = [(min_amplitude, 1), (1, max_amplitude)] - # res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) - # parameters["amplitudes"][count] = res.x - - # return d - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): try: from sklearn.feature_extraction.image import extract_patches_2d @@ -636,108 +545,93 @@ def initialize_and_check_kwargs(cls, recording, kwargs): HAVE_SKLEARN = False assert HAVE_SKLEARN, "CircusPeeler needs sklearn to work" - d = cls._default_params.copy() - d.update(kwargs) - # assert isinstance(d['waveform_extractor'], WaveformExtractor) - for v in ["use_sparse_matrix_threshold"]: - assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + assert (use_sparse_matrix_threshold >= 0) and ( + use_sparse_matrix_threshold <= 1 + ), f"use_sparse_matrix_threshold should be in [0, 1]" - d["num_channels"] = recording.get_num_channels() - d["num_samples"] = d["templates"].num_samples - d["num_templates"] = len(d["templates"].unit_ids) + self.num_channels = recording.get_num_channels() + self.num_samples = templates.num_samples + self.num_templates = len(templates.unit_ids) - if d["noise_levels"] is None: + if noise_levels is None: print("CircusPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + noise_levels = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] + self.abs_threholds = noise_levels * detect_threshold - if "overlaps" not in d: - d = cls._prepare_templates(d) - d["overlaps"] = compute_overlaps( - d["normed_templates"], - d["num_samples"], - d["num_channels"], - d["sparsities"], - ) + self.use_sparse_matrix_threshold = use_sparse_matrix_threshold + self._prepare_templates() + self.overlaps = compute_overlaps( + self.normed_templates, + self.num_samples, + self.num_channels, + self.sparsities, + ) + + self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) + + self.nbefore = templates.nbefore + self.nafter = templates.nafter + self.patch_sizes = (templates.num_samples, self.num_channels) + self.sym_patch = self.nbefore == self.nafter + self.jitter = int(jitter_ms * recording.get_sampling_frequency() / 1000.0) + + self.amplitudes = np.zeros((self.num_templates, 2), dtype=np.float32) + self.amplitudes[:, 0] = min_amplitude + self.amplitudes[:, 1] = max_amplitude + + self.margin = max(self.nbefore, self.nafter) * 2 + self.peak_sign = peak_sign + + def _prepare_templates(self): + import scipy.spatial + import scipy + + self.norms = np.zeros(self.num_templates, dtype=np.float32) + + all_units = self.templates.unit_ids + + sparsity = self.templates.sparsity.mask + + templates_array = self.templates.get_dense_templates() + self.sparsities = {} + self.normed_templates = {} + + for count, unit_id in enumerate(all_units): + self.sparsities[count] = np.flatnonzero(sparsity[count]) + self.norms[count] = np.linalg.norm(templates_array[count]) + templates_array[count] /= self.norms[count] + self.normed_templates[count] = templates_array[count][:, sparsity[count]] + + templates_array = templates_array.reshape(self.num_templates, -1) + + nnz = np.sum(templates_array != 0) / (self.num_templates * self.num_samples * self.num_channels) + if nnz <= self.use_sparse_matrix_threshold: + templates_array = scipy.sparse.csr_matrix(templates_array) + print(f"Templates are automatically sparsified (sparsity level is {nnz})") + self.is_dense = False else: - for key in ["circus_templates", "norms"]: - assert d[key] is not None, "If templates are provided, %d should also be there" % key + self.is_dense = True - d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) + self.circus_templates = templates_array - d["nbefore"] = d["templates"].nbefore - d["nafter"] = d["templates"].nafter - d["patch_sizes"] = ( - d["templates"].num_samples, - d["num_channels"], - ) - d["sym_patch"] = d["nbefore"] == d["nafter"] - d["jitter"] = int(d["jitter_ms"] * recording.get_sampling_frequency() / 1000.0) - - d["amplitudes"] = np.zeros((d["num_templates"], 2), dtype=np.float32) - d["amplitudes"][:, 0] = d["min_amplitude"] - d["amplitudes"][:, 1] = d["max_amplitude"] - # num_segments = recording.get_num_segments() - # if d["waveform_extractor"]._params["max_spikes_per_unit"] is None: - # num_snippets = 1000 - # else: - # num_snippets = 2 * d["waveform_extractor"]._params["max_spikes_per_unit"] - - # num_chunks = num_snippets // num_segments - # noise_snippets = get_random_data_chunks( - # recording, num_chunks_per_segment=num_chunks, chunk_size=d["num_samples"], seed=42 - # ) - # noise_snippets = ( - # noise_snippets.reshape(num_chunks, d["num_samples"], d["num_channels"]) - # .reshape(num_chunks, -1) - # .T - # ) - # parameters = cls._optimize_amplitudes(noise_snippets, d) - - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) - return margin - - @classmethod - def main_function(cls, traces, d): - peak_sign = d["peak_sign"] - abs_threholds = d["abs_threholds"] - exclude_sweep_size = d["exclude_sweep_size"] - templates = d["circus_templates"] - num_templates = d["num_templates"] - overlaps = d["overlaps"] - margin = d["margin"] - norms = d["norms"] - jitter = d["jitter"] - patch_sizes = d["patch_sizes"] - num_samples = d["nafter"] + d["nbefore"] - neighbor_window = num_samples - 1 - amplitudes = d["amplitudes"] - sym_patch = d["sym_patch"] + def get_trace_margin(self): + return self.margin + + def compute_matching(self, traces, start_frame, end_frame, segment_index): + + neighbor_window = self.num_samples - 1 - peak_traces = traces[margin // 2 : -margin // 2, :] + peak_traces = traces[self.margin // 2 : -self.margin // 2, :] peak_sample_index, peak_chan_ind = DetectPeakByChannel.detect_peaks( - peak_traces, peak_sign, abs_threholds, exclude_sweep_size + peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size ) from sklearn.feature_extraction.image import extract_patches_2d - if jitter > 0: - jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-jitter, jitter) - jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * jitter) + if self.jitter > 0: + jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-self.jitter, self.jitter) + jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * self.jitter) mask = (jittered_peaks > 0) & (jittered_peaks < len(peak_traces)) jittered_peaks = jittered_peaks[mask] jittered_channels = jittered_channels[mask] @@ -749,26 +643,26 @@ def main_function(cls, traces, d): num_peaks = len(peak_sample_index) - if sym_patch: - snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_index] - peak_sample_index += margin // 2 + if self.sym_patch: + snippets = extract_patches_2d(traces, self.patch_sizes)[peak_sample_index] + peak_sample_index += self.margin // 2 else: - peak_sample_index += margin // 2 - snippet_window = np.arange(-d["nbefore"], d["nafter"]) + peak_sample_index += self.margin // 2 + snippet_window = np.arange(-self.nbefore, self.nafter) snippets = traces[peak_sample_index[:, np.newaxis] + snippet_window] if num_peaks > 0: snippets = snippets.reshape(num_peaks, -1) - scalar_products = templates.dot(snippets.T) + scalar_products = self.circus_templates.dot(snippets.T) else: - scalar_products = np.zeros((num_templates, 0), dtype=np.float32) + scalar_products = np.zeros((self.num_templates, 0), dtype=np.float32) num_spikes = 0 spikes = np.empty(scalar_products.size, dtype=spike_dtype) - idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) + idx_lookup = np.arange(scalar_products.size).reshape(self.num_templates, -1) - min_sps = (amplitudes[:, 0] * norms)[:, np.newaxis] - max_sps = (amplitudes[:, 1] * norms)[:, np.newaxis] + min_sps = (self.amplitudes[:, 0] * self.norms)[:, np.newaxis] + max_sps = (self.amplitudes[:, 1] * self.norms)[:, np.newaxis] is_valid = (scalar_products > min_sps) & (scalar_products < max_sps) @@ -787,7 +681,7 @@ def main_function(cls, traces, d): idx_neighbor = peak_data[is_valid_nn[0] : is_valid_nn[1]] + neighbor_window if not best_cluster_ind in cached_overlaps.keys(): - cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() + cached_overlaps[best_cluster_ind] = self.overlaps[best_cluster_ind].toarray() to_add = -best_amplitude * cached_overlaps[best_cluster_ind][:, idx_neighbor] @@ -802,7 +696,7 @@ def main_function(cls, traces, d): is_valid = (scalar_products > min_sps) & (scalar_products < max_sps) - spikes["amplitude"][:num_spikes] /= norms[spikes["cluster_index"][:num_spikes]] + spikes["amplitude"][:num_spikes] /= self.norms[spikes["cluster_index"][:num_spikes]] spikes = spikes[:num_spikes] order = np.argsort(spikes["sample_index"]) diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 6e5267cb70..f423d55e2a 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -3,8 +3,11 @@ from threadpoolctl import threadpool_limits import numpy as np -from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs -from spikeinterface.core import get_chunk_with_margin +# from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs +# from spikeinterface.core import get_chunk_with_margin + +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.node_pipeline import run_node_pipeline def find_spikes_from_templates( @@ -21,7 +24,7 @@ def find_spikes_from_templates( method_kwargs : dict, optional Keyword arguments for the chosen method extra_outputs : bool - If True then method_kwargs is also returned + If True then a dict is also returned is also returned **job_kwargs : dict Parameters for ChunkRecordingExecutor verbose : Bool, default: False @@ -31,9 +34,8 @@ def find_spikes_from_templates( ------- spikes : ndarray Spikes found from templates. - method_kwargs: + outputs: Optionaly returns for debug purpose. - """ from .method_list import matching_methods @@ -42,117 +44,19 @@ def find_spikes_from_templates( job_kwargs = fix_job_kwargs(job_kwargs) method_class = matching_methods[method] + node0 = method_class(recording, **method_kwargs) + nodes = [node0] - # initialize - method_kwargs = method_class.initialize_and_check_kwargs(recording, method_kwargs) - - # add - method_kwargs["margin"] = method_class.get_margin(recording, method_kwargs) - - # serialiaze for worker - method_kwargs_seralized = method_class.serialize_method_kwargs(method_kwargs) - - # and run - func = _find_spikes_chunk - init_func = _init_worker_find_spikes - init_args = (recording, method, method_kwargs_seralized) - processor = ChunkRecordingExecutor( + spikes = run_node_pipeline( recording, - func, - init_func, - init_args, - handle_returns=True, + nodes, + job_kwargs, job_name=f"find spikes ({method})", - verbose=verbose, - **job_kwargs, + gather_mode="memory", + squeeze_output=True, ) - spikes = processor.run() - - spikes = np.concatenate(spikes) - if extra_outputs: - return spikes, method_kwargs + outputs = node0.get_extra_outputs() + return spikes, outputs else: return spikes - - -def _init_worker_find_spikes(recording, method, method_kwargs): - """Initialize worker for finding spikes.""" - - from .method_list import matching_methods - - method_class = matching_methods[method] - method_kwargs = method_class.unserialize_in_worker(method_kwargs) - - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["method"] = method - worker_ctx["method_kwargs"] = method_kwargs - worker_ctx["function"] = method_class.main_function - - return worker_ctx - - -def _find_spikes_chunk(segment_index, start_frame, end_frame, worker_ctx): - """Find spikes from a chunk of data.""" - - # recover variables of the worker - recording = worker_ctx["recording"] - method = worker_ctx["method"] - method_kwargs = worker_ctx["method_kwargs"] - margin = method_kwargs["margin"] - - # load trace in memory given some margin - recording_segment = recording._recording_segments[segment_index] - traces, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, margin, add_zeros=True - ) - - function = worker_ctx["function"] - - with threadpool_limits(limits=1): - spikes = function(traces, method_kwargs) - - # remove spikes in margin - if margin > 0: - keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (traces.shape[0] - margin)) - spikes = spikes[keep] - - spikes["sample_index"] += start_frame - margin - spikes["segment_index"] = segment_index - return spikes - - -# generic class for template engine -class BaseTemplateMatchingEngine: - default_params = {} - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - """This function runs before loops""" - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def serialize_method_kwargs(cls, kwargs): - """This function serializes kwargs to distribute them to workers""" - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def unserialize_in_worker(cls, recording, kwargs): - """This function unserializes kwargs in workers""" - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def get_margin(cls, recording, kwargs): - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def main_function(cls, traces, method_kwargs): - """This function returns the number of samples for the chunk margins""" - # need to be implemented in subclass - raise NotImplementedError diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 0dc71d789b..26f093c187 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -4,115 +4,68 @@ import numpy as np -from spikeinterface.core import get_noise_levels, get_channel_distances, get_random_data_chunks +from spikeinterface.core import get_noise_levels, get_channel_distances from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive -from spikeinterface.core.template import Templates - -spike_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("cluster_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), -] - - -from .main import BaseTemplateMatchingEngine - - -class NaiveMatching(BaseTemplateMatchingEngine): - """ - This is a naive template matching that does not resolve collision - and does not take in account sparsity. - It just minimizes the distance to templates for detected peaks. - - It is implemented for benchmarking against this low quality template matching. - And also as an example how to deal with methods_kwargs, margin, intit, func, ... - """ - - default_params = { - "templates": None, - "peak_sign": "neg", - "exclude_sweep_ms": 0.1, - "detect_threshold": 5, - "noise_levels": None, - "radius_um": 100, - "random_chunk_kwargs": {}, - } - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - d = cls.default_params.copy() - d.update(kwargs) - - assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) - - templates = d["templates"] - if d["noise_levels"] is None: - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] - - channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["radius_um"] +from .base import BaseTemplateMatching, _base_matching_dtype - d["nbefore"] = templates.nbefore - d["nafter"] = templates.nafter - d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) +class NaiveMatching(BaseTemplateMatching): + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + peak_sign="neg", + exclude_sweep_ms=0.1, + detect_threshold=5, + noise_levels=None, + radius_um=100.0, + random_chunk_kwargs={}, + ): - return d + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - @classmethod - def get_margin(cls, recording, kwargs): - margin = max(kwargs["nbefore"], kwargs["nafter"]) - return margin + self.templates_array = self.templates.get_dense_templates() - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def main_function(cls, traces, method_kwargs): - peak_sign = method_kwargs["peak_sign"] - abs_threholds = method_kwargs["abs_threholds"] - exclude_sweep_size = method_kwargs["exclude_sweep_size"] - neighbours_mask = method_kwargs["neighbours_mask"] - templates_array = method_kwargs["templates"].get_dense_templates() + if noise_levels is None: + noise_levels = get_noise_levels(recording, **random_chunk_kwargs, return_scaled=False) + self.abs_threholds = noise_levels * detect_threshold + self.peak_sign = peak_sign + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance < radius_um + self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) + self.nbefore = self.templates.nbefore + self.nafter = self.templates.nafter + self.margin = max(self.nbefore, self.nafter) - nbefore = method_kwargs["nbefore"] - nafter = method_kwargs["nafter"] + def get_trace_margin(self): + return self.margin - margin = method_kwargs["margin"] + def compute_matching(self, traces, start_frame, end_frame, segment_index): - if margin > 0: - peak_traces = traces[margin:-margin, :] + if self.margin > 0: + peak_traces = traces[self.margin : -self.margin, :] else: peak_traces = traces peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, peak_sign, abs_threholds, exclude_sweep_size, neighbours_mask + peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask ) - peak_sample_ind += margin + peak_sample_ind += self.margin - spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype) + spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind - spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template + spikes["channel_index"] = peak_chan_ind # naively take the closest template for i in range(peak_sample_ind.size): - i0 = peak_sample_ind[i] - nbefore - i1 = peak_sample_ind[i] + nafter + i0 = peak_sample_ind[i] - self.nbefore + i1 = peak_sample_ind[i] + self.nafter waveforms = traces[i0:i1, :] - dist = np.sum(np.sum((templates_array - waveforms[None, :, :]) ** 2, axis=1), axis=1) + dist = np.sum(np.sum((self.templates_array - waveforms[None, :, :]) ** 2, axis=1), axis=1) cluster_index = np.argmin(dist) spikes["cluster_index"][i] = cluster_index diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index e66929e2b1..56457fe2fa 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -11,15 +11,8 @@ from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive from spikeinterface.core.template import Templates -spike_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("cluster_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), -] +from .base import BaseTemplateMatching, _base_matching_dtype -from .main import BaseTemplateMatchingEngine try: import numba @@ -30,7 +23,7 @@ HAVE_NUMBA = False -class TridesclousPeeler(BaseTemplateMatchingEngine): +class TridesclousPeeler(BaseTemplateMatching): """ Template-matching ported from Tridesclous sorter. @@ -45,87 +38,73 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): spike collision when templates have high similarity. """ - default_params = { - "templates": None, - "peak_sign": "neg", - "peak_shift_ms": 0.2, - "detect_threshold": 5, - "noise_levels": None, - "radius_um": 100, - "num_closest": 5, - "sample_shift": 3, - "ms_before": 0.8, - "ms_after": 1.2, - "num_peeler_loop": 2, - "num_template_try": 1, - } - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - assert HAVE_NUMBA, "TridesclousPeeler needs numba to be installed" - - d = cls.default_params.copy() - d.update(kwargs) - - assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + peak_sign="neg", + peak_shift_ms=0.2, + detect_threshold=5, + noise_levels=None, + radius_um=100.0, + num_closest=5, + sample_shift=3, + ms_before=0.8, + ms_after=1.2, + num_peeler_loop=2, + num_template_try=1, + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) + + # maybe in base? + self.templates_array = templates.get_dense_templates() - templates = d["templates"] unit_ids = templates.unit_ids - channel_ids = templates.channel_ids + channel_ids = recording.channel_ids + + sr = recording.sampling_frequency - sr = templates.sampling_frequency + self.nbefore = templates.nbefore + self.nafter = templates.nafter - d["nbefore"] = templates.nbefore - d["nafter"] = templates.nafter - templates_array = templates.get_dense_templates() + self.peak_sign = peak_sign - nbefore_short = int(d["ms_before"] * sr / 1000.0) - nafter_short = int(d["ms_before"] * sr / 1000.0) + nbefore_short = int(ms_before * sr / 1000.0) + nafter_short = int(ms_after * sr / 1000.0) assert nbefore_short <= templates.nbefore assert nafter_short <= templates.nafter - d["nbefore_short"] = nbefore_short - d["nafter_short"] = nafter_short + self.nbefore_short = nbefore_short + self.nafter_short = nafter_short s0 = templates.nbefore - nbefore_short s1 = -(templates.nafter - nafter_short) if s1 == 0: s1 = None - templates_short = templates_array[:, slice(s0, s1), :].copy() - d["templates_short"] = templates_short + # TODO check with out copy + self.templates_short = self.templates_array[:, slice(s0, s1), :].copy() - d["peak_shift"] = int(d["peak_shift_ms"] / 1000 * sr) + self.peak_shift = int(peak_shift_ms / 1000 * sr) - if d["noise_levels"] is None: - print("TridesclousPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording) + assert noise_levels is not None, "TridesclousPeeler : noise should be computed outside" - d["abs_thresholds"] = d["noise_levels"] * d["detect_threshold"] + self.abs_thresholds = noise_levels * detect_threshold channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["radius_um"] - - sparsity = compute_sparsity( - templates, method="best_channels" - ) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) - template_sparsity_inds = sparsity.unit_id_to_channel_indices - template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - for unit_index, unit_id in enumerate(unit_ids): - chan_inds = template_sparsity_inds[unit_id] - template_sparsity[unit_index, chan_inds] = True + self.neighbours_mask = channel_distance < radius_um - d["template_sparsity"] = template_sparsity + if templates.sparsity is not None: + self.template_sparsity = templates.sparsity.mask + else: + self.template_sparsity = np.ones((unit_ids.size, channel_ids.size), dtype=bool) - extremum_channel = get_template_extremum_channel(templates, peak_sign=d["peak_sign"], outputs="index") + extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") # as numpy vector - extremum_channel = np.array([extremum_channel[unit_id] for unit_id in unit_ids], dtype="int64") - d["extremum_channel"] = extremum_channel + self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") channel_locations = templates.probe.contact_positions - - # TODO try it with real locaion - unit_locations = channel_locations[extremum_channel] - # ~ print(unit_locations) + unit_locations = channel_locations[self.extremum_channel] # distance between units import scipy @@ -138,15 +117,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): order = np.argsort(unit_distances[unit_ind, :]) closest_u = np.arange(unit_ids.size)[order].tolist() closest_u.remove(unit_ind) - closest_u = np.array(closest_u[: d["num_closest"]]) + closest_u = np.array(closest_u[:num_closest]) # compute unitary discriminent vector - (chans,) = np.nonzero(d["template_sparsity"][unit_ind, :]) - template_sparse = templates_array[unit_ind, :, :][:, chans] + (chans,) = np.nonzero(self.template_sparsity[unit_ind, :]) + template_sparse = self.templates_array[unit_ind, :, :][:, chans] closest_vec = [] # against N closets for u in closest_u: - vec = templates_array[u, :, :][:, chans] - template_sparse + vec = self.templates_array[u, :, :][:, chans] - template_sparse vec /= np.sum(vec**2) closest_vec.append((u, vec)) # against noise @@ -154,47 +133,38 @@ def initialize_and_check_kwargs(cls, recording, kwargs): closest_units.append(closest_vec) - d["closest_units"] = closest_units + self.closest_units = closest_units # distance channel from unit import scipy distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") - near_cluster_mask = distances < d["radius_um"] + near_cluster_mask = distances < radius_um # nearby cluster for each channel - possible_clusters_by_channel = [] + self.possible_clusters_by_channel = [] for channel_index in range(distances.shape[0]): (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) - possible_clusters_by_channel.append(cluster_inds) + self.possible_clusters_by_channel.append(cluster_inds) - d["possible_clusters_by_channel"] = possible_clusters_by_channel - d["possible_shifts"] = np.arange(-d["sample_shift"], d["sample_shift"] + 1, dtype="int64") + self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") - return d + self.num_peeler_loop = num_peeler_loop + self.num_template_try = num_template_try - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs + self.margin = max(self.nbefore, self.nafter) * 2 - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs + def get_trace_margin(self): + return self.margin - @classmethod - def get_margin(cls, recording, kwargs): - margin = 2 * (kwargs["nbefore"] + kwargs["nafter"]) - return margin - - @classmethod - def main_function(cls, traces, d): + def compute_matching(self, traces, start_frame, end_frame, segment_index): traces = traces.copy() all_spikes = [] level = 0 while True: - spikes = _tdc_find_spikes(traces, d, level=level) + # spikes = _tdc_find_spikes(traces, d, level=level) + spikes = self._find_spikes_one_level(traces, level=level) keep = spikes["cluster_index"] >= 0 if not np.any(keep): @@ -203,7 +173,7 @@ def main_function(cls, traces, d): level += 1 - if level == d["num_peeler_loop"]: + if level == self.num_peeler_loop: break if len(all_spikes) > 0: @@ -211,139 +181,131 @@ def main_function(cls, traces, d): order = np.argsort(all_spikes["sample_index"]) all_spikes = all_spikes[order] else: - all_spikes = np.zeros(0, dtype=spike_dtype) + all_spikes = np.zeros(0, dtype=_base_matching_dtype) return all_spikes + def _find_spikes_one_level(self, traces, level=0): -def _tdc_find_spikes(traces, d, level=0): - peak_sign = d["peak_sign"] - templates = d["templates"] - templates_short = d["templates_short"] - templates_array = templates.get_dense_templates() - - margin = d["margin"] - possible_clusters_by_channel = d["possible_clusters_by_channel"] - - peak_traces = traces[margin // 2 : -margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, peak_sign, d["abs_thresholds"], d["peak_shift"], d["neighbours_mask"] - ) - peak_sample_ind += margin // 2 - - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] - order = np.argsort(np.abs(peak_amplitude))[::-1] - peak_sample_ind = peak_sample_ind[order] - peak_chan_ind = peak_chan_ind[order] - - spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype) - spikes["sample_index"] = peak_sample_ind - spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template - - possible_shifts = d["possible_shifts"] - distances_shift = np.zeros(possible_shifts.size) - - for i in range(peak_sample_ind.size): - sample_index = peak_sample_ind[i] - - chan_ind = peak_chan_ind[i] - possible_clusters = possible_clusters_by_channel[chan_ind] - - if possible_clusters.size > 0: - # ~ s0 = sample_index - d['nbefore'] - # ~ s1 = sample_index + d['nafter'] - - # ~ wf = traces[s0:s1, :] - - s0 = sample_index - d["nbefore_short"] - s1 = sample_index + d["nafter_short"] - wf_short = traces[s0:s1, :] - - ## pure numpy with cluster spasity - # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) - - ## pure numpy with cluster+channel spasity - # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) - # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) - - ## numba with cluster+channel spasity - union_channels = np.any(d["template_sparsity"][possible_clusters, :], axis=0) - # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) - distances = numba_sparse_dist(wf_short, templates_short, union_channels, possible_clusters) - - # DEBUG - # ~ ind = np.argmin(distances) - # ~ cluster_index = possible_clusters[ind] - - for ind in np.argsort(distances)[: d["num_template_try"]]: - cluster_index = possible_clusters[ind] - - chan_sparsity = d["template_sparsity"][cluster_index, :] - template_sparse = templates_array[cluster_index, :, :][:, chan_sparsity] - - # find best shift - - ## pure numpy version - # for s, shift in enumerate(possible_shifts): - # wf_shift = traces[s0 + shift: s1 + shift, chan_sparsity] - # distances_shift[s] = np.sum((template_sparse - wf_shift)**2) - # ind_shift = np.argmin(distances_shift) - # shift = possible_shifts[ind_shift] - - ## numba version - numba_best_shift( - traces, - templates_array[cluster_index, :, :], - sample_index, - d["nbefore"], - possible_shifts, - distances_shift, - chan_sparsity, - ) - ind_shift = np.argmin(distances_shift) - shift = possible_shifts[ind_shift] - - sample_index = sample_index + shift - s0 = sample_index - d["nbefore"] - s1 = sample_index + d["nafter"] - wf_sparse = traces[s0:s1, chan_sparsity] - - # accept or not - - centered = wf_sparse - template_sparse - accepted = True - for other_ind, other_vector in d["closest_units"][cluster_index]: - v = np.sum(centered * other_vector) - if np.abs(v) > 0.5: - accepted = False + peak_traces = traces[self.margin // 2 : -self.margin // 2, :] + peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( + peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask + ) + peak_sample_ind += self.margin // 2 + + peak_amplitude = traces[peak_sample_ind, peak_chan_ind] + order = np.argsort(np.abs(peak_amplitude))[::-1] + peak_sample_ind = peak_sample_ind[order] + peak_chan_ind = peak_chan_ind[order] + + spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) + spikes["sample_index"] = peak_sample_ind + spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template + + possible_shifts = self.possible_shifts + distances_shift = np.zeros(possible_shifts.size) + + for i in range(peak_sample_ind.size): + sample_index = peak_sample_ind[i] + + chan_ind = peak_chan_ind[i] + possible_clusters = self.possible_clusters_by_channel[chan_ind] + + if possible_clusters.size > 0: + # ~ s0 = sample_index - d['nbefore'] + # ~ s1 = sample_index + d['nafter'] + + # ~ wf = traces[s0:s1, :] + + s0 = sample_index - self.nbefore_short + s1 = sample_index + self.nafter_short + wf_short = traces[s0:s1, :] + + ## pure numpy with cluster spasity + # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) + + ## pure numpy with cluster+channel spasity + # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) + # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) + + ## numba with cluster+channel spasity + union_channels = np.any(self.template_sparsity[possible_clusters, :], axis=0) + # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) + distances = numba_sparse_dist(wf_short, self.templates_short, union_channels, possible_clusters) + + # DEBUG + # ~ ind = np.argmin(distances) + # ~ cluster_index = possible_clusters[ind] + + for ind in np.argsort(distances)[: self.num_template_try]: + cluster_index = possible_clusters[ind] + + chan_sparsity = self.template_sparsity[cluster_index, :] + template_sparse = self.templates_array[cluster_index, :, :][:, chan_sparsity] + + # find best shift + + ## pure numpy version + # for s, shift in enumerate(possible_shifts): + # wf_shift = traces[s0 + shift: s1 + shift, chan_sparsity] + # distances_shift[s] = np.sum((template_sparse - wf_shift)**2) + # ind_shift = np.argmin(distances_shift) + # shift = possible_shifts[ind_shift] + + ## numba version + numba_best_shift( + traces, + self.templates_array[cluster_index, :, :], + sample_index, + self.nbefore, + possible_shifts, + distances_shift, + chan_sparsity, + ) + ind_shift = np.argmin(distances_shift) + shift = possible_shifts[ind_shift] + + sample_index = sample_index + shift + s0 = sample_index - self.nbefore + s1 = sample_index + self.nafter + wf_sparse = traces[s0:s1, chan_sparsity] + + # accept or not + + centered = wf_sparse - template_sparse + accepted = True + for other_ind, other_vector in self.closest_units[cluster_index]: + v = np.sum(centered * other_vector) + if np.abs(v) > 0.5: + accepted = False + break + + if accepted: + # ~ if ind != np.argsort(distances)[0]: + # ~ print('not first one', np.argsort(distances), ind) break if accepted: - # ~ if ind != np.argsort(distances)[0]: - # ~ print('not first one', np.argsort(distances), ind) - break + amplitude = 1.0 - if accepted: - amplitude = 1.0 + # remove template + template = self.templates_array[cluster_index, :, :] + s0 = sample_index - self.nbefore + s1 = sample_index + self.nafter + traces[s0:s1, :] -= template * amplitude - # remove template - template = templates_array[cluster_index, :, :] - s0 = sample_index - d["nbefore"] - s1 = sample_index + d["nafter"] - traces[s0:s1, :] -= template * amplitude + else: + cluster_index = -1 + amplitude = 0.0 else: cluster_index = -1 amplitude = 0.0 - else: - cluster_index = -1 - amplitude = 0.0 - - spikes["cluster_index"][i] = cluster_index - spikes["amplitude"][i] = amplitude + spikes["cluster_index"][i] = cluster_index + spikes["amplitude"][i] = amplitude - return spikes + return spikes if HAVE_NUMBA: diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 99de6fcd4e..2531a922da 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -4,7 +4,8 @@ from dataclasses import dataclass from typing import List, Tuple, Optional -from .main import BaseTemplateMatchingEngine + +from .base import BaseTemplateMatching, _base_matching_dtype from spikeinterface.core.template import Templates @@ -197,8 +198,9 @@ def from_parameters_and_templates(cls, params, templates): return template_meta +# important : this is differents from the spikeinterface.core.Sparsity @dataclass -class Sparsity: +class WobbleSparsity: """Variables that describe channel sparsity. Parameters @@ -226,7 +228,7 @@ def from_parameters_and_templates(cls, params, templates): Returns ------- - sparsity : Sparsity + sparsity : WobbleSparsity Dataclass object for aggregating channel sparsity variables together. """ visible_channels = np.ptp(templates, axis=1) > params.visibility_threshold @@ -250,7 +252,7 @@ def from_templates(cls, params, templates): Returns ------- - sparsity : Sparsity + sparsity : WobbleSparsity Dataclass object for aggregating channel sparsity variables together. """ visible_channels = templates.sparsity.mask @@ -297,7 +299,7 @@ def __post_init__(self): self.temporal, self.singular, self.spatial, self.temporal_jittered = self.compressed_templates -class WobbleMatch(BaseTemplateMatchingEngine): +class WobbleMatch(BaseTemplateMatching): """Template matching method from the Paninski lab. Templates are jittered or "wobbled" in time and amplitude to capture variability in spike amplitude and @@ -331,53 +333,30 @@ class WobbleMatch(BaseTemplateMatchingEngine): - "peaks" are considered spikes if their amplitude clears the threshold parameter """ - default_params = { - "templates": None, - } - spike_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("cluster_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), - ] - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - """Initialize the objective and precompute various useful objects. + # default_params = { + # "templates": None, + # } - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object. - kwargs : dict - Keyword arguments for matching method. - - Returns - ------- - d : dict - Updated Keyword arguments. - """ - d = cls.default_params.copy() + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + parameters={}, + ): - required_kwargs_keys = ["templates"] - for required_key in required_kwargs_keys: - assert required_key in kwargs, f"`{required_key}` is a required key in the kwargs" + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - parameters = kwargs.get("parameters", {}) - templates = kwargs["templates"] - assert isinstance(templates, Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) templates_array = templates.get_dense_templates().astype(np.float32, casting="safe") # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) template_meta = TemplateMetadata.from_parameters_and_templates(params, templates_array) if not templates.are_templates_sparse(): - sparsity = Sparsity.from_parameters_and_templates(params, templates_array) + sparsity = WobbleSparsity.from_parameters_and_templates(params, templates_array) else: - sparsity = Sparsity.from_templates(params, templates) + sparsity = WobbleSparsity.from_templates(params, templates) # Perform initial computations on templates necessary for computing the objective sparse_templates = np.where(sparsity.visible_channels[:, np.newaxis, :], templates_array, 0) @@ -394,84 +373,47 @@ def initialize_and_check_kwargs(cls, recording, kwargs): norm_squared=norm_squared, ) - # Pack initial data into kwargs - kwargs["params"] = params - kwargs["template_meta"] = template_meta - kwargs["sparsity"] = sparsity - kwargs["template_data"] = template_data - kwargs["nbefore"] = templates.nbefore - kwargs["nafter"] = templates.nafter - d.update(kwargs) - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - # This function does nothing without a waveform extractor -- candidate for refactor - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - # This function does nothing without a waveform extractor -- candidate for refactor - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - """Get margin for chunking recording. + self.params = params + self.template_meta = template_meta + self.sparsity = sparsity + self.template_data = template_data + self.nbefore = templates.nbefore + self.nafter = templates.nafter - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object. - kwargs : dict - Keyword arguments for matching method. - - Returns - ------- - margin : int - Buffer in samples on each side of a chunk. - """ - buffer_ms = 10 - # margin = int(buffer_ms*1e-3 * recording.sampling_frequency) - margin = 300 # To ensure equivalence with spike-psvae version of the algorithm - return margin + # buffer_ms = 10 + # self.margin = int(buffer_ms*1e-3 * recording.sampling_frequency) + self.margin = 300 # To ensure equivalence with spike-psvae version of the algorithm - @classmethod - def main_function(cls, traces, method_kwargs): - """Detect spikes in traces using the template matching algorithm. + def get_trace_margin(self): + return self.margin - Parameters - ---------- - traces : ndarray (chunk_len + 2*margin, num_channels) - Voltage traces for a chunk of the recording. - method_kwargs : dict - Keyword arguments for matching method. + def compute_matching(self, traces, start_frame, end_frame, segment_index): - Returns - ------- - spikes : ndarray (num_spikes,) - Resulting spike train. - """ # Unpack method_kwargs - nbefore, nafter = method_kwargs["nbefore"], method_kwargs["nafter"] - template_meta = method_kwargs["template_meta"] - params = method_kwargs["params"] - sparsity = method_kwargs["sparsity"] - template_data = method_kwargs["template_data"] + # nbefore, nafter = method_kwargs["nbefore"], method_kwargs["nafter"] + # template_meta = method_kwargs["template_meta"] + # params = method_kwargs["params"] + # sparsity = method_kwargs["sparsity"] + # template_data = method_kwargs["template_data"] # Check traces assert traces.dtype == np.float32, "traces must be specified as np.float32" # Compute objective - objective = compute_objective(traces, template_data, params.approx_rank) - objective_normalized = 2 * objective - template_data.norm_squared[:, np.newaxis] + objective = compute_objective(traces, self.template_data, self.params.approx_rank) + objective_normalized = 2 * objective - self.template_data.norm_squared[:, np.newaxis] # Compute spike train spike_trains, scalings, distance_metrics = [], [], [] - for i in range(params.max_iter): + for i in range(self.params.max_iter): # find peaks - spike_train, scaling, distance_metric = cls.find_peaks( - objective, objective_normalized, np.array(spike_trains), params, template_data, template_meta + spike_train, scaling, distance_metric = self.find_peaks( + objective, + objective_normalized, + np.array(spike_trains), + self.params, + self.template_data, + self.template_meta, ) if len(spike_train) == 0: break @@ -482,15 +424,22 @@ def main_function(cls, traces, method_kwargs): distance_metrics.extend(list(distance_metric)) # subtract newly detected spike train from traces (via the objective) - objective, objective_normalized = cls.subtract_spike_train( - spike_train, scaling, template_data, objective, objective_normalized, params, template_meta, sparsity + objective, objective_normalized = self.subtract_spike_train( + spike_train, + scaling, + self.template_data, + objective, + objective_normalized, + self.params, + self.template_meta, + self.sparsity, ) spike_train = np.array(spike_trains) scalings = np.array(scalings) distance_metric = np.array(distance_metrics) if len(spike_train) == 0: # no spikes found - return np.zeros(0, dtype=cls.spike_dtype) + return np.zeros(0, dtype=_base_matching_dtype) # order spike times index = np.argsort(spike_train[:, 0]) @@ -499,8 +448,8 @@ def main_function(cls, traces, method_kwargs): distance_metric = distance_metric[index] # adjust spike_train - spike_train[:, 0] += nbefore # beginning of template --> center of template - spike_train[:, 1] //= params.jitter_factor # jittered_index --> template_index + spike_train[:, 0] += self.nbefore # beginning of template --> center of template + spike_train[:, 1] //= self.params.jitter_factor # jittered_index --> template_index # TODO : Benchmark spike amplitudes # Find spike amplitudes / channels @@ -512,7 +461,7 @@ def main_function(cls, traces, method_kwargs): channel_inds.append(best_ch) # assign result to spikes array - spikes = np.zeros(spike_train.shape[0], dtype=cls.spike_dtype) + spikes = np.zeros(spike_train.shape[0], dtype=_base_matching_dtype) spikes["sample_index"] = spike_train[:, 0] spikes["cluster_index"] = spike_train[:, 1] spikes["channel_index"] = channel_inds @@ -622,7 +571,7 @@ def subtract_spike_train( Dataclass object for aggregating the parameters together. template_meta : TemplateMetadata Dataclass object for aggregating template metadata together. - sparsity : Sparsity + sparsity : WobbleSparsity Dataclass object for aggregating channel sparsity variables together. Returns diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 4fe90dd7bc..ad8897df91 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -631,47 +631,31 @@ def __init__( self.conv_margin = prototype.shape[0] assert peak_sign in ("both", "neg", "pos") - idx = np.argmax(np.abs(prototype)) + self.nbefore = int(ms_before * recording.sampling_frequency / 1000) if peak_sign == "neg": - assert prototype[idx] < 0, "Prototype should have a negative peak" + assert prototype[self.nbefore] < 0, "Prototype should have a negative peak" peak_sign = "pos" elif peak_sign == "pos": - assert prototype[idx] > 0, "Prototype should have a positive peak" - elif peak_sign == "both": - raise NotImplementedError("Matched filtering not working with peak_sign=both yet!") + assert prototype[self.nbefore] > 0, "Prototype should have a positive peak" self.peak_sign = peak_sign - self.nbefore = int(ms_before * recording.sampling_frequency / 1000) + self.prototype = np.flip(prototype) / np.linalg.norm(prototype) + contact_locations = recording.get_channel_locations() dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2) - weights, self.z_factors = get_convolution_weights(dist, **weight_method) + self.weights, self.z_factors = get_convolution_weights(dist, **weight_method) + self.num_z_factors = len(self.z_factors) + self.num_channels = recording.get_num_channels() + self.num_templates = self.num_channels + if peak_sign == "both": + self.weights = np.hstack((self.weights, self.weights)) + self.weights[:, self.num_templates :, :] *= -1 + self.num_templates *= 2 - num_channels = recording.get_num_channels() - num_templates = num_channels * len(self.z_factors) - weights = weights.reshape(num_templates, -1) - - templates = weights[:, None, :] * prototype[None, :, None] - templates -= templates.mean(axis=(1, 2))[:, None, None] - temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) - temporal = temporal[:, :, :rank] - singular = singular[:, :rank] - spatial = spatial[:, :rank, :] - templates = np.matmul(temporal * singular[:, np.newaxis, :], spatial) - norms = np.linalg.norm(templates, axis=(1, 2)) - del templates - - temporal /= norms[:, np.newaxis, np.newaxis] - temporal = np.flip(temporal, axis=1) - spatial = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2]) - temporal = np.moveaxis(temporal, [0, 1, 2], [1, 2, 0]) - singular = singular.T[:, :, np.newaxis] - - self.temporal = temporal - self.spatial = spatial - self.singular = singular + self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1) random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs) - conv_random_data = self.get_convolved_traces(random_data, temporal, spatial, singular) + conv_random_data = self.get_convolved_traces(random_data) medians = np.median(conv_random_data, axis=1) medians = medians[:, None] noise_levels = np.median(np.abs(conv_random_data - medians), axis=1) / 0.6744897501960817 @@ -688,16 +672,13 @@ def get_trace_margin(self): def compute(self, traces, start_frame, end_frame, segment_index, max_margin): assert HAVE_NUMBA, "You need to install numba" - conv_traces = self.get_convolved_traces(traces, self.temporal, self.spatial, self.singular) + conv_traces = self.get_convolved_traces(traces) conv_traces /= self.abs_thresholds[:, None] conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin] traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size] - num_z_factors = len(self.z_factors) - num_templates = traces.shape[1] - - traces_center = traces_center.reshape(num_z_factors, num_templates, traces_center.shape[1]) - conv_traces = conv_traces.reshape(num_z_factors, num_templates, conv_traces.shape[1]) + traces_center = traces_center.reshape(self.num_z_factors, self.num_templates, traces_center.shape[1]) + conv_traces = conv_traces.reshape(self.num_z_factors, self.num_templates, conv_traces.shape[1]) peak_mask = traces_center > 1 peak_mask = _numba_detect_peak_matched_filtering( @@ -708,11 +689,13 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): self.abs_thresholds, self.peak_sign, self.neighbours_mask, - num_templates, + self.num_channels, ) # Find peaks and correct for time shift z_ind, peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) + if self.peak_sign == "both": + peak_chan_ind = peak_chan_ind % self.num_channels # If we want to estimate z # peak_chan_ind = peak_chan_ind % num_channels @@ -739,16 +722,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # return is always a tuple return (local_peaks,) - def get_convolved_traces(self, traces, temporal, spatial, singular): + def get_convolved_traces(self, traces): import scipy.signal - num_timesteps, num_templates = len(traces), temporal.shape[1] - num_peaks = num_timesteps - self.conv_margin + 1 - scalar_products = np.zeros((num_templates, num_peaks), dtype=np.float32) - spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * singular - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="valid") - scalar_products += np.sum(objective_by_rank, axis=0) + tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") + scalar_products = np.dot(self.weights, tmp) return scalar_products @@ -873,37 +851,28 @@ def _numba_detect_peak_neg( @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_matched_filtering( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_templates + traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels ): num_z = traces_center.shape[0] + num_templates = traces_center.shape[1] for template_ind in range(num_templates): for z in range(num_z): for s in range(peak_mask.shape[2]): if not peak_mask[z, template_ind, s]: continue for neighbour in range(num_templates): - if not neighbours_mask[template_ind, neighbour]: - continue for j in range(num_z): + if not neighbours_mask[template_ind % num_channels, neighbour % num_channels]: + continue for i in range(exclude_sweep_size): - if template_ind >= neighbour: - if z >= j: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] - ) - else: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] > traces_center[j, neighbour, s] - ) - elif template_ind < neighbour: - if z > j: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] > traces_center[j, neighbour, s] - ) - else: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] > traces_center[j, neighbour, s] - ) + if template_ind >= neighbour and z >= j: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] + ) + else: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) peak_mask[z, template_ind, s] &= ( traces_center[z, template_ind, s] > traces[j, neighbour, s + i] ) diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index fa30ba3483..7c34f5948d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -328,19 +328,38 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) ) assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np) + peaks_local_mf_filtering_both = detect_peaks( + recording, + method="matched_filtering", + peak_sign="both", + detect_threshold=5, + exclude_sweep_ms=0.1, + prototype=prototype, + ms_before=1.0, + **job_kwargs, + ) + assert len(peaks_local_mf_filtering_both) > len(peaks_local_mf_filtering) + DEBUG = False if DEBUG: import matplotlib.pyplot as plt - peaks = peaks_local_mf_filtering + peaks_local = peaks_by_channel_np + peaks_mf_neg = peaks_local_mf_filtering + peaks_mf_both = peaks_local_mf_filtering_both + labels = ["locally_exclusive", "mf_neg", "mf_both"] - sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"] + fig, ax = plt.subplots() chan_offset = 500 traces = recording.get_traces().copy() traces += np.arange(traces.shape[1])[None, :] * chan_offset - fig, ax = plt.subplots() ax.plot(traces, color="k") - ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, color="r") + + for count, peaks in enumerate([peaks_local, peaks_mf_neg, peaks_mf_both]): + sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"] + ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, label=labels[count]) + + ax.legend() plt.show() diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index dab19809be..cbf1d29932 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -10,6 +10,7 @@ job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) +# job_kwargs = dict(n_jobs=1, chunk_duration="500ms", progress_bar=True) def get_sorting_analyzer(): @@ -40,19 +41,25 @@ def test_find_spikes_from_templates(method, sorting_analyzer): noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() # sorting_analyzer - method_kwargs_all = {"templates": templates, "noise_levels": noise_levels} + method_kwargs_all = { + "templates": templates, + } method_kwargs = {} + if method in ("naive", "tdc-peeler", "circus"): + method_kwargs["noise_levels"] = noise_levels + # method_kwargs["wobble"] = { # "templates": waveform_extractor.get_all_templates(), # "nbefore": waveform_extractor.nbefore, # "nafter": waveform_extractor.nafter, # } - sampling_frequency = recording.get_sampling_frequency() + method_kwargs.update(method_kwargs_all) + spikes, info = find_spikes_from_templates( + recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + ) - method_kwargs_ = method_kwargs.get(method, {}) - method_kwargs_.update(method_kwargs_all) - spikes = find_spikes_from_templates(recording, method=method, method_kwargs=method_kwargs_, **job_kwargs) + # print(info) # DEBUG = True @@ -65,15 +72,15 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # gt_sorting = sorting_analyzer.sorting - # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) + # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency) - # metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) + # ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) # fig, ax = plt.subplots() # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) # si.plot_agreement_matrix(comp, ax=ax) # ax.set_title(method) - # plt.show() + # plt.show() if __name__ == "__main__": @@ -81,6 +88,6 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # method = "naive" # method = "tdc-peeler" # method = "circus" - # method = "circus-omp-svd" - method = "wobble" + method = "circus-omp-svd" + # method = "wobble" test_find_spikes_from_templates(method, sorting_analyzer) diff --git a/src/spikeinterface/sortingcomponents/tests/test_wobble.py b/src/spikeinterface/sortingcomponents/tests/test_wobble.py index 5e6be02409..d6d1e1e0b9 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_wobble.py +++ b/src/spikeinterface/sortingcomponents/tests/test_wobble.py @@ -143,7 +143,7 @@ def test_convolve_templates(): ) unit_overlap = unit_overlap > 0 unit_overlap = np.repeat(unit_overlap, jitter_factor, axis=0) - sparsity = wobble.Sparsity(visible_channels, unit_overlap) + sparsity = wobble.WobbleSparsity(visible_channels, unit_overlap) # Act: run convolve_templates pairwise_convolution = wobble.convolve_templates( diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 85043d0d12..5e160a6a5a 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -1,127 +1,67 @@ +""" +This module will be deprecated and will be removed in 0.102.0 + +All ploting for the previous GTStudy is now centralized in spikeinterface.benchmark.benchmark_plot_tools +Please not that GTStudy is replaced by SorterStudy wich is based more generic BenchmarkStudy. +""" + from __future__ import annotations -import numpy as np +from .base import BaseWidget -from .base import BaseWidget, to_attr +import warnings class StudyRunTimesWidget(BaseWidget): """ - Plot sorter run times for a GroundTruthStudy - + Plot sorter run times for a SorterStudy. Parameters ---------- - study : GroundTruthStudy + study : SorterStudy A study object. case_keys : list or None A selection of cases to plot, if None, then all. """ - def __init__( - self, - study, - case_keys=None, - backend=None, - **backend_kwargs, - ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, run_times=study.get_run_times(case_keys), case_keys=case_keys, colors=study.get_colors() + def __init__(self, study, case_keys=None, backend=None, **backend_kwargs): + warnings.warn( + "plot_study_run_times is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." ) - + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) + from spikeinterface.benchmark.benchmark_plot_tools import plot_run_times - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + plot_run_times(data_plot["study"], case_keys=data_plot["case_keys"]) - for i, key in enumerate(dp.case_keys): - label = dp.study.cases[key]["label"] - rt = dp.run_times.loc[key] - self.ax.bar(i, rt, width=0.8, label=label, facecolor=dp.colors[key]) - self.ax.set_ylabel("run time (s)") - self.ax.legend() - -# TODO : plot optionally average on some levels using group by class StudyUnitCountsWidget(BaseWidget): """ Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" - Parameters ---------- - study : GroundTruthStudy + study : SorterStudy A study object. case_keys : list or None A selection of cases to plot, if None, then all. """ - def __init__( - self, - study, - case_keys=None, - backend=None, - **backend_kwargs, - ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, - count_units=study.get_count_units(case_keys=case_keys), - case_keys=case_keys, + def __init__(self, study, case_keys=None, backend=None, **backend_kwargs): + warnings.warn( + "plot_study_unit_counts is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." ) - + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors + from spikeinterface.benchmark.benchmark_plot_tools import plot_unit_counts - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - columns = dp.count_units.columns.tolist() - columns.remove("num_gt") - columns.remove("num_sorter") - - ncol = len(columns) - - colors = get_some_colors(columns, color_engine="auto", map_name="hot") - colors["num_well_detected"] = "green" - - xticklabels = [] - for i, key in enumerate(dp.case_keys): - for c, col in enumerate(columns): - x = i + 1 + c / (ncol + 1) - y = dp.count_units.loc[key, col] - if not "well_detected" in col: - y = -y - - if i == 0: - label = col.replace("num_", "").replace("_", " ").title() - else: - label = None - - self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) - - xticklabels.append(dp.study.cases[key]["label"]) - - self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) - self.ax.set_xticklabels(xticklabels) - self.ax.legend() + plot_unit_counts(data_plot["study"], case_keys=data_plot["case_keys"]) class StudyPerformances(BaseWidget): @@ -154,78 +94,26 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - + warnings.warn( + "plot_study_performances is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." + ) plot_data = dict( study=study, - perfs=study.get_performance_by_unit(case_keys=case_keys), mode=mode, performance_names=performance_names, case_keys=case_keys, ) - - self.colors = study.get_colors() - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors - - import pandas as pd - import seaborn as sns - - dp = to_attr(data_plot) - perfs = dp.perfs - study = dp.study - - if dp.mode in ("ordered", "snr"): - backend_kwargs["num_axes"] = len(dp.performance_names) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - if dp.mode == "ordered": - for count, performance_name in enumerate(dp.performance_names): - ax = self.axes.flatten()[count] - for key in dp.case_keys: - label = study.cases[key]["label"] - val = perfs.xs(key).loc[:, performance_name].values - val = np.sort(val)[::-1] - ax.plot(val, label=label, c=self.colors[key]) - ax.set_title(performance_name) - if count == len(dp.performance_names) - 1: - ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) - - elif dp.mode == "snr": - metric_name = dp.mode - for count, performance_name in enumerate(dp.performance_names): - ax = self.axes.flatten()[count] - - max_metric = 0 - for key in dp.case_keys: - x = study.get_metrics(key).loc[:, metric_name].values - y = perfs.xs(key).loc[:, performance_name].values - label = study.cases[key]["label"] - ax.scatter(x, y, s=10, label=label, color=self.colors[key]) - max_metric = max(max_metric, np.max(x)) - ax.set_title(performance_name) - ax.set_xlim(0, max_metric * 1.05) - ax.set_ylim(0, 1.05) - if count == 0: - ax.legend(loc="lower right") - - elif dp.mode == "swarm": - levels = perfs.index.names - df = pd.melt( - perfs.reset_index(), - id_vars=levels, - var_name="Metric", - value_name="Score", - value_vars=dp.performance_names, - ) - df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) - sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) + from spikeinterface.benchmark.benchmark_plot_tools import plot_performances + + plot_performances( + data_plot["study"], + mode=data_plot["mode"], + performance_names=data_plot["performance_names"], + case_keys=data_plot["case_keys"], + ) class StudyAgreementMatrix(BaseWidget): @@ -251,9 +139,9 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - + warnings.warn( + "plot_study_agreement_matrix is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." + ) plot_data = dict( study=study, case_keys=case_keys, @@ -263,36 +151,9 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .comparison import AgreementMatrixWidget - - dp = to_attr(data_plot) - study = dp.study - - backend_kwargs["num_axes"] = len(dp.case_keys) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - for count, key in enumerate(dp.case_keys): - ax = self.axes.flatten()[count] - comp = study.comparisons[key] - unit_ticks = len(comp.sorting1.unit_ids) <= 16 - count_text = len(comp.sorting1.unit_ids) <= 16 - - AgreementMatrixWidget( - comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax - ) - label = study.cases[key]["label"] - ax.set_xlabel(label) - - if count > 0: - ax.set_ylabel(None) - ax.set_yticks([]) - ax.set_xticks([]) + from spikeinterface.benchmark.benchmark_plot_tools import plot_agreement_matrix - # ax0 = self.axes.flatten()[0] - # for ax in self.axes.flatten()[1:]: - # ax.sharey(ax0) + plot_agreement_matrix(data_plot["study"], ordered=data_plot["ordered"], case_keys=data_plot["case_keys"]) class StudySummary(BaseWidget): @@ -320,25 +181,26 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - plot_data = dict( - study=study, - case_keys=case_keys, + warnings.warn( + "plot_study_summary is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." ) - + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - study = data_plot["study"] case_keys = data_plot["case_keys"] - StudyPerformances(study=study, case_keys=case_keys, mode="ordered", backend="matplotlib", **backend_kwargs) - StudyPerformances(study=study, case_keys=case_keys, mode="snr", backend="matplotlib", **backend_kwargs) - StudyAgreementMatrix(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) - StudyRunTimesWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) - StudyUnitCountsWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + from spikeinterface.benchmark.benchmark_plot_tools import ( + plot_agreement_matrix, + plot_performances, + plot_unit_counts, + plot_run_times, + ) + + plot_performances(study=study, case_keys=case_keys, mode="ordered") + plot_performances(study=study, case_keys=case_keys, mode="snr") + plot_agreement_matrix(study=study, case_keys=case_keys) + plot_run_times(study=study, case_keys=case_keys) + plot_unit_counts(study=study, case_keys=case_keys)