diff --git a/doc/conf.py b/doc/conf.py index 15cb65d46a..13d1ef4e65 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -118,14 +118,15 @@ 'examples_dirs': ['../examples/modules_gallery'], 'gallery_dirs': ['modules_gallery', ], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ - '../examples/modules_gallery/core/', - '../examples/modules_gallery/extractors/', + '../examples/modules_gallery/core', + '../examples/modules_gallery/extractors', '../examples/modules_gallery/qualitymetrics', '../examples/modules_gallery/comparison', '../examples/modules_gallery/widgets', ]), 'within_subsection_order': FileNameSortKey, 'ignore_pattern': '/generate_', + 'nested_sections': False, } intersphinx_mapping = { diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index dabad818f9..da94cf549c 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -7,3 +7,4 @@ How to guides get_started analyse_neuropixels handle_drift + load_matlab_data diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst new file mode 100644 index 0000000000..e12d83810a --- /dev/null +++ b/doc/how_to/load_matlab_data.rst @@ -0,0 +1,101 @@ +Exporting MATLAB Data to Binary & Loading in SpikeInterface +=========================================================== + +In this tutorial, we will walk through the process of exporting data from MATLAB in a binary format and subsequently loading it using SpikeInterface in Python. + +Exporting Data from MATLAB +-------------------------- + +Begin by ensuring your data structure is correct. Organize your data matrix so that the first dimension corresponds to samples/time and the second to channels. +Here, we present a MATLAB code that creates a random dataset and writes it to a binary file as an illustration. + +.. code-block:: matlab + + % Define the size of your data + numSamples = 1000; + numChannels = 384; + + % Generate random data as an example + data = rand(numSamples, numChannels); + + % Write the data to a binary file + fileID = fopen('your_data_as_a_binary.bin', 'wb'); + fwrite(fileID, data, 'double'); + fclose(fileID); + +.. note:: + + In your own script, replace the random data generation with your actual dataset. + +Loading Data in SpikeInterface +------------------------------ + +After executing the above MATLAB code, a binary file named :code:`your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path. + +Use the following Python script to load the binary data into SpikeInterface: + +.. code-block:: python + + import spikeinterface as si + from pathlib import Path + + # Define file path + # For Linux or macOS: + file_path = Path("/The/Path/To/Your/Data/your_data_as_a_binary.bin") + # For Windows: + # file_path = Path(r"c:\path\to\your\data\your_data_as_a_binary.bin") + + # Confirm file existence + assert file_path.is_file(), f"Error: {file_path} is not a valid file. Please check the path." + + # Define recording parameters + sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset + num_channels = 384 # Adjust according to your MATLAB dataset + dtype = "float64" # MATLAB's double corresponds to Python's float64 + + # Load data using SpikeInterface + recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, + num_channels=num_channels, dtype=dtype) + + # Confirm that the data was loaded correctly by comparing the data shapes and see they match the MATLAB data + print(recording.get_num_frames(), recording.get_num_channels()) + +Follow the steps above to seamlessly import your MATLAB data into SpikeInterface. Once loaded, you can harness the full power of SpikeInterface for data processing, including filtering, spike sorting, and more. + +Common Pitfalls & Tips +---------------------- + +1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use :code:`time_axis=1` in :code:`si.read_binary()`. +2. **File Path**: Always double-check the Python file path. +3. **Data Type Consistency**: Ensure data types between MATLAB and Python are consistent. MATLAB's `double` is equivalent to Numpy's `float64`. +4. **Sampling Frequency**: Set the appropriate sampling frequency in Hz for SpikeInterface. +5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's `Numpy for MATLAB Users `_ guide. + +Using gains and offsets for integer data +---------------------------------------- + +Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units, you can apply a gain and an offset. +In SpikeInterface, you can use the :code:`gain_to_uV` and :code:`offset_to_uV` parameters, since traces are handled in microvolts (uV). Both parameters can be integrated into the :code:`read_binary` function. +If your data in MATLAB is stored as :code:`int16`, and you know the gain and offset, you can use the following code to load the data: + +.. code-block:: python + + sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset + num_channels = 384 # Adjust according to your MATLAB dataset + dtype_int = 'int16' # Adjust according to your MATLAB dataset + gain_to_uV = 0.195 # Adjust according to your MATLAB dataset + offset_to_uV = 0 # Adjust according to your MATLAB dataset + + recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, + num_channels=num_channels, dtype=dtype_int, + gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV) + + recording.get_traces() # Return traces in original units [type: int] + recording.get_traces(return_scaled=True) # Return traces in micro volts (uV) [type: float] + + +This will equip your recording object with capabilities to convert the data to float values in uV using the :code:`get_traces()` method with the :code:`return_scaled` parameter set to :code:`True`. + +.. note:: + + The gain and offset parameters are usually format dependent and you will need to find out the correct values for your data format. You can load your data without gain and offset but then the traces will be in integer values and not in uV. diff --git a/doc/images/plot_traces_ephyviewer.png b/doc/images/plot_traces_ephyviewer.png new file mode 100644 index 0000000000..9d926725a4 Binary files /dev/null and b/doc/images/plot_traces_ephyviewer.png differ diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index b452307e3c..76ab7855c6 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -248,21 +248,19 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for We also have a high level class to compare many sorters against ground truth: :py:func:`~spiekinterface.comparison.GroundTruthStudy()` -A study is a systematic performance comparison of several ground truth recordings with several sorters. +A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases +like the different parameter sets. -The study class proposes high-level tool functions to run many ground truth comparisons with many sorters +The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" on many recordings and then collect and aggregate results in an easy way. The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolder: - * raw_files : contain a copy of recordings in binary format - * sorter_folders : contains outputs of sorters - * ground_truth : contains a copy of sorting ground truth in npz format - * sortings: contains light copy of all sorting in npz format - * tables: some tables in csv format - -In order to run and rerun the computation all gt_sorting and recordings are copied to a fast and universal format: -binary (for recordings) and npz (for sortings). + * datasets: contains ground truth datasets + * sorters : contains outputs of sorters + * sortings: contains light copy of all sorting + * metrics: contains metrics + * ... .. code-block:: python @@ -274,28 +272,51 @@ binary (for recordings) and npz (for sortings). import spikeinterface.widgets as sw from spikeinterface.comparison import GroundTruthStudy - # Setup study folder - rec0, gt_sorting0 = se.toy_example(num_channels=4, duration=10, seed=10, num_segments=1) - rec1, gt_sorting1 = se.toy_example(num_channels=4, duration=10, seed=0, num_segments=1) - gt_dict = { - 'rec0': (rec0, gt_sorting0), - 'rec1': (rec1, gt_sorting1), + + # generate 2 simulated datasets (could be also mearec files) + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + + datasets = { + "toy0": (rec0, gt_sorting0), + "toy1": (rec1, gt_sorting1), } - study_folder = 'a_study_folder' - study = GroundTruthStudy.create(study_folder, gt_dict) - # all sorters for all recordings in one function. - sorter_list = ['herdingspikes', 'tridesclous', ] - study.run_sorters(sorter_list, mode_if_folder_exists="keep") + # define some "cases" here we want to tests tridesclous2 on 2 datasets and spykingcircus on one dataset + # so it is a two level study (sorter_name, dataset) + # this could be more complicated like (sorter_name, dataset, params) + cases = { + ("tdc2", "toy0"): { + "label": "tridesclous2 on tetrode0", + "dataset": "toy0", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + }, + ("tdc2", "toy1"): { + "label": "tridesclous2 on tetrode1", + "dataset": "toy1", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + }, + + ("sc", "toy0"): { + "label": "spykingcircus2 on tetrode0", + "dataset": "toy0", + "run_sorter_params": { + "sorter_name": "spykingcircus", + "docker_image": True + }, + }, + } + # this initilize a folder + study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases, + levels=["sorter_name", "dataset"]) - # You can re-run **run_study_sorters** as many times as you want. - # By default **mode='keep'** so only uncomputed sorters are re-run. - # For instance, just remove the "sorter_folders/rec1/herdingspikes" to re-run - # only one sorter on one recording. - # - # Then we copy the spike sorting outputs into a separate subfolder. - # This allow us to remove the "large" sorter_folders. - study.copy_sortings() + + # all cases in one function + study.run_sorters() # Collect comparisons #   @@ -306,11 +327,11 @@ binary (for recordings) and npz (for sortings). # Note: use exhaustive_gt=True when you know exactly how many # units in ground truth (for synthetic datasets) + # run all comparisons and loop over the results study.run_comparisons(exhaustive_gt=True) - - for (rec_name, sorter_name), comp in study.comparisons.items(): + for key, comp in study.comparisons.items(): print('*' * 10) - print(rec_name, sorter_name) + print(key) # raw counting of tp/fp/... print(comp.count_score) # summary @@ -323,26 +344,27 @@ binary (for recordings) and npz (for sortings). # Collect synthetic dataframes and display # As shown previously, the performance is returned as a pandas dataframe. - # The :py:func:`~spikeinterface.comparison.aggregate_performances_table()` function, + # The :py:func:`~spikeinterface.comparison.get_performance_by_unit()` function, # gathers all the outputs in the study folder and merges them in a single dataframe. + # Same idea for :py:func:`~spikeinterface.comparison.get_count_units()` - dataframes = study.aggregate_dataframes() + # this is a dataframe + perfs = study.get_performance_by_unit() - # Pandas dataframes can be nicely displayed as tables in the notebook. - print(dataframes.keys()) + # this is a dataframe + unit_counts = study.get_count_units() # we can also access run times - print(dataframes['run_times']) + run_times = study.get_run_times() + print(run_times) # Easy plot with seaborn - run_times = dataframes['run_times'] fig1, ax1 = plt.subplots() sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1) ax1.set_title('Run times') ############################################################################## - perfs = dataframes['perf_by_unit'] fig2, ax2 = plt.subplots() sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2) ax2.set_title('Recall') diff --git a/doc/modules/core.rst b/doc/modules/core.rst index fdc4d71fe7..976a82a4a3 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -547,8 +547,7 @@ workflow. In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikeinterface.core.NumpyRecording`, :py:class:`~spikeinterface.core.NumpySorting`, :py:class:`~spikeinterface.core.NumpyEvent`, and :py:class:`~spikeinterface.core.NumpySnippets`. These object behave exactly like normal SpikeInterface objects, -but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported. -In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`). +but they are not bound to a file. Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index 8c7c0a2cc3..447d83db52 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -25,9 +25,11 @@ For more details about each metric and it's availability and use within SpikeInt :glob: qualitymetrics/amplitude_cutoff + qualitymetrics/amplitude_cv qualitymetrics/amplitude_median qualitymetrics/d_prime qualitymetrics/drift + qualitymetrics/firing_range qualitymetrics/firing_rate qualitymetrics/isi_violations qualitymetrics/isolation_distance diff --git a/doc/modules/qualitymetrics/amplitude_cutoff.rst b/doc/modules/qualitymetrics/amplitude_cutoff.rst index 9f747f8d40..a1e4d85d01 100644 --- a/doc/modules/qualitymetrics/amplitude_cutoff.rst +++ b/doc/modules/qualitymetrics/amplitude_cutoff.rst @@ -21,12 +21,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitudes from all spikes - fraction_missing = qm.compute_amplitude_cutoffs(wvf_extractor, peak_sign="neg") - # fraction_missing is a dict containing the units' IDs as keys, + fraction_missing = sqm.compute_amplitude_cutoffs(wvf_extractor, peak_sign="neg") + # fraction_missing is a dict containing the unit IDs as keys, # and their estimated fraction of missing spikes as values. Reference diff --git a/doc/modules/qualitymetrics/amplitude_cv.rst b/doc/modules/qualitymetrics/amplitude_cv.rst new file mode 100644 index 0000000000..13117b607c --- /dev/null +++ b/doc/modules/qualitymetrics/amplitude_cv.rst @@ -0,0 +1,55 @@ +Amplitude CV (:code:`amplitude_cv_median`, :code:`amplitude_cv_range`) +====================================================================== + + +Calculation +----------- + +The amplitude CV (coefficient of variation) is a measure of the amplitude variability. +It is computed as the ratio between the standard deviation and the amplitude mean. +To obtain a better estimate of this measure, it is first computed separately for several temporal bins. +Out of these values, the median and the range (percentile distance, by default between the +5th and 95th percentiles) are computed. + +The computation requires either spike amplitudes (see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes()`) +or amplitude scalings (see :py:func:`~spikeinterface.postprocessing.compute_amplitude_scalings()`) to be pre-computed. + + +Expectation and use +------------------- + +The amplitude CV median is expected to be relatively low for well-isolated units, indicating a "stereotypical" spike shape. + +The amplitude CV range can be high in the presence of noise contamination, due to amplitude outliers like in +the example below. + +.. image:: amplitudes.png + :width: 600 + + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + # Make recording, sorting and wvf_extractor object for your data. + # It is required to run `compute_spike_amplitudes(wvf_extractor)` or + # `compute_amplitude_scalings(wvf_extractor)` (if missing, values will be NaN) + amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(wvf_extractor) + # amplitude_cv_median and amplitude_cv_range are dicts containing the unit ids as keys, + # and their amplitude_cv metrics as values. + + + +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_amplitude_cv_metrics + + +Literature +---------- + +Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino. diff --git a/doc/modules/qualitymetrics/amplitude_median.rst b/doc/modules/qualitymetrics/amplitude_median.rst index ffc45d1cf6..3ac52560e8 100644 --- a/doc/modules/qualitymetrics/amplitude_median.rst +++ b/doc/modules/qualitymetrics/amplitude_median.rst @@ -20,12 +20,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitude values from all spikes. - amplitude_medians = qm.compute_amplitude_medians(wvf_extractor) - # amplitude_medians is a dict containing the units' IDs as keys, + amplitude_medians = sqm.compute_amplitude_medians(wvf_extractor) + # amplitude_medians is a dict containing the unit IDs as keys, # and their estimated amplitude medians as values. Reference diff --git a/doc/modules/qualitymetrics/amplitudes.png b/doc/modules/qualitymetrics/amplitudes.png new file mode 100644 index 0000000000..0ee4dd1eda Binary files /dev/null and b/doc/modules/qualitymetrics/amplitudes.png differ diff --git a/doc/modules/qualitymetrics/d_prime.rst b/doc/modules/qualitymetrics/d_prime.rst index abb8c1dc74..e3bd61c580 100644 --- a/doc/modules/qualitymetrics/d_prime.rst +++ b/doc/modules/qualitymetrics/d_prime.rst @@ -32,9 +32,9 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm - d_prime = qm.lda_metrics(all_pcs, all_labels, 0) + d_prime = sqm.lda_metrics(all_pcs, all_labels, 0) Reference diff --git a/doc/modules/qualitymetrics/drift.rst b/doc/modules/qualitymetrics/drift.rst index 0a852f80af..ae52f7f883 100644 --- a/doc/modules/qualitymetrics/drift.rst +++ b/doc/modules/qualitymetrics/drift.rst @@ -40,11 +40,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm + # Make recording, sorting and wvf_extractor object for your data. # It is required to run `compute_spike_locations(wvf_extractor)` # (if missing, values will be NaN) - drift_ptps, drift_stds, drift_mads = qm.compute_drift_metrics(wvf_extractor, peak_sign="neg") + drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(wvf_extractor, peak_sign="neg") # drift_ptps, drift_stds, and drift_mads are dict containing the units' ID as keys, # and their metrics as values. diff --git a/doc/modules/qualitymetrics/firing_range.rst b/doc/modules/qualitymetrics/firing_range.rst new file mode 100644 index 0000000000..925539e9c6 --- /dev/null +++ b/doc/modules/qualitymetrics/firing_range.rst @@ -0,0 +1,40 @@ +Firing range (:code:`firing_range`) +=================================== + + +Calculation +----------- + +The firing range indicates the dispersion of the firing rate of a unit across the recording. It is computed by +taking the difference between the 95th percentile's firing rate and the 5th percentile's firing rate computed over short time bins (e.g. 10 s). + + + +Expectation and use +------------------- + +Very high levels of firing ranges, outside of a physiological range, might indicate noise contamination. + + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + # Make recording, sorting and wvf_extractor object for your data. + firing_range = sqm.compute_firing_ranges(wvf_extractor) + # firing_range is a dict containing the unit IDs as keys, + # and their firing firing_range as values (in Hz). + +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_firing_ranges + + +Literature +---------- + +Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino. diff --git a/doc/modules/qualitymetrics/firing_rate.rst b/doc/modules/qualitymetrics/firing_rate.rst index eddef3e48f..c0e15d7c2e 100644 --- a/doc/modules/qualitymetrics/firing_rate.rst +++ b/doc/modules/qualitymetrics/firing_rate.rst @@ -37,11 +37,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - firing_rate = qm.compute_firing_rates(wvf_extractor) - # firing_rate is a dict containing the units' IDs as keys, + firing_rate = sqm.compute_firing_rates(wvf_extractor) + # firing_rate is a dict containing the unit IDs as keys, # and their firing rates across segments as values (in Hz). References diff --git a/doc/modules/qualitymetrics/isi_violations.rst b/doc/modules/qualitymetrics/isi_violations.rst index 947e7d4938..725d9b0fd6 100644 --- a/doc/modules/qualitymetrics/isi_violations.rst +++ b/doc/modules/qualitymetrics/isi_violations.rst @@ -77,11 +77,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - isi_violations_ratio, isi_violations_count = qm.compute_isi_violations(wvf_extractor, isi_threshold_ms=1.0) + isi_violations_ratio, isi_violations_count = sqm.compute_isi_violations(wvf_extractor, isi_threshold_ms=1.0) References ---------- diff --git a/doc/modules/qualitymetrics/presence_ratio.rst b/doc/modules/qualitymetrics/presence_ratio.rst index e4de2248bd..5a420c8ccf 100644 --- a/doc/modules/qualitymetrics/presence_ratio.rst +++ b/doc/modules/qualitymetrics/presence_ratio.rst @@ -23,12 +23,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - presence_ratio = qm.compute_presence_ratios(wvf_extractor) - # presence_ratio is a dict containing the units' IDs as keys + presence_ratio = sqm.compute_presence_ratios(wvf_extractor) + # presence_ratio is a dict containing the unit IDs as keys # and their presence ratio (between 0 and 1) as values. Links to original implementations diff --git a/doc/modules/qualitymetrics/sliding_rp_violations.rst b/doc/modules/qualitymetrics/sliding_rp_violations.rst index 843242c1e8..de68c3a92f 100644 --- a/doc/modules/qualitymetrics/sliding_rp_violations.rst +++ b/doc/modules/qualitymetrics/sliding_rp_violations.rst @@ -27,11 +27,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - contamination = qm.compute_sliding_rp_violations(wvf_extractor, bin_size_ms=0.25) + contamination = sqm.compute_sliding_rp_violations(wvf_extractor, bin_size_ms=0.25) References ---------- diff --git a/doc/modules/qualitymetrics/snr.rst b/doc/modules/qualitymetrics/snr.rst index 288ab60515..b88d3291be 100644 --- a/doc/modules/qualitymetrics/snr.rst +++ b/doc/modules/qualitymetrics/snr.rst @@ -41,12 +41,12 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - SNRs = qm.compute_snrs(wvf_extractor) - # SNRs is a dict containing the units' IDs as keys and their SNRs as values. + SNRs = sqm.compute_snrs(wvf_extractor) + # SNRs is a dict containing the unit IDs as keys and their SNRs as values. Links to original implementations --------------------------------- diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index 2f566bf8a7..0750940199 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -27,9 +27,9 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - synchrony = qm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 86c541dfd0..8565e94fce 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -14,6 +14,9 @@ Since version 0.95.0, the :py:mod:`spikeinterface.widgets` module supports multi * | :code:`sortingview`: web-based and interactive rendering using the `sortingview `_ | and `FIGURL `_ packages. +Version 0.100.0, also come with this new backend: +* | :code:`ephyviewer`: interactive Qt based using the `ephyviewer `_ package + Installing backends ------------------- @@ -85,6 +88,28 @@ Finally, if you wish to set up another cloud provider, follow the instruction fr `kachery-cloud `_ package ("Using your own storage bucket"). +ephyviewer +^^^^^^^^^^ + +This backend is Qt based with PyQt5, PyQt6 or PySide6 support. Qt is sometimes tedious to install. + + +For a pip-based installation, run: + +.. code-block:: bash + + pip install PySide6 ephyviewer + + +Anaconda users will have a better experience with this: + +.. code-block:: bash + + conda install pyqt=5 + pip install ephyviewer + + + Usage ----- @@ -215,6 +240,21 @@ For example, here is how to combine the timeseries and sorting summary generated print(url) +ephyviewer +^^^^^^^^^^ + + +The :code:`ephyviewer` backend is currently only available for the :py:func:`~spikeinterface.widgets.plot_traces()` function. + + +.. code-block:: python + + plot_traces(recording, backend="ephyviewer", mode="line", show_channel_ids=True) + + +.. image:: ../images/plot_traces_ephyviewer.png + + Available plotting functions ---------------------------- @@ -229,7 +269,7 @@ Available plotting functions * :py:func:`~spikeinterface.widgets.plot_spikes_on_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`) * :py:func:`~spikeinterface.widgets.plot_template_metrics` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_template_similarity` (backends: ::code:`matplotlib`, :code:`sortingview`) -* :py:func:`~spikeinterface.widgets.plot_timeseries` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) +* :py:func:`~spikeinterface.widgets.plot_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`, :code:`ephyviewer`) * :py:func:`~spikeinterface.widgets.plot_unit_depths` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_unit_locations` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_unit_summary` (backends: :code:`matplotlib`) diff --git a/src/spikeinterface/comparison/__init__.py b/src/spikeinterface/comparison/__init__.py index a390bb7689..bff85dde4a 100644 --- a/src/spikeinterface/comparison/__init__.py +++ b/src/spikeinterface/comparison/__init__.py @@ -28,12 +28,11 @@ compare_multiple_templates, MultiTemplateComparison, ) -from .collisioncomparison import CollisionGTComparison -from .correlogramcomparison import CorrelogramGTComparison + from .groundtruthstudy import GroundTruthStudy -from .collisionstudy import CollisionGTStudy -from .correlogramstudy import CorrelogramGTStudy -from .studytools import aggregate_performances_table +from .collision import CollisionGTComparison, CollisionGTStudy +from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy + from .hybrid import ( HybridSpikesRecording, HybridUnitsRecording, diff --git a/src/spikeinterface/comparison/collisioncomparison.py b/src/spikeinterface/comparison/collision.py similarity index 64% rename from src/spikeinterface/comparison/collisioncomparison.py rename to src/spikeinterface/comparison/collision.py index 3b279717b7..dd04b2c72d 100644 --- a/src/spikeinterface/comparison/collisioncomparison.py +++ b/src/spikeinterface/comparison/collision.py @@ -1,13 +1,15 @@ -import numpy as np - from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy from .comparisontools import make_collision_events +import numpy as np + class CollisionGTComparison(GroundTruthComparison): """ - This class is an extension of GroundTruthComparison by focusing - to benchmark spike in collision + This class is an extension of GroundTruthComparison by focusing to benchmark spike in collision. + + This class needs maintenance and need a bit of refactoring. collision_lag: float @@ -156,3 +158,73 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good pair_names = pair_names[order] return similarities, recall_scores, pair_names + + +class CollisionGTStudy(GroundTruthStudy): + def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): + _kwargs = dict() + _kwargs.update(kwargs) + _kwargs["exhaustive_gt"] = exhaustive_gt + _kwargs["collision_lag"] = collision_lag + _kwargs["nbins"] = nbins + GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) + self.exhaustive_gt = exhaustive_gt + self.collision_lag = collision_lag + + def get_lags(self, key): + comp = self.comparisons[key] + fs = comp.sorting1.get_sampling_frequency() + lags = comp.bins / fs * 1000.0 + return lags + + def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): + import sklearn + + if case_keys is None: + case_keys = self.cases.keys() + + self.all_similarities = {} + self.all_recall_scores = {} + self.good_only = good_only + + for key in case_keys: + templates = self.get_templates(key) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + comp = self.comparisons[key] + similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( + similarity, good_only=good_only, min_accuracy=min_accuracy + ) + self.all_similarities[key] = similarities + self.all_recall_scores[key] = recall_scores + + def get_mean_over_similarity_range(self, similarity_range, key): + idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) + all_similarities = self.all_similarities[key][idx] + all_recall_scores = self.all_recall_scores[key][idx] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + mean_recall_scores = np.nanmean(all_recall_scores, axis=0) + + return mean_recall_scores + + def get_lag_profile_over_similarity_bins(self, similarity_bins, key): + all_similarities = self.all_similarities[key] + all_recall_scores = self.all_recall_scores[key] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) + result[(cmin, cmax)] = mean_recall_scores + + return result diff --git a/src/spikeinterface/comparison/collisionstudy.py b/src/spikeinterface/comparison/collisionstudy.py deleted file mode 100644 index 34a556e8b9..0000000000 --- a/src/spikeinterface/comparison/collisionstudy.py +++ /dev/null @@ -1,88 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .collisioncomparison import CollisionGTComparison - -import numpy as np - - -class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CollisionGTComparison( - gt_sorting, sorting, exhaustive_gt=exhaustive_gt, collision_lag=collision_lag, nbins=nbins - ) - self.comparisons[(rec_name, sorter_name)] = comp - self.exhaustive_gt = exhaustive_gt - self.collision_lag = collision_lag - - def get_lags(self): - fs = self.comparisons[(self.rec_names[0], self.sorter_names[0])].sorting1.get_sampling_frequency() - lags = self.comparisons[(self.rec_names[0], self.sorter_names[0])].bins / fs * 1000 - return lags - - def precompute_scores_by_similarities(self, good_only=True, min_accuracy=0.9): - if not hasattr(self, "_good_only") or self._good_only != good_only: - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_similarities = [] - all_recall_scores = [] - - for rec_name in self.rec_names: - if (rec_name, sorter_name) in self.comparisons.keys(): - comp = self.comparisons[(rec_name, sorter_name)] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity_matrix[rec_name], good_only=good_only, min_accuracy=min_accuracy - ) - - all_similarities.append(similarities) - all_recall_scores.append(recall_scores) - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_recall_scores[sorter_name] = np.concatenate(all_recall_scores, axis=0) - - def get_mean_over_similarity_range(self, similarity_range, sorter_name): - idx = (self.all_similarities[sorter_name] >= similarity_range[0]) & ( - self.all_similarities[sorter_name] <= similarity_range[1] - ) - all_similarities = self.all_similarities[sorter_name][idx] - all_recall_scores = self.all_recall_scores[sorter_name][idx] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = np.nanmean(all_recall_scores, axis=0) - - return mean_recall_scores - - def get_lag_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_recall_scores = self.all_recall_scores[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) - result[(cmin, cmax)] = mean_recall_scores - - return result diff --git a/src/spikeinterface/comparison/correlogramcomparison.py b/src/spikeinterface/comparison/correlogram.py similarity index 64% rename from src/spikeinterface/comparison/correlogramcomparison.py rename to src/spikeinterface/comparison/correlogram.py index 80e881a152..aaffef1887 100644 --- a/src/spikeinterface/comparison/correlogramcomparison.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -1,16 +1,17 @@ -import numpy as np from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy from spikeinterface.postprocessing import compute_correlograms +import numpy as np + + class CorrelogramGTComparison(GroundTruthComparison): """ This class is an extension of GroundTruthComparison by focusing - to benchmark correlation reconstruction - + to benchmark correlation reconstruction. - collision_lag: float - Collision lag in ms. + This class needs maintenance and need a bit of refactoring. """ @@ -105,6 +106,62 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): order = np.argsort(similarities) similarities = similarities[order] - errors = errors[order, :] + errors = errors[order] return similarities, errors + + +class CorrelogramGTStudy(GroundTruthStudy): + def run_comparisons( + self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs + ): + _kwargs = dict() + _kwargs.update(kwargs) + _kwargs["exhaustive_gt"] = exhaustive_gt + _kwargs["window_ms"] = window_ms + _kwargs["bin_ms"] = bin_ms + _kwargs["well_detected_score"] = well_detected_score + GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) + self.exhaustive_gt = exhaustive_gt + + @property + def time_bins(self): + for key, value in self.comparisons.items(): + return value.time_bins + + def precompute_scores_by_similarities(self, case_keys=None, good_only=True): + import sklearn.metrics + + if case_keys is None: + case_keys = self.cases.keys() + + self.all_similarities = {} + self.all_errors = {} + + for key in case_keys: + templates = self.get_templates(key) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + comp = self.comparisons[key] + similarities, errors = comp.compute_correlogram_by_similarity(similarity) + + self.all_similarities[key] = similarities + self.all_errors[key] = errors + + def get_error_profile_over_similarity_bins(self, similarity_bins, key): + all_similarities = self.all_similarities[key] + all_errors = self.all_errors[key] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_errors = all_errors[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_errors = np.nanmean(all_errors[amin:amax], axis=0) + result[(cmin, cmax)] = mean_errors + + return result diff --git a/src/spikeinterface/comparison/correlogramstudy.py b/src/spikeinterface/comparison/correlogramstudy.py deleted file mode 100644 index fb00c08157..0000000000 --- a/src/spikeinterface/comparison/correlogramstudy.py +++ /dev/null @@ -1,76 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .correlogramcomparison import CorrelogramGTComparison - -import numpy as np - - -class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CorrelogramGTComparison( - gt_sorting, - sorting, - exhaustive_gt=exhaustive_gt, - window_ms=window_ms, - bin_ms=bin_ms, - well_detected_score=well_detected_score, - ) - self.comparisons[(rec_name, sorter_name)] = comp - - self.exhaustive_gt = exhaustive_gt - - @property - def time_bins(self): - for key, value in self.comparisons.items(): - return value.time_bins - - def precompute_scores_by_similarities(self, good_only=True): - if not hasattr(self, "_computed"): - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_errors = {} - self._computed = True - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_errors = [] - all_similarities = [] - for rec_name in self.rec_names: - try: - comp = self.comparisons[(rec_name, sorter_name)] - similarities, errors = comp.compute_correlogram_by_similarity(similarity_matrix[rec_name]) - all_similarities.append(similarities) - all_errors.append(errors) - except Exception: - pass - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_errors[sorter_name] = np.concatenate(all_errors, axis=0) - - def get_error_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_errors = self.all_errors[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_errors = np.nanmean(all_errors[amin:amax], axis=0) - result[(cmin, cmax)] = mean_errors - - return result diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 7b146f07bc..d43727cb44 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -1,327 +1,403 @@ from pathlib import Path import shutil +import os +import json +import pickle import numpy as np -from spikeinterface.core import load_extractor -from spikeinterface.extractors import NpzSortingExtractor -from spikeinterface.sorters import sorter_dict, run_sorters +from spikeinterface.core import load_extractor, extract_waveforms, load_waveforms +from spikeinterface.core.core_tools import SIJsonEncoder + +from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder from spikeinterface import WaveformExtractor from spikeinterface.qualitymetrics import compute_quality_metrics -from .paircomparisons import compare_sorter_to_ground_truth - -from .studytools import ( - setup_comparison_study, - get_rec_names, - get_recordings, - iter_working_folder, - iter_computed_names, - iter_computed_sorting, - collect_run_times, -) +from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison -class GroundTruthStudy: - def __init__(self, study_folder=None): - import pandas as pd +# TODO later : save comparison in folders when comparison object will be able to serialize - self.study_folder = Path(study_folder) - self._is_scanned = False - self.computed_names = None - self.rec_names = None - self.sorter_names = None - self.scan_folder() +# This is to separate names when the key are tuples when saving folders +_key_separator = " ## " - self.comparisons = None - self.exhaustive_gt = None - def __repr__(self): - t = "Ground truth study\n" - t += " " + str(self.study_folder) + "\n" - t += " recordings: {} {}\n".format(len(self.rec_names), self.rec_names) - if len(self.sorter_names): - t += " sorters: {} {}\n".format(len(self.sorter_names), self.sorter_names) +class GroundTruthStudy: + """ + This class is an helper function to run any comparison on several "cases" for many ground-truth dataset. - return t + "cases" refer to: + * several sorters for comparisons + * same sorter with differents parameters + * any combination of these (and more) - def scan_folder(self): - self.rec_names = get_rec_names(self.study_folder) - # scan computed names - self.computed_names = list(iter_computed_names(self.study_folder)) # list of pair (rec_name, sorter_name) - self.sorter_names = np.unique([e for _, e in iter_computed_names(self.study_folder)]).tolist() - self._is_scanned = True + For increased flexibility, cases keys can be a tuple so that we can vary complexity along several + "levels" or "axis" (paremeters or sorters). + In this case, the result dataframes will have `MultiIndex` to handle the different levels. - @classmethod - def create(cls, study_folder, gt_dict, **job_kwargs): - setup_comparison_study(study_folder, gt_dict, **job_kwargs) - return cls(study_folder) + A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see + :py:fun:`~spikeinterface.core.generate.generate_ground_truth_recording()`). - def run_sorters(self, sorter_list, mode_if_folder_exists="keep", remove_sorter_folders=False, **kwargs): - sorter_folders = self.study_folder / "sorter_folders" - recording_dict = get_recordings(self.study_folder) - - run_sorters( - sorter_list, - recording_dict, - sorter_folders, - with_output=False, - mode_if_folder_exists=mode_if_folder_exists, - **kwargs, - ) - - # results are copied so the heavy sorter_folders can be removed - self.copy_sortings() - - if remove_sorter_folders: - shutil.rmtree(self.study_folder / "sorter_folders") - - def _check_rec_name(self, rec_name): - if not self._is_scanned: - self.scan_folder() - if len(self.rec_names) > 1 and rec_name is None: - raise Exception("Pass 'rec_name' parameter to select which recording to use.") - elif len(self.rec_names) == 1: - rec_name = self.rec_names[0] - else: - rec_name = self.rec_names[self.rec_names.index(rec_name)] - return rec_name - - def get_ground_truth(self, rec_name=None): - rec_name = self._check_rec_name(rec_name) - sorting = load_extractor(self.study_folder / "ground_truth" / rec_name) - return sorting - - def get_recording(self, rec_name=None): - rec_name = self._check_rec_name(rec_name) - rec = load_extractor(self.study_folder / "raw_files" / rec_name) - return rec - - def get_sorting(self, sort_name, rec_name=None): - rec_name = self._check_rec_name(rec_name) - - selected_sorting = None - if sort_name in self.sorter_names: - for r_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - if sort_name == sorter_name and r_name == rec_name: - selected_sorting = sorting - return selected_sorting - - def copy_sortings(self): - sorter_folders = self.study_folder / "sorter_folders" - sorting_folders = self.study_folder / "sortings" - log_olders = self.study_folder / "sortings" / "run_log" - - log_olders.mkdir(parents=True, exist_ok=True) - - for rec_name, sorter_name, output_folder in iter_working_folder(sorter_folders): - SorterClass = sorter_dict[sorter_name] - fname = rec_name + "[#]" + sorter_name - npz_filename = sorting_folders / (fname + ".npz") - - try: - sorting = SorterClass.get_result_from_folder(output_folder) - NpzSortingExtractor.write_sorting(sorting, npz_filename) - except: - if npz_filename.is_file(): - npz_filename.unlink() - if (output_folder / "spikeinterface_log.json").is_file(): - shutil.copyfile( - output_folder / "spikeinterface_log.json", sorting_folders / "run_log" / (fname + ".json") - ) + This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. + Note that the underlying folder structure is not backward compatible! + """ - self.scan_folder() + def __init__(self, study_folder): + self.folder = Path(study_folder) - def run_comparisons(self, exhaustive_gt=False, **kwargs): + self.datasets = {} + self.cases = {} + self.sortings = {} self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - sc = compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=exhaustive_gt, **kwargs) - self.comparisons[(rec_name, sorter_name)] = sc - self.exhaustive_gt = exhaustive_gt - def aggregate_run_times(self): - return collect_run_times(self.study_folder) - - def aggregate_performance_by_unit(self): - assert self.comparisons is not None, "run_comparisons first" + self.scan_folder() - perf_by_unit = [] - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - comp = self.comparisons[(rec_name, sorter_name)] + @classmethod + def create(cls, study_folder, datasets={}, cases={}, levels=None): + # check that cases keys are homogeneous + key0 = list(cases.keys())[0] + if isinstance(key0, str): + assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" + if levels is None: + levels = "level0" + else: + assert isinstance(levels, str) + elif isinstance(key0, tuple): + assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" + num_levels = len(key0) + assert all( + len(key) == num_levels for key in cases.keys() + ), "Keys for cases are not homogeneous, tuple negth differ" + if levels is None: + levels = [f"level{i}" for i in range(num_levels)] + else: + levels = list(levels) + assert len(levels) == num_levels + else: + raise ValueError("Keys for cases must str or tuple") - perf = comp.get_performance(method="by_unit", output="pandas") - perf["rec_name"] = rec_name - perf["sorter_name"] = sorter_name - perf = perf.reset_index() - perf_by_unit.append(perf) + study_folder = Path(study_folder) + study_folder.mkdir(exist_ok=False, parents=True) - import pandas as pd + (study_folder / "datasets").mkdir() + (study_folder / "datasets" / "recordings").mkdir() + (study_folder / "datasets" / "gt_sortings").mkdir() + (study_folder / "sorters").mkdir() + (study_folder / "sortings").mkdir() + (study_folder / "sortings" / "run_logs").mkdir() + (study_folder / "metrics").mkdir() - perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(["rec_name", "sorter_name", "gt_unit_id"]) + for key, (rec, gt_sorting) in datasets.items(): + assert "/" not in key, "'/' cannot be in the key name!" + assert "\\" not in key, "'\\' cannot be in the key name!" - return perf_by_unit + # recordings are pickled + rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") - def aggregate_count_units(self, well_detected_score=None, redundant_score=None, overmerged_score=None): - assert self.comparisons is not None, "run_comparisons first" + # sortings are pickled + saved as NumpyFolderSorting + gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") + gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") - import pandas as pd + info = {} + info["levels"] = levels + (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") - index = pd.MultiIndex.from_tuples(self.computed_names, names=["rec_name", "sorter_name"]) + # cases is dumped to a pickle file, json is not possible because of the tuple key + (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) - count_units = pd.DataFrame( - index=index, - columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant", "num_overmerged"], - dtype=int, - ) + return cls(study_folder) - if self.exhaustive_gt: - count_units["num_false_positive"] = pd.Series(dtype=int) - count_units["num_bad"] = pd.Series(dtype=int) + def scan_folder(self): + if not (self.folder / "datasets").exists(): + raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = self.comparisons[(rec_name, sorter_name)] + with open(self.folder / "info.json", "r") as f: + self.info = json.load(f) - count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units( - well_detected_score - ) - if self.exhaustive_gt: - count_units.loc[(rec_name, sorter_name), "num_overmerged"] = comp.count_overmerged_units( - overmerged_score - ) - count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units(redundant_score) - count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units( - redundant_score - ) - count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() + self.levels = self.info["levels"] - return count_units + for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): + key = rec_file.stem + rec = load_extractor(rec_file) + gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) + self.datasets[key] = (rec, gt_sorting) - def aggregate_dataframes(self, copy_into_folder=True, **karg_thresh): - dataframes = {} - dataframes["run_times"] = self.aggregate_run_times().reset_index() - perfs = self.aggregate_performance_by_unit() + with open(self.folder / "cases.pickle", "rb") as f: + self.cases = pickle.load(f) - dataframes["perf_by_unit"] = perfs.reset_index() - dataframes["count_units"] = self.aggregate_count_units(**karg_thresh).reset_index() + self.comparisons = {k: None for k in self.cases} - if copy_into_folder: - tables_folder = self.study_folder / "tables" - tables_folder.mkdir(parents=True, exist_ok=True) + self.sortings = {} + for key in self.cases: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + if sorting_folder.exists(): + sorting = load_extractor(sorting_folder) + else: + sorting = None + self.sortings[key] = sorting - for name, df in dataframes.items(): - df.to_csv(str(tables_folder / (name + ".csv")), sep="\t", index=False) - - return dataframes + def __repr__(self): + t = f"{self.__class__.__name__} {self.folder.stem} \n" + t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" + t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" + num_computed = sum([1 for sorting in self.sortings.values() if sorting is not None]) + t += f" computed: {num_computed}\n" - def get_waveform_extractor(self, rec_name, sorter_name=None): - rec = self.get_recording(rec_name) + return t - if sorter_name is None: - name = "GroundTruth" - sorting = self.get_ground_truth(rec_name) + def key_to_str(self, key): + if isinstance(key, str): + return key + elif isinstance(key, tuple): + return _key_separator.join(key) else: - assert sorter_name in self.sorter_names - name = sorter_name - sorting = self.get_sorting(sorter_name, rec_name) - - waveform_folder = self.study_folder / "waveforms" / f"waveforms_{name}_{rec_name}" + raise ValueError("Keys for cases must str or tuple") + + def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): + if case_keys is None: + case_keys = self.cases.keys() + + job_list = [] + for key in case_keys: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + sorting_exists = sorting_folder.exists() + + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + sorter_folder_exists = sorting_folder.exists() + + if keep: + if sorting_exists: + continue + if sorter_folder_exists: + # the sorter folder exists but havent been copied to sortings folder + sorting = read_sorter_folder(sorter_folder, raise_error=False) + if sorting is not None: + # save and skip + self.copy_sortings(case_keys=[key]) + continue + + if sorting_exists: + # delete older sorting + log before running sorters + shutil.rmtree(sorting_exists) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" + if log_file.exists(): + log_file.unlink() + + params = self.cases[key]["run_sorter_params"].copy() + # this ensure that sorter_name is given + recording, _ = self.datasets[self.cases[key]["dataset"]] + sorter_name = params.pop("sorter_name") + job = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=sorter_folder, + ) + job.update(params) + # the verbose is overwritten and global to all run_sorters + job["verbose"] = verbose + job["with_output"] = False + job_list.append(job) + + run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) + + # TODO later create a list in laucher for engine blocking and non-blocking + if engine not in ("slurm",): + self.copy_sortings(case_keys) + + def copy_sortings(self, case_keys=None, force=True): + if case_keys is None: + case_keys = self.cases.keys() + + for key in case_keys: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" + + if (sorter_folder / "spikeinterface_log.json").exists(): + sorting = read_sorter_folder( + sorter_folder, raise_error=False, register_recording=False, sorting_info=False + ) + else: + sorting = None + + if sorting is not None: + if sorting_folder.exists(): + if force: + # delete folder + log + shutil.rmtree(sorting_folder) + if log_file.exists(): + log_file.unlink() + else: + continue + + sorting = sorting.save(format="numpy_folder", folder=sorting_folder) + self.sortings[key] = sorting + + # copy logs + shutil.copyfile(sorter_folder / "spikeinterface_log.json", log_file) + + def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): + if case_keys is None: + case_keys = self.cases.keys() + + for key in case_keys: + dataset_key = self.cases[key]["dataset"] + _, gt_sorting = self.datasets[dataset_key] + sorting = self.sortings[key] + if sorting is None: + self.comparisons[key] = None + continue + comp = comparison_class(gt_sorting, sorting, **kwargs) + self.comparisons[key] = comp + + def get_run_times(self, case_keys=None): + import pandas as pd - if waveform_folder.is_dir(): - we = WaveformExtractor.load(waveform_folder) - else: - we = WaveformExtractor.create(rec, sorting, waveform_folder) + if case_keys is None: + case_keys = self.cases.keys() + + log_folder = self.folder / "sortings" / "run_logs" + + run_times = {} + for key in case_keys: + log_file = log_folder / f"{self.key_to_str(key)}.json" + with open(log_file, mode="r") as logfile: + log = json.load(logfile) + run_time = log.get("run_time", None) + run_times[key] = run_time + + return pd.Series(run_times, name="run_time") + + def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): + if case_keys is None: + case_keys = self.cases.keys() + + base_folder = self.folder / "waveforms" + base_folder.mkdir(exist_ok=True) + + dataset_keys = [self.cases[key]["dataset"] for key in case_keys] + dataset_keys = set(dataset_keys) + for dataset_key in dataset_keys: + # the waveforms depend on the dataset key + wf_folder = base_folder / self.key_to_str(dataset_key) + recording, gt_sorting = self.datasets[dataset_key] + we = extract_waveforms(recording, gt_sorting, folder=wf_folder) + + def get_waveform_extractor(self, key): + # some recording are not dumpable to json and the waveforms extactor need it! + # so we load it with and put after + # this should be fixed in PR 2027 so remove this after + + dataset_key = self.cases[key]["dataset"] + wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) + we = load_waveforms(wf_folder, with_recording=False) + recording, _ = self.datasets[dataset_key] + we.set_recording(recording) return we - def compute_waveforms( - self, - rec_name, - sorter_name=None, - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - n_jobs=-1, - total_memory="1G", - ): - we = self.get_waveform_extractor(rec_name, sorter_name) - we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) - we.run_extract_waveforms(n_jobs=n_jobs, total_memory=total_memory) - - def get_templates(self, rec_name, sorter_name=None, mode="median"): - """ - Get template for a given recording. - - If sorter_name=None then template are from the ground truth. - - """ - we = self.get_waveform_extractor(rec_name, sorter_name=sorter_name) + def get_templates(self, key, mode="average"): + we = self.get_waveform_extractor(key) templates = we.get_all_templates(mode=mode) return templates - def compute_metrics( - self, - rec_name, - metric_names=["snr"], - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - n_jobs=-1, - total_memory="1G", - ): - we = self.get_waveform_extractor(rec_name) - we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) - we.run_extract_waveforms(n_jobs=n_jobs, total_memory=total_memory) - - # metrics - metrics = compute_quality_metrics(we, metric_names=metric_names) - folder = self.study_folder / "metrics" - folder.mkdir(exist_ok=True) - filename = folder / f"metrics _{rec_name}.txt" - metrics.to_csv(filename, sep="\t", index=True) + def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): + if case_keys is None: + case_keys = self.cases.keys() + + done = [] + for key in case_keys: + dataset_key = self.cases[key]["dataset"] + if dataset_key in done: + # some case can share the same waveform extractor + continue + done.append(dataset_key) + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + if filename.exists(): + if force: + os.remove(filename) + else: + continue + we = self.get_waveform_extractor(key) + metrics = compute_quality_metrics(we, metric_names=metric_names) + metrics.to_csv(filename, sep="\t", index=True) + + def get_metrics(self, key): + import pandas as pd + + dataset_key = self.cases[key]["dataset"] + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + if not filename.exists(): + return + metrics = pd.read_csv(filename, sep="\t", index_col=0) + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + metrics.index = gt_sorting.unit_ids return metrics - def get_metrics(self, rec_name=None, **metric_kwargs): - """ - Load or compute units metrics for a given recording. - """ - rec_name = self._check_rec_name(rec_name) - metrics_folder = self.study_folder / "metrics" - metrics_folder.mkdir(parents=True, exist_ok=True) + def get_units_snr(self, key): + """ """ + return self.get_metrics(key)["snr"] - filename = self.study_folder / "metrics" / f"metrics _{rec_name}.txt" + def get_performance_by_unit(self, case_keys=None): import pandas as pd - if filename.is_file(): - metrics = pd.read_csv(filename, sep="\t", index_col=0) - gt_sorting = self.get_ground_truth(rec_name) - metrics.index = gt_sorting.unit_ids + if case_keys is None: + case_keys = self.cases.keys() + + perf_by_unit = [] + for key in case_keys: + comp = self.comparisons.get(key, None) + assert comp is not None, "You need to do study.run_comparisons() first" + + perf = comp.get_performance(method="by_unit", output="pandas") + if isinstance(key, str): + perf[self.levels] = key + elif isinstance(key, tuple): + for col, k in zip(self.levels, key): + perf[col] = k + + perf = perf.reset_index() + perf_by_unit.append(perf) + + perf_by_unit = pd.concat(perf_by_unit) + perf_by_unit = perf_by_unit.set_index(self.levels) + return perf_by_unit + + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) else: - metrics = self.compute_metrics(rec_name, **metric_kwargs) + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) - metrics.index.name = "unit_id" - # add rec name columns - metrics["rec_name"] = rec_name + columns = ["num_gt", "num_sorter", "num_well_detected"] + comp = self.comparisons[case_keys[0]] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) - return metrics + for key in case_keys: + comp = self.comparisons.get(key, None) + assert comp is not None, "You need to do study.run_comparisons() first" - def get_units_snr(self, rec_name=None, **metric_kwargs): - """ """ - metric = self.get_metrics(rec_name=rec_name, **metric_kwargs) - return metric["snr"] - - def concat_all_snr(self): - metrics = [] - for rec_name in self.rec_names: - df = self.get_metrics(rec_name) - df = df.reset_index() - metrics.append(df) - metrics = pd.concat(metrics) - metrics = metrics.set_index(["rec_name", "unit_id"]) - return metrics["snr"] + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index af410255b9..e0c98cd772 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -39,7 +39,7 @@ class HybridUnitsRecording(InjectTemplatesRecording): The refractory period of the injected spike train (in ms). injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serialisable to file. Returns ------- @@ -84,7 +84,8 @@ def __init__( ) # save injected sorting if necessary self.injected_sorting = injected_sorting - if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) @@ -137,7 +138,7 @@ class HybridSpikesRecording(InjectTemplatesRecording): this refractory period. injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serializable to file. Returns ------- @@ -180,7 +181,8 @@ def __init__( self.injected_sorting = injected_sorting # save injected sorting if necessary - if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 9e02fd5b2d..f44e14c4c4 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -1,6 +1,7 @@ from pathlib import Path import json import pickle +import warnings import numpy as np @@ -180,9 +181,16 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou return sorting def save_to_folder(self, save_folder): + warnings.warn( + "save_to_folder() is deprecated. " + "You should save and load the multi sorting comparison object using pickle." + "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + DeprecationWarning, + stacklevel=2, + ) for sorting in self.object_list: - assert ( - sorting.check_if_json_serializable() + assert sorting.check_serializablility( + "json" ), "MultiSortingComparison.save_to_folder() need json serializable sortings" save_folder = Path(save_folder) @@ -205,6 +213,13 @@ def save_to_folder(self, save_folder): @staticmethod def load_from_folder(folder_path): + warnings.warn( + "load_from_folder() is deprecated. " + "You should save and load the multi sorting comparison object using pickle." + "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + DeprecationWarning, + stacklevel=2, + ) folder_path = Path(folder_path) with (folder_path / "kwargs.json").open() as f: kwargs = json.load(f) @@ -244,7 +259,8 @@ def __init__( BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) - self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = True if len(unit_ids) > 0: for k in ("agreement_number", "avg_agreement", "unit_ids"): diff --git a/src/spikeinterface/comparison/studytools.py b/src/spikeinterface/comparison/studytools.py deleted file mode 100644 index 26d2c1ad6f..0000000000 --- a/src/spikeinterface/comparison/studytools.py +++ /dev/null @@ -1,349 +0,0 @@ -""" -High level tools to run many ground-truth comparison with -many sorter on many recordings and then collect and aggregate results -in an easy way. - -The all mechanism is based on an intrinsic organization -into a "study_folder" with several subfolder: - * raw_files : contain a copy in binary format of recordings - * sorter_folders : contains output of sorters - * ground_truth : contains a copy of sorting ground in npz format - * sortings: contains light copy of all sorting in npz format - * tables: some table in cvs format -""" - -from pathlib import Path -import shutil -import json -import os - - -from spikeinterface.core import load_extractor -from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.extractors import NpzSortingExtractor -from spikeinterface.sorters import sorter_dict -from spikeinterface.sorters.basesorter import is_log_ok - - -from .comparisontools import _perf_keys -from .paircomparisons import compare_sorter_to_ground_truth - - -# This is deprecated and will be removed -def iter_working_folder(working_folder): - working_folder = Path(working_folder) - for rec_folder in working_folder.iterdir(): - if not rec_folder.is_dir(): - continue - for output_folder in rec_folder.iterdir(): - if (output_folder / "spikeinterface_job.json").is_file(): - with open(output_folder / "spikeinterface_job.json", "r") as f: - job_dict = json.load(f) - rec_name = job_dict["rec_name"] - sorter_name = job_dict["sorter_name"] - yield rec_name, sorter_name, output_folder - else: - rec_name = rec_folder.name - sorter_name = output_folder.name - if not output_folder.is_dir(): - continue - if not is_log_ok(output_folder): - continue - yield rec_name, sorter_name, output_folder - - -# This is deprecated and will be removed -def iter_sorting_output(working_folder): - """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" - for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): - SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) - yield rec_name, sorter_name, sorting - - -def setup_comparison_study(study_folder, gt_dict, **job_kwargs): - """ - Based on a dict of (recording, sorting) create the study folder. - - Parameters - ---------- - study_folder: str - The study folder. - gt_dict : a dict of tuple (recording, sorting_gt) - Dict of tuple that contain recording and sorting ground truth - """ - job_kwargs = fix_job_kwargs(job_kwargs) - study_folder = Path(study_folder) - assert not study_folder.is_dir(), "'study_folder' already exists. Please remove it" - - study_folder.mkdir(parents=True, exist_ok=True) - sorting_folders = study_folder / "sortings" - log_folder = sorting_folders / "run_log" - log_folder.mkdir(parents=True, exist_ok=True) - tables_folder = study_folder / "tables" - tables_folder.mkdir(parents=True, exist_ok=True) - - for rec_name, (recording, sorting_gt) in gt_dict.items(): - # write recording using save with binary - folder = study_folder / "ground_truth" / rec_name - sorting_gt.save(folder=folder, format="numpy_folder") - folder = study_folder / "raw_files" / rec_name - recording.save(folder=folder, format="binary", **job_kwargs) - - # make an index of recording names - with open(study_folder / "names.txt", mode="w", encoding="utf8") as f: - for rec_name in gt_dict: - f.write(rec_name + "\n") - - -def get_rec_names(study_folder): - """ - Get list of keys of recordings. - Read from the 'names.txt' file in study folder. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - rec_names: list - List of names. - """ - study_folder = Path(study_folder) - with open(study_folder / "names.txt", mode="r", encoding="utf8") as f: - rec_names = f.read()[:-1].split("\n") - return rec_names - - -def get_recordings(study_folder): - """ - Get ground recording as a dict. - - They are read from the 'raw_files' folder with binary format. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - recording_dict: dict - Dict of recording. - """ - study_folder = Path(study_folder) - - rec_names = get_rec_names(study_folder) - recording_dict = {} - for rec_name in rec_names: - rec = load_extractor(study_folder / "raw_files" / rec_name) - recording_dict[rec_name] = rec - - return recording_dict - - -def get_ground_truths(study_folder): - """ - Get ground truth sorting extractor as a dict. - - They are read from the 'ground_truth' folder with npz format. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - ground_truths: dict - Dict of sorting_gt. - """ - study_folder = Path(study_folder) - rec_names = get_rec_names(study_folder) - ground_truths = {} - for rec_name in rec_names: - sorting = load_extractor(study_folder / "ground_truth" / rec_name) - ground_truths[rec_name] = sorting - return ground_truths - - -def iter_computed_names(study_folder): - sorting_folder = Path(study_folder) / "sortings" - for filename in os.listdir(sorting_folder): - if filename.endswith(".npz") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".npz", "").split("[#]") - yield rec_name, sorter_name - - -def iter_computed_sorting(study_folder): - """ - Iter over sorting files. - """ - sorting_folder = Path(study_folder) / "sortings" - for filename in os.listdir(sorting_folder): - if filename.endswith(".npz") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".npz", "").split("[#]") - sorting = NpzSortingExtractor(sorting_folder / filename) - yield rec_name, sorter_name, sorting - - -def collect_run_times(study_folder): - """ - Collect run times in a working folder and store it in CVS files. - - The output is list of (rec_name, sorter_name, run_time) - """ - import pandas as pd - - study_folder = Path(study_folder) - sorting_folders = study_folder / "sortings" - log_folder = sorting_folders / "run_log" - tables_folder = study_folder / "tables" - - tables_folder.mkdir(parents=True, exist_ok=True) - - run_times = [] - for filename in os.listdir(log_folder): - if filename.endswith(".json") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".json", "").split("[#]") - with open(log_folder / filename, encoding="utf8", mode="r") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - run_times.append((rec_name, sorter_name, run_time)) - - run_times = pd.DataFrame(run_times, columns=["rec_name", "sorter_name", "run_time"]) - run_times = run_times.set_index(["rec_name", "sorter_name"]) - - return run_times - - -def aggregate_sorting_comparison(study_folder, exhaustive_gt=False): - """ - Loop over output folder in a tree to collect sorting output and run - ground_truth_comparison on them. - - Parameters - ---------- - study_folder: str - The study folder. - exhaustive_gt: bool (default True) - Tell if the ground true is "exhaustive" or not. In other world if the - GT have all possible units. It allows more performance measurement. - For instance, MEArec simulated dataset have exhaustive_gt=True - - Returns - ---------- - comparisons: a dict of SortingComparison - - """ - - study_folder = Path(study_folder) - - ground_truths = get_ground_truths(study_folder) - results = collect_study_sorting(study_folder) - - comparisons = {} - for (rec_name, sorter_name), sorting in results.items(): - gt_sorting = ground_truths[rec_name] - sc = compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=exhaustive_gt) - comparisons[(rec_name, sorter_name)] = sc - - return comparisons - - -def aggregate_performances_table(study_folder, exhaustive_gt=False, **karg_thresh): - """ - Aggregate some results into dataframe to have a "study" overview on all recordingXsorter. - - Tables are: - * run_times: run times per recordingXsorter - * perf_pooled_with_sum: GroundTruthComparison.see get_performance - * perf_pooled_with_average: GroundTruthComparison.see get_performance - * count_units: given some threshold count how many units : 'well_detected', 'redundant', 'false_postive_units, 'bad' - - Parameters - ---------- - study_folder: str - The study folder. - karg_thresh: dict - Threshold parameters used for the "count_units" table. - - Returns - ------- - dataframes: a dict of DataFrame - Return several useful DataFrame to compare all results. - Note that count_units depend on karg_thresh. - """ - import pandas as pd - - study_folder = Path(study_folder) - sorter_folders = study_folder / "sorter_folders" - tables_folder = study_folder / "tables" - - comparisons = aggregate_sorting_comparison(study_folder, exhaustive_gt=exhaustive_gt) - ground_truths = get_ground_truths(study_folder) - results = collect_study_sorting(study_folder) - - study_folder = Path(study_folder) - - dataframes = {} - - # get run times: - run_times = pd.read_csv(str(tables_folder / "run_times.csv"), sep="\t") - run_times.columns = ["rec_name", "sorter_name", "run_time"] - run_times = run_times.set_index( - [ - "rec_name", - "sorter_name", - ] - ) - dataframes["run_times"] = run_times - - perf_pooled_with_sum = pd.DataFrame(index=run_times.index, columns=_perf_keys) - dataframes["perf_pooled_with_sum"] = perf_pooled_with_sum - - perf_pooled_with_average = pd.DataFrame(index=run_times.index, columns=_perf_keys) - dataframes["perf_pooled_with_average"] = perf_pooled_with_average - - count_units = pd.DataFrame( - index=run_times.index, columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant"] - ) - dataframes["count_units"] = count_units - if exhaustive_gt: - count_units["num_false_positive"] = None - count_units["num_bad"] = None - - perf_by_spiketrain = [] - - for (rec_name, sorter_name), comp in comparisons.items(): - gt_sorting = ground_truths[rec_name] - sorting = results[(rec_name, sorter_name)] - - perf = comp.get_performance(method="pooled_with_sum", output="pandas") - perf_pooled_with_sum.loc[(rec_name, sorter_name), :] = perf - - perf = comp.get_performance(method="pooled_with_average", output="pandas") - perf_pooled_with_average.loc[(rec_name, sorter_name), :] = perf - - perf = comp.get_performance(method="by_spiketrain", output="pandas") - perf["rec_name"] = rec_name - perf["sorter_name"] = sorter_name - perf = perf.reset_index() - - perf_by_spiketrain.append(perf) - - count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units(**karg_thresh) - count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units() - if exhaustive_gt: - count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units() - count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() - - perf_by_spiketrain = pd.concat(perf_by_spiketrain) - perf_by_spiketrain = perf_by_spiketrain.set_index(["rec_name", "sorter_name", "gt_unit_id"]) - dataframes["perf_by_spiketrain"] = perf_by_spiketrain - - return dataframes diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 70f8a63c8c..91c8c640e0 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -1,19 +1,11 @@ -import importlib import shutil import pytest from pathlib import Path -from spikeinterface.extractors import toy_example -from spikeinterface.sorters import installed_sorters +from spikeinterface import generate_ground_truth_recording +from spikeinterface.preprocessing import bandpass_filter from spikeinterface.comparison import GroundTruthStudy -try: - import tridesclous - - HAVE_TDC = True -except ImportError: - HAVE_TDC = False - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "comparison" @@ -27,61 +19,85 @@ def setup_module(): if study_folder.is_dir(): shutil.rmtree(study_folder) - _setup_comparison_study() + create_a_study(study_folder) + + +def simple_preprocess(rec): + return bandpass_filter(rec) -def _setup_comparison_study(): - rec0, gt_sorting0 = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) - rec1, gt_sorting1 = toy_example(num_channels=32, duration=30, seed=0, num_segments=1) +def create_a_study(study_folder): + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=91) - gt_dict = { + datasets = { "toy_tetrode": (rec0, gt_sorting0), "toy_probe32": (rec1, gt_sorting1), + "toy_probe32_preprocess": (simple_preprocess(rec1), gt_sorting1), } - study = GroundTruthStudy.create(study_folder, gt_dict) + # cases can also be generated via simple loops + cases = { + # + ("tdc2", "no-preprocess", "tetrode"): { + "label": "tridesclous2 without preprocessing and standard params", + "dataset": "toy_tetrode", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + "comparison_params": {}, + }, + # + ("tdc2", "with-preprocess", "probe32"): { + "label": "tridesclous2 with preprocessing standar params", + "dataset": "toy_probe32_preprocess", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + "comparison_params": {}, + }, + # we comment this at the moement because SC2 is quite slow for testing + # ("sc2", "no-preprocess", "tetrode"): { + # "label": "spykingcircus2 without preprocessing standar params", + # "dataset": "toy_tetrode", + # "run_sorter_params": { + # "sorter_name": "spykingcircus2", + # }, + # "comparison_params": { + # }, + # }, + } -@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -def test_run_study_sorters(): - study = GroundTruthStudy(study_folder) - sorter_list = [ - "tridesclous", - ] - print( - f"\n#################################\nINSTALLED SORTERS\n#################################\n" - f"{installed_sorters()}" + study = GroundTruthStudy.create( + study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"] ) - study.run_sorters(sorter_list) + # print(study) -@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -def test_extract_sortings(): +def test_GroundTruthStudy(): study = GroundTruthStudy(study_folder) + print(study) - study.copy_sortings() - - for rec_name in study.rec_names: - gt_sorting = study.get_ground_truth(rec_name) - - for rec_name in study.rec_names: - metrics = study.get_metrics(rec_name=rec_name) + study.run_sorters(verbose=True) - snr = study.get_units_snr(rec_name=rec_name) + print(study.sortings) - study.copy_sortings() + print(study.comparisons) + study.run_comparisons() + print(study.comparisons) - run_times = study.aggregate_run_times() + study.extract_waveforms_gt(n_jobs=-1) - study.run_comparisons(exhaustive_gt=True) + study.compute_metrics() - perf = study.aggregate_performance_by_unit() + for key in study.cases: + metrics = study.get_metrics(key) + print(metrics) - count_units = study.aggregate_count_units() - dataframes = study.aggregate_dataframes() - print(dataframes) + study.get_performance_by_unit() + study.get_count_units() if __name__ == "__main__": - # setup_module() - # test_run_study_sorters() - test_extract_sortings() + setup_module() + test_GroundTruthStudy() diff --git a/src/spikeinterface/comparison/tests/test_studytools.py b/src/spikeinterface/comparison/tests/test_studytools.py deleted file mode 100644 index dbc39d5e1d..0000000000 --- a/src/spikeinterface/comparison/tests/test_studytools.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import shutil -from pathlib import Path - -import pytest - -from spikeinterface.extractors import toy_example -from spikeinterface.comparison.studytools import ( - setup_comparison_study, - iter_computed_names, - iter_computed_sorting, - get_rec_names, - get_ground_truths, - get_recordings, -) - -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "comparison" -else: - cache_folder = Path("cache_folder") / "comparison" - - -study_folder = cache_folder / "test_studytools" - - -def setup_module(): - if study_folder.is_dir(): - shutil.rmtree(study_folder) - - -def test_setup_comparison_study(): - rec0, gt_sorting0 = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) - rec1, gt_sorting1 = toy_example(num_channels=32, duration=30, seed=0, num_segments=1) - - gt_dict = { - "toy_tetrode": (rec0, gt_sorting0), - "toy_probe32": (rec1, gt_sorting1), - } - setup_comparison_study(study_folder, gt_dict) - - -def test_get_ground_truths(): - names = get_rec_names(study_folder) - d = get_ground_truths(study_folder) - d = get_recordings(study_folder) - - -def test_loops(): - names = list(iter_computed_names(study_folder)) - for rec_name, sorter_name, sorting in iter_computed_sorting(study_folder): - print(rec_name, sorter_name) - print(sorting) - - -if __name__ == "__main__": - setup_module() - test_setup_comparison_study() - test_get_ground_truths() - test_loops() diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 87c0805630..1430e8fb45 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -57,8 +57,7 @@ def __init__(self, main_ids: Sequence) -> None: # * number of units for sorting self._properties = {} - self._is_dumpable = True - self._is_json_serializable = True + self._serializablility = {"memory": True, "json": True, "pickle": True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -425,14 +424,15 @@ def from_dict(dictionary: dict, base_folder: Optional[Union[Path, str]] = None) extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ - if dictionary["relative_paths"]: + # for pickle dump relative_path was not in the dict, this ensure compatibility + if dictionary.get("relative_paths", False): assert base_folder is not None, "When relative_paths=True, need to provide base_folder" dictionary = _make_paths_absolute(dictionary, base_folder) extractor = _load_extractor_from_dict(dictionary) folder_metadata = dictionary.get("folder_metadata", None) if folder_metadata is not None: folder_metadata = Path(folder_metadata) - if dictionary["relative_paths"]: + if dictionary.get("relative_paths", False): folder_metadata = base_folder / folder_metadata extractor.load_metadata_from_folder(folder_metadata) return extractor @@ -471,24 +471,33 @@ def clone(self) -> "BaseExtractor": clone = BaseExtractor.from_dict(d) return clone - def check_if_dumpable(self): - """Check if the object is dumpable, including nested objects. + def check_serializablility(self, type): + kwargs = self._kwargs + for value in kwargs.values(): + # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors + if isinstance(value, BaseExtractor): + if not value.check_serializablility(type=type): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + return self._serializablility[type] + + def check_if_memory_serializable(self): + """ + Check if the object is serializable to memory with pickle, including nested objects. Returns ------- bool - True if the object is dumpable, False otherwise. + True if the object is memory serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_dumpable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_dumpable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_dumpable() for k, v in value.items()]) - return self._is_dumpable + return self.check_serializablility("memory") def check_if_json_serializable(self): """ @@ -499,16 +508,13 @@ def check_if_json_serializable(self): bool True if the object is json serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_json_serializable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_json_serializable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_json_serializable() for k, v in value.items()]) - return self._is_json_serializable + # we keep this for backward compatilibity or not ???? + # is this needed ??? I think no. + return self.check_serializablility("json") + + def check_if_pickle_serializable(self): + # is this needed ??? I think no. + return self.check_serializablility("pickle") @staticmethod def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path: @@ -557,7 +563,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No if str(file_path).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata) + self.dump_to_pickle(file_path, folder_metadata=folder_metadata) else: raise ValueError("Dump: file must .json or .pkl") @@ -576,7 +582,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_json_serializable(), "The extractor is not json serializable" + assert self.check_serializablility("json"), "The extractor is not json serializable" # Writing paths as relative_to requires recursively expanding the dict if relative_to: @@ -616,12 +622,13 @@ def dump_to_pickle( folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_dumpable(), "The extractor is not dumpable" + assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle" dump_dict = self.to_dict( include_annotations=True, include_properties=include_properties, folder_metadata=folder_metadata, + relative_to=None, recursive=False, ) file_path = self._get_file_path(file_path, [".pkl", ".pickle"]) @@ -653,8 +660,8 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") - if "warning" in d and "not dumpable" in d["warning"]: - print("The extractor was not dumpable") + if "warning" in d: + print("The extractor was not serializable to file") return None extractor = BaseExtractor.from_dict(d, base_folder=base_folder) return extractor @@ -814,10 +821,12 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): # dump provenance provenance_file = folder / f"provenance.json" - if self.check_if_json_serializable(): + if self.check_serializablility("json"): self.dump(provenance_file) else: - provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8") + provenance_file.write_text( + json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" + ) self.save_metadata_to_folder(folder) @@ -911,7 +920,7 @@ def save_to_zarr( zarr_root = zarr.open(zarr_path_init, mode="w", storage_options=storage_options) - if self.check_if_dumpable(): + if self.check_if_json_serializable(): zarr_root.attrs["provenance"] = check_json(self.to_dict()) else: zarr_root.attrs["provenance"] = None diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 07837bcef7..eeb1e8af60 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1056,6 +1056,8 @@ def __init__( dtype = parent_recording.dtype if parent_recording is not None else templates.dtype BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype) + # Important : self._serializablility is not change here because it will depend on the sorting parents itself. + n_units = len(sorting.unit_ids) assert len(templates) == n_units self.spike_vector = sorting.to_spike_vector() @@ -1431,5 +1433,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) return recording, sorting diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c0ee77d2fd..84ee502c14 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -167,11 +167,11 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.check_if_dumpable(): + if not recording.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( - "Recording is not dumpable and can't be processed in parallel. " - "You can use the `recording.save()` function to make it dumpable or set 'n_jobs' to 1." + "Recording is not serializable to memory and can't be processed in parallel. " + "You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1." ) return n_jobs diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index d5663156c7..3d7ec6cd1a 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -64,7 +64,8 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" t_starts = [float(t_start) for t_start in t_starts] - self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for i, traces in enumerate(traces_list): if t_starts is None: @@ -126,8 +127,10 @@ def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True - self._is_json_serializable = False + self._serializablility["memory"] = True + self._serializablility["json"] = False + # theorically this should be False but for simplicity make generators simples we still need this. + self._serializablility["pickle"] = True if spikes.size == 0: nseg = 1 @@ -357,8 +360,10 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting" BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True - self._is_json_serializable = False + + self._serializablility["memory"] = True + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.shm = SharedMemory(shm_name, create=False) self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) @@ -516,8 +521,9 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore dtype=dtype, ) - self._is_dumpable = False - self._is_json_serializable = False + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for snippets, spikesframes in zip(snippets_list, spikesframes_list): snp_segment = NumpySnippetsSegment(snippets, spikesframes) diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 1ff31127f4..879700cc15 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -181,9 +181,10 @@ def __init__(self, oldapi_recording_extractor): dtype=oldapi_recording_extractor.get_dtype(return_scaled=False), ) - # set _is_dumpable to False to use dumping mechanism of old extractor - self._is_dumpable = False - self._is_json_serializable = False + # set to False to use dumping mechanism of old extractor + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.annotate(is_filtered=oldapi_recording_extractor.is_filtered) @@ -268,8 +269,9 @@ def __init__(self, oldapi_sorting_extractor): sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor) self.add_sorting_segment(sorting_segment) - self._is_dumpable = False - self._is_json_serializable = False + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False # add old properties copy_properties(oldapi_extractor=oldapi_sorting_extractor, new_extractor=self) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 4c3680b021..8c5c62d568 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from .recording_tools import get_channel_distances, get_noise_levels @@ -33,7 +35,9 @@ class ChannelSparsity: """ - Handle channel sparsity for a set of units. + Handle channel sparsity for a set of units. That is, for every unit, + it indicates which channels are used to represent the waveform and the rest + of the non-represented channels are assumed to be zero. Internally, sparsity is stored as a boolean mask. @@ -92,13 +96,17 @@ def __init__(self, mask, unit_ids, channel_ids): assert self.mask.shape[0] == self.unit_ids.shape[0] assert self.mask.shape[1] == self.channel_ids.shape[0] - # some precomputed dict + # Those are computed at first call self._unit_id_to_channel_ids = None self._unit_id_to_channel_indices = None + self.num_channels = self.channel_ids.size + self.num_units = self.unit_ids.size + self.max_num_active_channels = self.mask.sum(axis=1).max() + def __repr__(self): - ratio = np.mean(self.mask) - txt = f"ChannelSparsity - units: {self.unit_ids.size} - channels: {self.channel_ids.size} - ratio: {ratio:0.2f}" + density = np.mean(self.mask) + txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}" return txt @property @@ -119,6 +127,85 @@ def unit_id_to_channel_indices(self): self._unit_id_to_channel_indices[unit_id] = channel_inds return self._unit_id_to_channel_indices + def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray: + """ + Sparsify the waveforms according to a unit_id corresponding sparsity. + + + Given a unit_id, this method selects only the active channels for + that unit and removes the rest. + + Parameters + ---------- + waveforms : np.array + Dense waveforms with shape (num_waveforms, num_samples, num_channels) or a + single dense waveform (template) with shape (num_samples, num_channels). + unit_id : str + The unit_id for which to sparsify the waveform. + + Returns + ------- + sparsified_waveforms : np.array + Sparse waveforms with shape (num_waveforms, num_samples, num_active_channels) + or a single sparsified waveform (template) with shape (num_samples, num_active_channels). + """ + + assert_msg = ( + "Waveforms must be dense to sparsify them. " + f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}" + ) + assert self.are_waveforms_dense(waveforms=waveforms), assert_msg + + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + sparsified_waveforms = waveforms[..., non_zero_indices] + + return sparsified_waveforms + + def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray: + """ + Densify sparse waveforms that were sparisified according to a unit's channel sparsity. + + Given a unit_id its sparsified waveform, this method places the waveform back + into its original form within a dense array. + + Parameters + ---------- + waveforms : np.array + The sparsified waveforms array of shape (num_waveforms, num_samples, num_active_channels) or a single + sparse waveform (template) with shape (num_samples, num_active_channels). + unit_id : str + The unit_id that was used to sparsify the waveform. + + Returns + ------- + densified_waveforms : np.array + The densified waveforms array of shape (num_waveforms, num_samples, num_channels) or a single dense + waveform (template) with shape (num_samples, num_channels). + + """ + + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + + assert_msg = ( + "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " + f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels." + ) + assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg + + densified_shape = waveforms.shape[:-1] + (self.num_channels,) + densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) + densified_waveforms[..., non_zero_indices] = waveforms + + return densified_waveforms + + def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: + return waveforms.shape[-1] == self.num_channels + + def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool: + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + return waveforms.shape[-1] == num_active_channels + @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): """ @@ -144,16 +231,16 @@ def to_dict(self): ) @classmethod - def from_dict(cls, d): + def from_dict(cls, dictionary: dict): unit_id_to_channel_ids_corrected = {} - for unit_id in d["unit_ids"]: - if unit_id in d["unit_id_to_channel_ids"]: - unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][unit_id] + for unit_id in dictionary["unit_ids"]: + if unit_id in dictionary["unit_id_to_channel_ids"]: + unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][unit_id] else: - unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][str(unit_id)] - d["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected + unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][str(unit_id)] + dictionary["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected - return cls.from_unit_id_to_channel_ids(**d) + return cls.from_unit_id_to_channel_ids(**dictionary) ## Some convinient function to compute sparsity from several strategy @classmethod diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index ea1a9cf0d2..a944be3da0 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -31,39 +31,39 @@ def make_nested_extractors(extractor): ) -def test_check_if_dumpable(): +def test_check_if_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - extractors_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_dumpable: - assert extractor.check_if_dumpable() + # make a list of memory serializable objects + extractors_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_mem_serializable: + assert extractor.check_if_memory_serializable() - # make not dumpable - test_extractor._is_dumpable = False - extractors_not_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_not_dumpable: - assert not extractor.check_if_dumpable() + # make not not memory serilizable + test_extractor._serializablility["memory"] = False + extractors_not_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_not_mem_serializable: + assert not extractor.check_if_memory_serializable() -def test_check_if_json_serializable(): +def test_check_if_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - test_extractor._is_json_serializable = True + # make a list of json serializable objects + test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) - assert extractor.check_if_json_serializable() + assert extractor.check_serializablility("json") - # make not dumpable - test_extractor._is_json_serializable = False + # make of not json serializable objects + test_extractor._serializablility["json"] = False extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: print(extractor) - assert not extractor.check_if_json_serializable() + assert not extractor.check_serializablility("json") if __name__ == "__main__": - test_check_if_dumpable() - test_check_if_json_serializable() + test_check_if_memory_serializable() + test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index a3cd0caa92..223b2a8a3a 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -142,7 +142,6 @@ def test_write_memory_recording(): recording = NoiseGeneratorRecording( num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" ) - # make dumpable recording = recording.save() # write with loop diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 7d7af6025b..a904e4dd32 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -36,7 +36,7 @@ def test_ensure_n_jobs(): n_jobs = ensure_n_jobs(recording, n_jobs=1) assert n_jobs == 1 - # dumpable + # check serializable n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1) assert n_jobs > 1 @@ -45,7 +45,7 @@ def test_ensure_chunk_size(): recording = generate_recording(num_channels=2) dtype = recording.get_dtype() assert dtype == "float32" - # make dumpable + # make serializable recording = recording.save() chunk_size = ensure_chunk_size(recording, total_memory="512M", chunk_size=None, chunk_memory=None, n_jobs=2) @@ -90,7 +90,7 @@ def init_func(arg1, arg2, arg3): def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) - # make dumpable + # make serializable recording = recording.save() init_args = "a", 120, "yep" diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 473648c5ec..1c491bd7a6 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -142,9 +142,11 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract self.extractor_list = extractor_list self.extractor_dict = extractor_dict + BaseExtractor.__init__(self, main_ids=["1", "2"]) # this already the case by default - self._is_dumpable = True - self._is_json_serializable = True + self._serializablility["memory"] = True + self._serializablility["json"] = True + self._serializablility["pickle"] = True self._kwargs = { "attribute": attribute, @@ -195,3 +197,8 @@ def test_encoding_numpy_scalars_within_nested_extractors_list(nested_extractor_l def test_encoding_numpy_scalars_within_nested_extractors_dict(nested_extractor_dict): json.dumps(nested_extractor_dict, cls=SIJsonEncoder) + + +if __name__ == "__main__": + nested_extractor = nested_extractor() + test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 75182bf532..ac114ac161 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -55,5 +55,93 @@ def test_ChannelSparsity(): assert np.array_equal(sparsity.mask, sparsity4.mask) +def test_sparsify_waveforms(): + seed = 0 + rng = np.random.default_rng(seed=seed) + + num_units = 3 + num_samples = 5 + num_channels = 4 + + is_mask_valid = False + while not is_mask_valid: + sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool") + is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0) + + unit_ids = np.arange(num_units) + channel_ids = np.arange(num_channels) + sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids) + + for unit_id in unit_ids: + waveforms_dense = rng.random(size=(num_units, num_samples, num_channels)) + + # Test are_waveforms_dense + assert sparsity.are_waveforms_dense(waveforms_dense) + + # Test sparsify + waveforms_sparse = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id) + non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels) + + # Test round-trip (note that this is loosy) + unit_id = unit_ids[unit_id] + non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] + waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) + assert np.array_equal(waveforms_dense[..., non_zero_indices], waveforms_dense2[..., non_zero_indices]) + + # Test sparsify with one waveform (template) + template_dense = waveforms_dense.mean(axis=0) + template_sparse = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id) + assert template_sparse.shape == (num_samples, num_active_channels) + + # Test round trip with template + template_dense2 = sparsity.densify_waveforms(template_sparse, unit_id=unit_id) + assert np.array_equal(template_dense[..., non_zero_indices], template_dense2[:, non_zero_indices]) + + +def test_densify_waveforms(): + seed = 0 + rng = np.random.default_rng(seed=seed) + + num_units = 3 + num_samples = 5 + num_channels = 4 + + is_mask_valid = False + while not is_mask_valid: + sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool") + is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0) + + unit_ids = np.arange(num_units) + channel_ids = np.arange(num_channels) + sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids) + + for unit_id in unit_ids: + non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + waveforms_sparse = rng.random(size=(num_units, num_samples, num_active_channels)) + + # Test are waveforms sparse + assert sparsity.are_waveforms_sparse(waveforms_sparse, unit_id=unit_id) + + # Test densify + waveforms_dense = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) + assert waveforms_dense.shape == (num_units, num_samples, num_channels) + + # Test round-trip + waveforms_sparse2 = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id) + assert np.array_equal(waveforms_sparse, waveforms_sparse2) + + # Test densify with one waveform (template) + template_sparse = waveforms_sparse.mean(axis=0) + template_dense = sparsity.densify_waveforms(template_sparse, unit_id=unit_id) + assert template_dense.shape == (num_samples, num_channels) + + # Test round trip with template + template_sparse2 = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id) + assert np.array_equal(template_sparse, template_sparse2) + + if __name__ == "__main__": test_ChannelSparsity() diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 107ef5f180..2bbf5e9b0f 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -6,7 +6,13 @@ import zarr -from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity +from spikeinterface.core import ( + generate_recording, + generate_sorting, + NumpySorting, + ChannelSparsity, + generate_ground_truth_recording, +) from spikeinterface import WaveformExtractor, BaseRecording, extract_waveforms, load_waveforms from spikeinterface.core.waveform_extractor import precompute_sparsity @@ -309,7 +315,7 @@ def test_recordingless(): recording = recording.save(folder=cache_folder / "recording1") sorting = sorting.save(folder=cache_folder / "sorting1") - # recording and sorting are not dumpable + # recording and sorting are not serializable wf_folder = cache_folder / "wf_recordingless" # save with relative paths @@ -510,10 +516,44 @@ def test_compute_sparsity(): print(sparsity) +def test_non_json_object(): + recording, sorting = generate_ground_truth_recording( + durations=[30, 40], + sampling_frequency=30000.0, + num_channels=32, + num_units=5, + ) + + # recording is not save to keep it in memory + sorting = sorting.save() + + wf_folder = cache_folder / "test_waveform_extractor" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + + we = extract_waveforms( + recording, + sorting, + wf_folder, + mode="folder", + sparsity=None, + sparse=False, + ms_before=1.0, + ms_after=1.6, + max_spikes_per_unit=50, + n_jobs=4, + chunk_size=30000, + progress_bar=True, + ) + + # This used to fail because of json + we = load_waveforms(wf_folder) + + if __name__ == "__main__": - test_WaveformExtractor() + # test_WaveformExtractor() # test_extract_waveforms() - # test_sparsity() # test_portability() # test_recordingless() # test_compute_sparsity() + test_non_json_object() diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6881ab3ec5..2710ff1338 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -159,11 +159,20 @@ def load_from_folder( else: rec_attributes["probegroup"] = None else: - try: - recording = load_extractor(folder / "recording.json", base_folder=folder) - rec_attributes = None - except: + recording = None + if (folder / "recording.json").exists(): + try: + recording = load_extractor(folder / "recording.json", base_folder=folder) + except: + pass + elif (folder / "recording.pickle").exists(): + try: + recording = load_extractor(folder / "recording.pickle") + except: + pass + if recording is None: raise Exception("The recording could not be loaded. You can use the `with_recording=False` argument") + rec_attributes = None if sorting is None: sorting = load_extractor(folder / "sorting.json", base_folder=folder) @@ -271,14 +280,22 @@ def create( else: relative_to = None - if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) - if sorting.check_if_json_serializable(): + elif recording.check_serializablility("pickle"): + # In this case we loose the relative_to!! + recording.dump(folder / "recording.pickle") + + if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif sorting.check_serializablility("pickle"): + # In this case we loose the relative_to!! + # TODO later the dump to pickle should dump the dictionary and so relative could be put back + sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -879,14 +896,19 @@ def save( (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) - if self.sorting.check_if_json_serializable(): + elif self.recording.check_serializablility("pickle"): + self.recording.dump(folder / "recording.pickle") + + if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif self.sorting.check_serializablility("pickle"): + self.sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -931,16 +953,16 @@ def save( # write metadata zarr_root.attrs["params"] = check_json(self._params) if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["recording"] = check_json(rec_dict) - if self.sorting.check_if_json_serializable(): + if self.sorting.check_serializablility("json"): sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["sorting"] = check_json(sort_dict) else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not json serializable, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) recording_info = zarr_root.create_group("recording_info") recording_info.attrs["recording_attributes"] = check_json(rec_attributes) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index c92861a8bf..ebc810b953 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -35,6 +35,7 @@ def export_to_phy( template_mode: str = "median", dtype: Optional[npt.DTypeLike] = None, verbose: bool = True, + use_relative_path: bool = False, **job_kwargs, ): """ @@ -64,6 +65,9 @@ def export_to_phy( Dtype to save binary data verbose: bool If True, output is verbose + use_relative_path : bool, default: False + If True and `copy_binary=True` saves the binary file `dat_path` in the `params.py` relative to `output_folder` (ie `dat_path=r'recording.dat'`). If `copy_binary=False`, then uses a path relative to the `output_folder` + If False, uses an absolute path in the `params.py` (ie `dat_path=r'path/to/the/recording.dat'`) {} """ @@ -94,7 +98,7 @@ def export_to_phy( used_sparsity = sparsity else: used_sparsity = ChannelSparsity.create_dense(waveform_extractor) - # convinient sparsity dict for the 3 cases to retrieve channl_inds + # convenient sparsity dict for the 3 cases to retrieve channl_inds sparse_dict = used_sparsity.unit_id_to_channel_indices empty_flag = False @@ -106,7 +110,7 @@ def export_to_phy( empty_flag = True unit_ids = non_empty_units if empty_flag: - warnings.warn("Empty units have been removed when being exported to Phy") + warnings.warn("Empty units have been removed while exporting to Phy") if len(unit_ids) == 0: raise Exception("No non-empty units in the sorting result, can't save to Phy.") @@ -149,7 +153,15 @@ def export_to_phy( # write params.py with (output_folder / "params.py").open("w") as f: - f.write(f"dat_path = r'{str(rec_path)}'\n") + if use_relative_path: + if copy_binary: + f.write(f"dat_path = r'recording.dat'\n") + elif rec_path == "None": + f.write(f"dat_path = {rec_path}\n") + else: + f.write(f"dat_path = r'{str(Path(rec_path).relative_to(output_folder))}'\n") + else: + f.write(f"dat_path = r'{str(rec_path)}'\n") f.write(f"n_channels_dat = {num_chans}\n") f.write(f"dtype = '{dtype_str}'\n") f.write(f"offset = 0\n") diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 3dde998ca1..bd56208ebe 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -6,13 +6,6 @@ from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts from spikeinterface.core.core_tools import define_function_from_class -try: - import mtscomp - - HAVE_MTSCOMP = True -except: - HAVE_MTSCOMP = False - class CompressedBinaryIblExtractor(BaseRecording): """Load IBL data as an extractor object. @@ -42,7 +35,6 @@ class CompressedBinaryIblExtractor(BaseRecording): """ extractor_name = "CompressedBinaryIbl" - installed = HAVE_MTSCOMP mode = "folder" installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" name = "cbin_ibl" @@ -51,7 +43,10 @@ def __init__(self, folder_path, load_sync_channel=False, stream_name="ap"): # this work only for future neo from neo.rawio.spikeglxrawio import read_meta_file, extract_stream_info - assert HAVE_MTSCOMP + try: + import mtscomp + except: + raise ImportError(self.installation_mesg) folder_path = Path(folder_path) # check bands diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index b40b998103..31241a4147 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -40,7 +40,6 @@ def __init__( sampling_frequency: float | None = None, session_info_file_path: str | Path | None = None, spikes_matfile_path: str | Path | None = None, - session_info_matfile_path: str | Path | None = None, ): try: from pymatreader import read_mat @@ -67,26 +66,6 @@ def __init__( ) file_path = spikes_matfile_path if file_path is None else file_path - if session_info_matfile_path is not None: - # Raise an error if the warning period has expired - deprecation_issued = datetime.datetime(2023, 4, 1) - deprecation_deadline = deprecation_issued + datetime.timedelta(days=180) - if datetime.datetime.now() > deprecation_deadline: - raise ValueError( - "The session_info_matfile_path argument is no longer supported in. Use session_info_file_path instead." - ) - - # Otherwise, issue a DeprecationWarning - else: - warnings.warn( - "The session_info_matfile_path argument is deprecated and will be removed in six months. " - "Use session_info_file_path instead.", - DeprecationWarning, - ) - session_info_file_path = ( - session_info_matfile_path if session_info_file_path is None else session_info_file_path - ) - self.spikes_cellinfo_path = Path(file_path) self.session_path = self.spikes_cellinfo_path.parent self.session_id = self.spikes_cellinfo_path.stem.split(".")[0] diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index a771dc47b1..cd2b6fb941 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -22,6 +22,19 @@ from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts +def drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs): + # Temporary function until neo version 0.13.0 is released + from packaging.version import Version + from importlib.metadata import version as lib_version + + neo_version = lib_version("neo") + # The possibility of ignoring timestamps errors is not present in neo <= 0.12.0 + if Version(neo_version) <= Version("0.12.0"): + neo_kwargs.pop("ignore_timestamps_errors") + + return neo_kwargs + + class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): """ Class for reading data saved by the Open Ephys GUI. @@ -45,14 +58,24 @@ class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks (experiments), specify the block index you want to load. all_annotations: bool (default False) Load exhaustively all annotation from neo. + ignore_timestamps_errors: bool (default False) + Ignore the discontinuous timestamps errors in neo. """ mode = "folder" NeoRawIOClass = "OpenEphysRawIO" name = "openephyslegacy" - def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): - neo_kwargs = self.map_to_neo_kwargs(folder_path) + def __init__( + self, + folder_path, + stream_id=None, + stream_name=None, + block_index=None, + all_annotations=False, + ignore_timestamps_errors=False, + ): + neo_kwargs = self.map_to_neo_kwargs(folder_path, ignore_timestamps_errors) NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, @@ -64,8 +87,9 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) @classmethod - def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(folder_path)} + def map_to_neo_kwargs(cls, folder_path, ignore_timestamps_errors=False): + neo_kwargs = {"dirname": str(folder_path), "ignore_timestamps_errors": ignore_timestamps_errors} + neo_kwargs = drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs) return neo_kwargs diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index c91aed644d..05aee160f5 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import Optional from pathlib import Path import numpy as np @@ -13,10 +16,14 @@ class BasePhyKilosortSortingExtractor(BaseSorting): ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py) - exclude_cluster_groups: list or str, optional + exclude_cluster_groups: list or str, default: None Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). keep_good_only : bool, default: True Whether to only keep good units. + remove_empty_units : bool, default: True + If True, empty units are removed from the sorting extractor. + load_all_cluster_properties : bool, default: True + If True, all cluster properties are loaded from the tsv/csv files. """ extractor_name = "BasePhyKilosortSorting" @@ -29,11 +36,11 @@ class BasePhyKilosortSortingExtractor(BaseSorting): def __init__( self, - folder_path, - exclude_cluster_groups=None, - keep_good_only=False, - remove_empty_units=False, - load_all_cluster_properties=True, + folder_path: Path | str, + exclude_cluster_groups: Optional[list[str] | str] = None, + keep_good_only: bool = False, + remove_empty_units: bool = False, + load_all_cluster_properties: bool = True, ): try: import pandas as pd @@ -195,20 +202,33 @@ class PhySortingExtractor(BasePhyKilosortSortingExtractor): ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py). - exclude_cluster_groups: list or str, optional + exclude_cluster_groups: list or str, default: None Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). + load_all_cluster_properties : bool, default: True + If True, all cluster properties are loaded from the tsv/csv files. Returns ------- extractor : PhySortingExtractor - The loaded data. + The loaded Sorting object. """ extractor_name = "PhySorting" name = "phy" - def __init__(self, folder_path, exclude_cluster_groups=None): - BasePhyKilosortSortingExtractor.__init__(self, folder_path, exclude_cluster_groups, keep_good_only=False) + def __init__( + self, + folder_path: Path | str, + exclude_cluster_groups: Optional[list[str] | str] = None, + load_all_cluster_properties: bool = True, + ): + BasePhyKilosortSortingExtractor.__init__( + self, + folder_path, + exclude_cluster_groups, + keep_good_only=False, + load_all_cluster_properties=load_all_cluster_properties, + ) self._kwargs = { "folder_path": str(Path(folder_path).absolute()), @@ -223,8 +243,6 @@ class KiloSortSortingExtractor(BasePhyKilosortSortingExtractor): ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py). - exclude_cluster_groups: list or str, optional - Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). keep_good_only : bool, default: True Whether to only keep good units. If True, only Kilosort-labeled 'good' units are returned. @@ -234,13 +252,13 @@ class KiloSortSortingExtractor(BasePhyKilosortSortingExtractor): Returns ------- extractor : KiloSortSortingExtractor - The loaded data. + The loaded Sorting object. """ extractor_name = "KiloSortSorting" name = "kilosort" - def __init__(self, folder_path, keep_good_only=False, remove_empty_units=True): + def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove_empty_units: bool = True): BasePhyKilosortSortingExtractor.__init__( self, folder_path, diff --git a/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py b/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py index 35de8a23e2..c4c8d0c993 100644 --- a/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py +++ b/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py @@ -26,7 +26,7 @@ class CellExplorerSortingTest(SortingCommonTestSuite, unittest.TestCase): ( "cellexplorer/dataset_2/20170504_396um_0um_merge.spikes.cellinfo.mat", { - "session_info_matfile_path": local_folder + "session_info_file_path": local_folder / "cellexplorer/dataset_2/20170504_396um_0um_merge.sessionInfo.mat" }, ), diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 38cb714d59..ccd2121174 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -73,12 +73,6 @@ def _run(self, **job_kwargs): func = _spike_amplitudes_chunk init_func = _init_worker_spike_amplitudes n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) - if n_jobs != 1: - # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.check_if_dumpable(), ( - "The sorting object is not dumpable and cannot be processed in parallel. You can use the " - "`sorting.save()` function to make it dumpable" - ) init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor( recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index e2ef6e6794..6ab1a9afce 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -333,7 +333,7 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump_to_json(folder / "recording.json") np.save(folder / "peaks.npy", peaks) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 8dd5f857f6..e9726a16da 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -499,9 +499,8 @@ def compute_sliding_rp_violations( ) -def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs): - """ - Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of +def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): + """Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of "synchrony_size" spikes at the exact same sample index. Parameters @@ -510,6 +509,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k The waveform extractor object. synchrony_sizes : list or tuple, default: (2, 4, 8) The synchrony sizes to compute. + unit_ids : list or None, default: None + List of unit ids to compute the synchrony metrics. If None, all units are used. Returns ------- @@ -522,16 +523,20 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k Based on concepts described in [Gruen]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ - assert np.all(s > 1 for s in synchrony_sizes), "Synchrony sizes must be greater than 1" + assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit() sorting = waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) + if unit_ids is None: + unit_ids = sorting.unit_ids + # Pre-allocate synchrony counts synchrony_counts = {} for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64) + all_unit_ids = list(sorting.unit_ids) for segment_index in range(sorting.get_num_segments()): spikes_in_segment = spikes[segment_index] @@ -539,7 +544,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True) # add counts for this segment - for unit_index in np.arange(len(sorting.unit_ids)): + for unit_id in unit_ids: + unit_index = all_unit_ids.index(unit_id) spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index] # some segments/units might have no spikes if len(spikes_per_unit) == 0: @@ -551,8 +557,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # add counts for this segment synchrony_metrics_dict = { f"sync_spike_{synchrony_size}": { - unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id] - for unit_index, unit_id in enumerate(sorting.unit_ids) + unit_id: synchrony_counts[synchrony_size][all_unit_ids.index(unit_id)] / spike_counts[unit_id] + for unit_id in unit_ids } for synchrony_size in synchrony_sizes } @@ -563,7 +569,172 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k return synchrony_metrics -_default_params["synchrony_metrics"] = dict(synchrony_sizes=(0, 2, 4)) +_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) + + +def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): + """Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution + computed in non-overlapping time bins. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + bin_size_s : float, default: 5 + The size of the bin in seconds. + percentiles : tuple, default: (5, 95) + The percentiles to compute. + unit_ids : list or None + List of unit ids to compute the firing range. If None, all units are used. + + Returns + ------- + firing_ranges : dict + The firing range for each unit. + + Notes + ----- + Designed by Simon Musall and ported to SpikeInterface by Alessio Buccino. + """ + sampling_frequency = waveform_extractor.sampling_frequency + bin_size_samples = int(bin_size_s * sampling_frequency) + sorting = waveform_extractor.sorting + if unit_ids is None: + unit_ids = sorting.unit_ids + + # for each segment, we compute the firing rate histogram and we concatenate them + firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} + for segment_index in range(waveform_extractor.get_num_segments()): + num_samples = waveform_extractor.get_num_samples(segment_index) + edges = np.arange(0, num_samples + 1, bin_size_samples) + + for unit_id in unit_ids: + spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + spike_counts, _ = np.histogram(spike_times, bins=edges) + firing_rates = spike_counts / bin_size_s + firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) + + # finally we compute the percentiles + firing_ranges = {} + for unit_id in unit_ids: + firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( + firing_rate_histograms[unit_id], percentiles[0] + ) + + return firing_ranges + + +_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(5, 95)) + + +def compute_amplitude_cv_metrics( + waveform_extractor, + average_num_spikes_per_bin=50, + percentiles=(5, 95), + min_num_bins=10, + amplitude_extension="spike_amplitudes", + unit_ids=None, +): + """Calculate coefficient of variation of spike amplitudes within defined temporal bins. + From the distribution of coefficient of variations, both the median and the "range" (the distance between the + percentiles defined by `percentiles` parameter) are returned. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + average_num_spikes_per_bin : int, default: 50 + The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate + of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is + 100, then the temporal bin size will be 100/10 Hz = 10 s. + min_num_bins : int, default: 10 + The minimum number of bins to compute the median and range. If the number of bins is less than this then + the median and range are set to NaN. + amplitude_extension : str, default: 'spike_amplitudes' + The name of the extension to load the amplitudes from. 'spike_amplitudes' or 'amplitude_scalings'. + unit_ids : list or None + List of unit ids to compute the amplitude spread. If None, all units are used. + + Returns + ------- + amplitude_cv_median : dict + The median of the CV + amplitude_cv_range : dict + The range of the CV, computed as the distance between the percentiles. + + Notes + ----- + Designed by Simon Musall and Alessio Buccino. + """ + res = namedtuple("amplitude_cv", ["amplitude_cv_median", "amplitude_cv_range"]) + assert amplitude_extension in ( + "spike_amplitudes", + "amplitude_scalings", + ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" + sorting = waveform_extractor.sorting + total_duration = waveform_extractor.get_total_duration() + spikes = sorting.to_spike_vector() + num_spikes = sorting.count_num_spikes_per_unit() + if unit_ids is None: + unit_ids = sorting.unit_ids + + if waveform_extractor.is_extension(amplitude_extension): + sac = waveform_extractor.load_extension(amplitude_extension) + amps = sac.get_data(outputs="concatenated") + if amplitude_extension == "spike_amplitudes": + amps = np.concatenate(amps) + else: + warnings.warn("") + empty_dict = {unit_id: np.nan for unit_id in unit_ids} + return empty_dict + + # precompute segment slice + segment_slices = [] + for segment_index in range(waveform_extractor.get_num_segments()): + i0 = np.searchsorted(spikes["segment_index"], segment_index) + i1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + segment_slices.append(slice(i0, i1)) + + all_unit_ids = list(sorting.unit_ids) + amplitude_cv_medians, amplitude_cv_ranges = {}, {} + for unit_id in unit_ids: + firing_rate = num_spikes[unit_id] / total_duration + temporal_bin_size_samples = int( + (average_num_spikes_per_bin / firing_rate) * waveform_extractor.sampling_frequency + ) + + amp_spreads = [] + # bins and amplitude means are computed for each segment + for segment_index in range(waveform_extractor.get_num_segments()): + sample_bin_edges = np.arange( + 0, waveform_extractor.get_num_samples(segment_index) + 1, temporal_bin_size_samples + ) + spikes_in_segment = spikes[segment_slices[segment_index]] + amps_in_segment = amps[segment_slices[segment_index]] + unit_mask = spikes_in_segment["unit_index"] == all_unit_ids.index(unit_id) + spike_indices_unit = spikes_in_segment["sample_index"][unit_mask] + amps_unit = amps_in_segment[unit_mask] + amp_mean = np.abs(np.mean(amps_unit)) + for t0, t1 in zip(sample_bin_edges[:-1], sample_bin_edges[1:]): + i0 = np.searchsorted(spike_indices_unit, t0) + i1 = np.searchsorted(spike_indices_unit, t1) + amp_spreads.append(np.std(amps_unit[i0:i1]) / amp_mean) + + if len(amp_spreads) < min_num_bins: + amplitude_cv_medians[unit_id] = np.nan + amplitude_cv_ranges[unit_id] = np.nan + else: + amplitude_cv_medians[unit_id] = np.median(amp_spreads) + amplitude_cv_ranges[unit_id] = np.percentile(amp_spreads, percentiles[1]) - np.percentile( + amp_spreads, percentiles[0] + ) + + return res(amplitude_cv_medians, amplitude_cv_ranges) + + +_default_params["amplitude_cv"] = dict( + average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes" +) def compute_amplitude_cutoffs( diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 90dbb47a3a..97f14ec6f4 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -12,6 +12,8 @@ compute_amplitude_medians, compute_drift_metrics, compute_synchrony_metrics, + compute_firing_ranges, + compute_amplitude_cv_metrics, ) from .pca_metrics import ( @@ -40,6 +42,8 @@ "sliding_rp_violation": compute_sliding_rp_violations, "amplitude_cutoff": compute_amplitude_cutoffs, "amplitude_median": compute_amplitude_medians, + "amplitude_cv": compute_amplitude_cv_metrics, "synchrony": compute_synchrony_metrics, + "firing_range": compute_firing_ranges, "drift": compute_drift_metrics, } diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index d927d64c4f..2d63a06b17 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -12,6 +12,7 @@ compute_principal_components, compute_spike_locations, compute_spike_amplitudes, + compute_amplitude_scalings, ) from spikeinterface.qualitymetrics import ( @@ -31,6 +32,8 @@ compute_drift_metrics, compute_amplitude_medians, compute_synchrony_metrics, + compute_firing_ranges, + compute_amplitude_cv_metrics, ) @@ -212,6 +215,12 @@ def test_calculate_firing_rate_num_spikes(waveform_extractor_simple): # np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) +def test_calculate_firing_range(waveform_extractor_simple): + we = waveform_extractor_simple + firing_ranges = compute_firing_ranges(we) + print(firing_ranges) + + def test_calculate_amplitude_cutoff(waveform_extractor_simple): we = waveform_extractor_simple spike_amps = compute_spike_amplitudes(we) @@ -234,6 +243,24 @@ def test_calculate_amplitude_median(waveform_extractor_simple): # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) +def test_calculate_amplitude_cv_metrics(waveform_extractor_simple): + we = waveform_extractor_simple + spike_amps = compute_spike_amplitudes(we) + amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(we, average_num_spikes_per_bin=20) + print(amp_cv_median) + print(amp_cv_range) + + amps_scalings = compute_amplitude_scalings(we) + amp_cv_median_scalings, amp_cv_range_scalings = compute_amplitude_cv_metrics( + we, + average_num_spikes_per_bin=20, + amplitude_extension="amplitude_scalings", + min_num_bins=5, + ) + print(amp_cv_median_scalings) + print(amp_cv_range_scalings) + + def test_calculate_snrs(waveform_extractor_simple): we = waveform_extractor_simple snrs = compute_snrs(we) @@ -358,4 +385,6 @@ def test_calculate_drift_metrics(waveform_extractor_simple): # test_calculate_isi_violations(we) # test_calculate_sliding_rp_violations(we) # test_calculate_drift_metrics(we) - test_synchrony_metrics(we) + # test_synchrony_metrics(we) + test_calculate_firing_range(we) + test_calculate_amplitude_cv_metrics(we) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index c7581ba1e1..a956f8c811 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -137,8 +137,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) rec_file = output_folder / "spikeinterface_recording.json" - if recording.check_if_json_serializable(): - recording.dump_to_json(rec_file, relative_to=output_folder) + if recording.check_serializablility("json"): + recording.dump(rec_file, relative_to=output_folder) + elif recording.check_serializablility("pickle"): + recording.dump(output_folder / "spikeinterface_recording.pickle") else: d = {"warning": "The recording is not serializable to json"} rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") @@ -185,6 +187,26 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): return params + @classmethod + def load_recording_from_folder(cls, output_folder, with_warnings=False): + json_file = output_folder / "spikeinterface_recording.json" + pickle_file = output_folder / "spikeinterface_recording.pickle" + + if json_file.exists(): + with (json_file).open("r", encoding="utf8") as f: + recording_dict = json.load(f) + if "warning" in recording_dict.keys() and with_warnings: + warnings.warn( + "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." + ) + recording = None + else: + recording = load_extractor(json_file, base_folder=output_folder) + elif pickle_file.exists(): + recording = load_extractor(pickle_file) + + return recording + @classmethod def _dump_params(cls, recording, output_folder, sorter_params, verbose): with (output_folder / "spikeinterface_params.json").open(mode="w", encoding="utf8") as f: @@ -271,7 +293,7 @@ def run_from_folder(cls, output_folder, raise_error, verbose): return run_time @classmethod - def get_result_from_folder(cls, output_folder): + def get_result_from_folder(cls, output_folder, register_recording=True, sorting_info=True): output_folder = Path(output_folder) sorter_output_folder = output_folder / "sorter_output" # check errors in log file @@ -294,27 +316,25 @@ def get_result_from_folder(cls, output_folder): # back-compatibility sorting = cls._get_result_from_folder(output_folder) - # register recording to Sorting object - # check if not json serializable - with (output_folder / "spikeinterface_recording.json").open("r", encoding="utf8") as f: - recording_dict = json.load(f) - if "warning" in recording_dict.keys(): - warnings.warn( - "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." - ) - else: - recording = load_extractor(output_folder / "spikeinterface_recording.json", base_folder=output_folder) + if register_recording: + # register recording to Sorting object + recording = cls.load_recording_from_folder(output_folder, with_warnings=False) if recording is not None: - # can be None when not dumpable sorting.register_recording(recording) - # set sorting info to Sorting object - with open(output_folder / "spikeinterface_recording.json", "r") as f: - rec_dict = json.load(f) - with open(output_folder / "spikeinterface_params.json", "r") as f: - params_dict = json.load(f) - with open(output_folder / "spikeinterface_log.json", "r") as f: - log_dict = json.load(f) - sorting.set_sorting_info(rec_dict, params_dict, log_dict) + + if sorting_info: + # set sorting info to Sorting object + if (output_folder / "spikeinterface_recording.json").exists(): + with open(output_folder / "spikeinterface_recording.json", "r") as f: + rec_dict = json.load(f) + else: + rec_dict = None + + with open(output_folder / "spikeinterface_params.json", "r") as f: + params_dict = json.load(f) + with open(output_folder / "spikeinterface_log.json", "r") as f: + log_dict = json.load(f) + sorting.set_sorting_info(rec_dict, params_dict, log_dict) return sorting diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index a8d702ebe9..5180e6f1cc 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -147,9 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: new_api = False - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) p = params diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index 69f97fd11c..f6f0b3eaeb 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -89,9 +89,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort4 - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index df6d276bf5..a88c59d688 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -115,9 +115,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - recording: BaseRecording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index 2a41d793d5..1962d56206 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -148,9 +148,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): # saved by setup recording diff --git a/src/spikeinterface/sorters/internal/si_based.py b/src/spikeinterface/sorters/internal/si_based.py index 1496ffbbd1..989fab1258 100644 --- a/src/spikeinterface/sorters/internal/si_based.py +++ b/src/spikeinterface/sorters/internal/si_based.py @@ -1,4 +1,4 @@ -from spikeinterface.core import load_extractor +from spikeinterface.core import load_extractor, NumpyRecording from spikeinterface.sorters import BaseSorter @@ -14,7 +14,6 @@ def is_installed(cls): @classmethod def _setup_recording(cls, recording, output_folder, params, verbose): - # nothing to do here because the spikeinterface_recording.json is here anyway pass @classmethod diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index db3d88f116..710c4f76f4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -20,7 +20,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, "filtering": {"dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, @@ -52,9 +52,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs["verbose"] = verbose job_kwargs["progress_bar"] = verbose - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + sampling_rate = recording.get_sampling_frequency() num_channels = recording.get_num_channels() @@ -152,7 +151,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( - recording_f, method="circus-omp", method_kwargs=matching_params, **matching_job_params + recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params ) if verbose: diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 42f51d3a77..ed327e0f3c 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -49,9 +49,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import hdbscan - recording_raw = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) num_chans = recording_raw.get_num_channels() sampling_frequency = recording_raw.get_sampling_frequency() diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index f32a468a22..704f6843f2 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -66,7 +66,8 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal engine_kwargs: dict return_output: bool, dfault False - Return a sorting or None. + Return a sortings or None. + This also overwrite kwargs in in run_sorter(with_sorting=True/False) Returns ------- @@ -88,8 +89,12 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal "processpoolexecutor", ), "Only 'loop', 'joblib', and 'processpoolexecutor' support return_output=True." out = [] + for kwargs in job_list: + kwargs["with_output"] = True else: out = None + for kwargs in job_list: + kwargs["with_output"] = False if engine == "loop": # simple loop in main process diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 6e6ccc0358..bd5667b15f 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -624,10 +624,20 @@ def run_sorter_container( ) -def read_sorter_folder(output_folder, raise_error=True): +def read_sorter_folder(output_folder, register_recording=True, sorting_info=True, raise_error=True): """ Load a sorting object from a spike sorting output folder. The 'output_folder' must contain a valid 'spikeinterface_log.json' file + + + Parameters + ---------- + output_folder: Pth or str + The sorter folder + register_recording: bool, default: True + Attach recording (when json or pickle) to the sorting + sorting_info: bool, default: True + Attach sorting info to the sorting. """ output_folder = Path(output_folder) log_file = output_folder / "spikeinterface_log.json" @@ -647,7 +657,9 @@ def read_sorter_folder(output_folder, raise_error=True): sorter_name = log["sorter_name"] SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) + sorting = SorterClass.get_result_from_folder( + output_folder, register_recording=register_recording, sorting_info=sorting_info + ) return sorting diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 14c938f8ba..a5e29c8fd9 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -178,7 +178,7 @@ def test_run_sorters_with_list(): if working_folder.is_dir(): shutil.rmtree(working_folder) - # make dumpable + # make serializable rec0 = load_extractor(cache_folder / "toy_rec_0") rec1 = load_extractor(cache_folder / "toy_rec_1") diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index b87bbc7cee..28a1a63065 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -539,6 +539,7 @@ def remove_duplicates_via_matching( method_kwargs={}, job_kwargs={}, tmp_folder=None, + method="circus-omp-svd", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface import get_noise_levels @@ -546,7 +547,6 @@ def remove_duplicates_via_matching( from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms from spikeinterface.core import get_global_tmp_folder - from spikeinterface.sortingcomponents.matching.circus import get_scipy_shape import string, random, shutil, os from pathlib import Path @@ -591,19 +591,12 @@ def remove_duplicates_via_matching( chunk_size = duration + 3 * margin - dummy_filter = np.empty((num_chans, duration), dtype=np.float32) - dummy_traces = np.empty((num_chans, chunk_size), dtype=np.float32) - - fshape, axes = get_scipy_shape(dummy_filter, dummy_traces, axes=1) - method_kwargs.update( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], "omp_min_sps": 0.1, - "templates": None, - "overlaps": None, } ) @@ -618,16 +611,31 @@ def remove_duplicates_via_matching( method_kwargs.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method="circus-omp", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs - ) - method_kwargs.update( - { - "overlaps": computed["overlaps"], - "templates": computed["templates"], - "norms": computed["norms"], - "sparsities": computed["sparsities"], - } + sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs ) + if method == "circus-omp-svd": + method_kwargs.update( + { + "overlaps": computed["overlaps"], + "templates": computed["templates"], + "norms": computed["norms"], + "temporal": computed["temporal"], + "spatial": computed["spatial"], + "singular": computed["singular"], + "units_overlaps": computed["units_overlaps"], + "unit_overlaps_indices": computed["unit_overlaps_indices"], + "sparsity_mask": computed["sparsity_mask"], + } + ) + elif method == "circus-omp": + method_kwargs.update( + { + "overlaps": computed["overlaps"], + "templates": computed["templates"], + "norms": computed["norms"], + "sparsities": computed["sparsities"], + } + ) valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) if np.sum(valid) > 0: if np.sum(valid) == 1: diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index be8ecd6702..864548e7d4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -18,7 +18,9 @@ from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms -from spikeinterface.sortingcomponents.features_from_peaks import compute_features_from_peaks, EnergyFeature +from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser +from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature +from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever class RandomProjectionClustering: @@ -34,17 +36,17 @@ class RandomProjectionClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, + "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, - "max_spikes_per_unit": 200, "selection_method": "closest_to_centroid", - "nb_projections": {"ptp": 8, "energy": 2}, - "ms_before": 1.5, - "ms_after": 1.5, + "nb_projections": 10, + "ms_before": 1, + "ms_after": 1, "random_seed": 42, - "shared_memory": False, - "min_values": {"ptp": 0, "energy": 0}, + "smoothing_kwargs": {"window_length_ms": 1}, + "shared_memory": True, "tmp_folder": None, - "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "10M", "verbose": True, "progress_bar": True}, + "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } @classmethod @@ -74,50 +76,60 @@ def main_function(cls, recording, peaks, params): np.random.seed(d["random_seed"]) - features_params = {} - features_list = [] - - noise_snippets = None - - for proj_type in ["ptp", "energy"]: - if d["nb_projections"][proj_type] > 0: - features_list += [f"random_projections_{proj_type}"] - - if d["min_values"][proj_type] == "auto": - if noise_snippets is None: - num_segments = recording.get_num_segments() - num_chunks = 3 * d["max_spikes_per_unit"] // num_segments - noise_snippets = get_random_data_chunks( - recording, num_chunks_per_segment=num_chunks, chunk_size=num_samples, seed=42 - ) - noise_snippets = noise_snippets.reshape(num_chunks, num_samples, num_chans) - - if proj_type == "energy": - data = np.linalg.norm(noise_snippets, axis=1) - min_values = np.median(data, axis=0) - elif proj_type == "ptp": - data = np.ptp(noise_snippets, axis=1) - min_values = np.median(data, axis=0) - elif d["min_values"][proj_type] > 0: - min_values = d["min_values"][proj_type] - else: - min_values = None - - projections = np.random.randn(num_chans, d["nb_projections"][proj_type]) - features_params[f"random_projections_{proj_type}"] = { - "radius_um": params["radius_um"], - "projections": projections, - "min_values": min_values, - } - - features_data = compute_features_from_peaks( - recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"] + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name + else: + tmp_folder = Path(params["tmp_folder"]).absolute() + + ### Then we extract the SVD features + node0 = PeakRetriever(recording, peaks) + node1 = ExtractDenseWaveforms( + recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"] ) - if len(features_data) > 1: - hdbscan_data = np.hstack((features_data[0], features_data[1])) - else: - hdbscan_data = features_data[0] + node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) + + projections = np.random.randn(num_chans, d["nb_projections"]) + projections -= projections.mean(0) + projections /= projections.std(0) + + nbefore = int(params["ms_before"] * fs / 1000) + nafter = int(params["ms_after"] * fs / 1000) + nsamples = nbefore + nafter + + import scipy + + x = np.random.randn(100, nsamples, num_chans).astype(np.float32) + x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1) + + ptps = np.ptp(x, axis=1) + a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000)) + ydata = np.cumsum(a) / a.sum() + xdata = b[1:] + + from scipy.optimize import curve_fit + + def sigmoid(x, L, x0, k, b): + y = L / (1 + np.exp(-k * (x - x0))) + b + return y + + p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess + popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) + + node3 = RandomProjectionsFeature( + recording, + parents=[node0, node2], + return_output=True, + projections=projections, + radius_um=params["radius_um"], + ) + + pipeline_nodes = [node0, node1, node2, node3] + + hdbscan_data = run_node_pipeline( + recording, pipeline_nodes, params["job_kwargs"], job_name="extracting features" + ) import sklearn @@ -132,7 +144,7 @@ def main_function(cls, recording, peaks, params): all_indices = np.arange(0, peak_labels.size) - max_spikes = params["max_spikes_per_unit"] + max_spikes = params["waveforms"]["max_spikes_per_unit"] selection_method = params["selection_method"] for unit_ind in labels: diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index bd82ffa0a6..b534c2356d 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -184,41 +184,44 @@ def __init__( return_output=True, parents=None, projections=None, - radius_um=150.0, - min_values=None, + sigmoid=None, + radius_um=None, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.projections = projections - self.radius_um = radius_um - self.min_values = min_values - + self.sigmoid = sigmoid self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance < radius_um - - self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values)) - + self.radius_um = radius_um + self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um)) self._dtype = recording.get_dtype() def get_dtype(self): return self._dtype + def _sigmoid(self, x): + L, x0, k, b = self.sigmoid + y = L / (1 + np.exp(-k * (x - x0))) + b + return y + def compute(self, traces, peaks, waveforms): all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype) + for main_chan in np.unique(peaks["channel_index"]): (idx,) = np.nonzero(peaks["channel_index"] == main_chan) (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] - wf_ptp = (waveforms[idx][:, :, chan_inds]).ptp(axis=1) + wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) - if self.min_values is not None: - wf_ptp = (wf_ptp / self.min_values[chan_inds]) ** 4 + if self.sigmoid is not None: + wf_ptp *= self._sigmoid(wf_ptp) denom = np.sum(wf_ptp, axis=1) mask = denom != 0 - all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis]) + return all_projections diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index a19e7b71b5..358691cd25 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -33,9 +33,6 @@ from .main import BaseTemplateMatchingEngine -################# -# Circus peeler # -################# from scipy.fft._helper import _init_nd_shape_and_axes @@ -478,6 +475,366 @@ def main_function(cls, traces, d): return spikes +class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): + """ + Orthogonal Matching Pursuit inspired from Spyking Circus sorter + + https://elifesciences.org/articles/34518 + + This is an Orthogonal Template Matching algorithm. For speed and + memory optimization, templates are automatically sparsified. Signal + is convolved with the templates, and as long as some scalar products + are higher than a given threshold, we use a Cholesky decomposition + to compute the optimal amplitudes needed to reconstruct the signal. + + IMPORTANT NOTE: small chunks are more efficient for such Peeler, + consider using 100ms chunk + + Parameters + ---------- + amplitude: tuple + (Minimal, Maximal) amplitudes allowed for every template + omp_min_sps: float + Stopping criteria of the OMP algorithm, in percentage of the norm + noise_levels: array + The noise levels, for every channels. If None, they will be automatically + computed + random_chunk_kwargs: dict + Parameters for computing noise levels, if not provided (sub optimal) + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. + ----- + """ + + _default_params = { + "amplitudes": [0.6, 2], + "omp_min_sps": 0.1, + "waveform_extractor": None, + "random_chunk_kwargs": {}, + "noise_levels": None, + "rank": 5, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, + "ignored_ids": [], + "vicinity": 0, + } + + @classmethod + def _prepare_templates(cls, d): + waveform_extractor = d["waveform_extractor"] + num_templates = len(d["waveform_extractor"].sorting.unit_ids) + + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask + + d["sparsity_mask"] = sparsity + units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) + d["units_overlaps"] = units_overlaps > 0 + d["unit_overlaps_indices"] = {} + for i in range(num_templates): + (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) + + templates = waveform_extractor.get_all_templates(mode="median").copy() + + # First, we set masked channels to 0 + for count in range(num_templates): + templates[count][:, ~d["sparsity_mask"][count]] = 0 + + # Then we keep only the strongest components + rank = d["rank"] + temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) + d["temporal"] = temporal[:, :, :rank] + d["singular"] = singular[:, :rank] + d["spatial"] = spatial[:, :rank, :] + + # We reconstruct the approximated templates + templates = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) + + d["templates"] = {} + d["norms"] = np.zeros(num_templates, dtype=np.float32) + + # And get the norms, saving compressed templates for CC matrix + for count in range(num_templates): + template = templates[count][:, d["sparsity_mask"][count]] + d["norms"][count] = np.linalg.norm(template) + d["templates"][count] = template / d["norms"][count] + + d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] + d["temporal"] = np.flip(d["temporal"], axis=1) + + d["overlaps"] = [] + for i in range(num_templates): + num_overlaps = np.sum(d["units_overlaps"][i]) + overlapping_units = np.where(d["units_overlaps"][i])[0] + + # Reconstruct unit template from SVD Matrices + data = d["temporal"][i] * d["singular"][i][np.newaxis, :] + template_i = np.matmul(data, d["spatial"][i, :, :]) + template_i = np.flipud(template_i) + + unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32) + + for count, j in enumerate(overlapping_units): + overlapped_channels = d["sparsity_mask"][j] + visible_i = template_i[:, overlapped_channels] + + spatial_filters = d["spatial"][j, :, overlapped_channels] + spatially_filtered_template = np.matmul(visible_i, spatial_filters) + visible_i = spatially_filtered_template * d["singular"][j] + + for rank in range(visible_i.shape[1]): + unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full") + + d["overlaps"].append(unit_overlaps) + + d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) + d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) + d["singular"] = d["singular"].T[:, :, np.newaxis] + return d + + @classmethod + def initialize_and_check_kwargs(cls, recording, kwargs): + d = cls._default_params.copy() + d.update(kwargs) + + # assert isinstance(d['waveform_extractor'], WaveformExtractor) + + for v in ["omp_min_sps"]: + assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + + d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() + d["num_samples"] = d["waveform_extractor"].nsamples + d["nbefore"] = d["waveform_extractor"].nbefore + d["nafter"] = d["waveform_extractor"].nafter + d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["vicinity"] *= d["num_samples"] + + if d["noise_levels"] is None: + print("CircusOMPPeeler : noise should be computed outside") + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + + if "templates" not in d: + d = cls._prepare_templates(d) + else: + for key in [ + "norms", + "temporal", + "spatial", + "singular", + "units_overlaps", + "sparsity_mask", + "unit_overlaps_indices", + ]: + assert d[key] is not None, "If templates are provided, %d should also be there" % key + + d["num_templates"] = len(d["templates"]) + d["ignored_ids"] = np.array(d["ignored_ids"]) + + d["unit_overlaps_tables"] = {} + for i in range(d["num_templates"]): + d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) + d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) + + omp_min_sps = d["omp_min_sps"] + # d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + d["stop_criteria"] = omp_min_sps * np.maximum(d["norms"], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) + + return d + + @classmethod + def serialize_method_kwargs(cls, kwargs): + kwargs = dict(kwargs) + # remove waveform_extractor + kwargs.pop("waveform_extractor") + return kwargs + + @classmethod + def unserialize_in_worker(cls, kwargs): + return kwargs + + @classmethod + def get_margin(cls, recording, kwargs): + margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) + return margin + + @classmethod + def main_function(cls, traces, d): + templates = d["templates"] + num_templates = d["num_templates"] + num_channels = d["num_channels"] + num_samples = d["num_samples"] + overlaps = d["overlaps"] + norms = d["norms"] + nbefore = d["nbefore"] + nafter = d["nafter"] + omp_tol = np.finfo(np.float32).eps + num_samples = d["nafter"] + d["nbefore"] + neighbor_window = num_samples - 1 + min_amplitude, max_amplitude = d["amplitudes"] + ignored_ids = d["ignored_ids"] + stop_criteria = d["stop_criteria"][:, np.newaxis] + vicinity = d["vicinity"] + rank = d["rank"] + + num_timesteps = len(traces) + + num_peaks = num_timesteps - num_samples + 1 + conv_shape = (num_templates, num_peaks) + scalar_products = np.zeros(conv_shape, dtype=np.float32) + + # Filter using overlap-and-add convolution + if len(ignored_ids) > 0: + mask = ~np.isin(np.arange(num_templates), ignored_ids) + spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"][:, mask, :] + objective_by_rank = scipy.signal.oaconvolve( + scaled_filtered_data, d["temporal"][:, mask, :], axes=2, mode="valid" + ) + scalar_products[mask] += np.sum(objective_by_rank, axis=0) + scalar_products[ignored_ids] = -np.inf + else: + spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"] + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") + scalar_products += np.sum(objective_by_rank, axis=0) + + num_spikes = 0 + + spikes = np.empty(scalar_products.size, dtype=spike_dtype) + idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) + + M = np.zeros((num_templates, num_templates), dtype=np.float32) + + all_selections = np.empty((2, scalar_products.size), dtype=np.int32) + final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) + num_selection = 0 + + full_sps = scalar_products.copy() + + neighbors = {} + cached_overlaps = {} + + is_valid = scalar_products > stop_criteria + all_amplitudes = np.zeros(0, dtype=np.float32) + is_in_vicinity = np.zeros(0, dtype=np.int32) + + while np.any(is_valid): + best_amplitude_ind = scalar_products[is_valid].argmax() + best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) + + if num_selection > 0: + delta_t = selection[1] - peak_index + idx = np.where((delta_t < num_samples) & (delta_t > -num_samples))[0] + myline = neighbor_window + delta_t[idx] + myindices = selection[0, idx] + + local_overlaps = overlaps[best_cluster_ind] + overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] + table = d["unit_overlaps_tables"][best_cluster_ind] + + if num_selection == M.shape[0]: + Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) + Z[:num_selection, :num_selection] = M + M = Z + + mask = np.isin(myindices, overlapping_templates) + a, b = myindices[mask], myline[mask] + M[num_selection, idx[mask]] = local_overlaps[table[a], b] + + if vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 + else: + M[0, 0] = 1 + + all_selections[:, num_selection] = [best_cluster_ind, peak_index] + num_selection += 1 + + selection = all_selections[:, :num_selection] + res_sps = full_sps[selection[0], selection[1]] + + if True: # vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + all_amplitudes /= norms[selection[0]] + else: + # This is not working, need to figure out why + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] + + diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] + modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] + final_amplitudes[selection[0], selection[1]] = all_amplitudes + + for i in modified: + tmp_best, tmp_peak = selection[:, i] + diff_amp = diff_amplitudes[i] * norms[tmp_best] + + local_overlaps = overlaps[tmp_best] + overlapping_templates = d["units_overlaps"][tmp_best] + + if not tmp_peak in neighbors.keys(): + idx = [max(0, tmp_peak - neighbor_window), min(num_peaks, tmp_peak + num_samples)] + tdx = [neighbor_window + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak - 1] + neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} + + idx = neighbors[tmp_peak]["idx"] + tdx = neighbors[tmp_peak]["tdx"] + + to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] + scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add + + is_valid = scalar_products > stop_criteria + + is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) + valid_indices = np.where(is_valid) + + num_spikes = len(valid_indices[0]) + spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["channel_index"][:num_spikes] = 0 + spikes["cluster_index"][:num_spikes] = valid_indices[0] + spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] + + spikes = spikes[:num_spikes] + order = np.argsort(spikes["sample_index"]) + spikes = spikes[order] + + return spikes + + class CircusPeeler(BaseTemplateMatchingEngine): """ diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index bedc04a9d5..d982943126 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -1,6 +1,6 @@ from .naive import NaiveMatching from .tdc import TridesclousPeeler -from .circus import CircusPeeler, CircusOMPPeeler +from .circus import CircusPeeler, CircusOMPPeeler, CircusOMPSVDPeeler from .wobble import WobbleMatch matching_methods = { @@ -8,5 +8,6 @@ "tridesclous": TridesclousPeeler, "circus": CircusPeeler, "circus-omp": CircusOMPPeeler, + "circus-omp-svd": CircusOMPSVDPeeler, "wobble": WobbleMatch, } diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index c0dcd7ea6e..c10c78cbfc 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,8 +1,3 @@ -# basics -# from .timeseries import plot_timeseries, TracesWidget -from .rasters import plot_rasters, RasterWidget -from .probemap import plot_probe_map, ProbeMapWidget - # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget @@ -15,9 +10,6 @@ # units on probe from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget -# comparison related -from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget -from .agreementmatrix import plot_agreement_matrix, AgreementMatrixWidget from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, @@ -41,22 +33,6 @@ from .sortingperformance import plot_sorting_performance -# ground truth study (=comparison over sorter) -from .gtstudy import ( - StudyComparisonRunTimesWidget, - plot_gt_study_run_times, - StudyComparisonUnitCountsWidget, - StudyComparisonUnitCountsAveragesWidget, - plot_gt_study_unit_counts, - plot_gt_study_unit_counts_averages, - plot_gt_study_performances, - plot_gt_study_performances_averages, - StudyComparisonPerformancesWidget, - StudyComparisonPerformancesAveragesWidget, - plot_gt_study_performances_by_template_similarity, - StudyComparisonPerformancesByTemplateSimilarity, -) - # ground truth comparions (=comparison over sorter) from .gtcomparison import ( plot_gt_performances, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index 6d981e1fd4..d25f1ea97b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -1,7 +1,6 @@ import numpy as np from .basewidget import BaseWidget -from spikeinterface.comparison.collisioncomparison import CollisionGTComparison class ComparisonCollisionPairByPairWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py deleted file mode 100644 index 573221f528..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py +++ /dev/null @@ -1,574 +0,0 @@ -""" -Various widgets on top of GroundTruthStudy to summary results: - * run times - * performances - * count units -""" -import numpy as np - - -from .basewidget import BaseWidget - - -class StudyComparisonRunTimesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - color: - - - """ - - def __init__(self, study, color="#F7DC6F", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.color = color - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - study = self.study - ax = self.ax - - all_run_times = study.aggregate_run_times() - av_run_times = all_run_times.reset_index().groupby("sorter_name")["run_time"].mean() - - if len(study.rec_names) == 1: - # no errors bars - yerr = None - else: - # errors bars across recording - yerr = all_run_times.reset_index().groupby("sorter_name")["run_time"].std() - - sorter_names = av_run_times.index - - x = np.arange(sorter_names.size) + 1 - ax.bar(x, av_run_times.values, width=0.8, color=self.color, yerr=yerr) - ax.set_ylabel("run times (s)") - ax.set_xticks(x) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_xlim(0, sorter_names.size + 1) - - -def plot_gt_study_run_times(*args, **kwargs): - W = StudyComparisonRunTimesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_run_times.__doc__ = StudyComparisonRunTimesWidget.__doc__ - - -class StudyComparisonUnitCountsAveragesWidget(BaseWidget): - """ - Plot averages over found units for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - log_scale: if the y-axis should be displayed as log scaled - - """ - - def __init__(self, study, cmap_name="Set2", log_scale=False, ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.log_scale = log_scale - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - study = self.study - ax = self.ax - - count_units = study.aggregate_count_units() - - if study.exhaustive_gt: - columns = ["num_well_detected", "num_false_positive", "num_redundant", "num_overmerged"] - else: - columns = ["num_well_detected", "num_redundant", "num_overmerged"] - ncol = len(columns) - - df = count_units.reset_index() - - m = df.groupby("sorter_name")[columns].mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = df.groupby("sorter_name")[columns].std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values, yerr=yerr, width=1 / (ncol + 2), color=cmap(c), label=clean_labels[c]) - - ax.legend() - if self.log_scale: - ax.set_yscale("log") - - ax.set_xticks(np.arange(sorter_names.size) + 1) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("# units") - ax.set_xlim(0, sorter_names.size + 1) - - if count_units["num_gt"].unique().size == 1: - num_gt = count_units["num_gt"].unique()[0] - ax.axhline(num_gt, ls="--", color="k") - - -class StudyComparisonUnitCountsWidget(BaseWidget): - """ - Plot averages over found units for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - log_scale: if the y-axis should be displayed as log scaled - - """ - - def __init__(self, study, cmap_name="Set2", log_scale=False, ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.log_scale = log_scale - - num_rec = len(study.rec_names) - if ax is None: - fig, axes = plt.subplots(ncols=1, nrows=num_rec, squeeze=False) - else: - axes = np.array([ax]).T - - BaseWidget.__init__(self, axes=axes) - - def plot(self): - study = self.study - ax = self.ax - - import seaborn as sns - - study = self.study - - count_units = study.aggregate_count_units() - count_units = count_units.reset_index() - - if study.exhaustive_gt: - columns = ["num_well_detected", "num_false_positive", "num_redundant", "num_overmerged"] - else: - columns = ["num_well_detected", "num_redundant", "num_overmerged"] - - ncol = len(columns) - cmap = plt.get_cmap(self.cmap_name, 4) - - for r, rec_name in enumerate(study.rec_names): - ax = self.axes[r, 0] - ax.set_title(rec_name) - df = count_units.loc[count_units["rec_name"] == rec_name, :] - m = df.groupby("sorter_name")[columns].mean() - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - ax.bar(x, m[col].values, width=1 / (ncol + 2), color=cmap(c), label=clean_labels[c]) - - if r == 0: - ax.legend() - - if self.log_scale: - ax.set_yscale("log") - - if r == len(study.rec_names) - 1: - ax.set_xticks(np.arange(sorter_names.size) + 1) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("# units") - ax.set_xlim(0, sorter_names.size + 1) - - if count_units["num_gt"].unique().size == 1: - num_gt = count_units["num_gt"].unique()[0] - ax.axhline(num_gt, ls="--", color="k") - - -def plot_gt_study_unit_counts_averages(*args, **kwargs): - W = StudyComparisonUnitCountsAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_unit_counts_averages.__doc__ = StudyComparisonUnitCountsAveragesWidget.__doc__ - - -def plot_gt_study_unit_counts(*args, **kwargs): - W = StudyComparisonUnitCountsWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_unit_counts.__doc__ = StudyComparisonUnitCountsWidget.__doc__ - - -class StudyComparisonPerformancesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, palette="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.palette = palette - - num_rec = len(study.rec_names) - if ax is None: - fig, axes = plt.subplots(ncols=1, nrows=num_rec, squeeze=False) - else: - axes = np.array([ax]).T - - BaseWidget.__init__(self, axes=axes) - - def plot(self, average=False): - import seaborn as sns - - study = self.study - - sns.set_palette(sns.color_palette(self.palette)) - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - for r, rec_name in enumerate(study.rec_names): - ax = self.axes[r, 0] - ax.set_title(rec_name) - df = perf_by_units.loc[perf_by_units["rec_name"] == rec_name, :] - df = pd.melt( - df, - id_vars="sorter_name", - var_name="Metric", - value_name="Score", - value_vars=("accuracy", "precision", "recall"), - ).sort_values("sorter_name") - sns.swarmplot( - data=df, x="sorter_name", y="Score", hue="Metric", dodge=True, s=3, ax=ax - ) # order=sorter_list, - # ~ ax.set_xticklabels(sorter_names_short, rotation=30, ha='center') - # ~ ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5) - - ax.set_ylim(0, 1.05) - ax.set_ylabel(f"Perfs for {rec_name}") - if r < len(study.rec_names) - 1: - ax.set_xlabel("") - ax.set(xticklabels=[]) - - -class StudyComparisonTemplateSimilarityWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - study = self.study - ax = self.ax - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean"] - - data = perf_by_units.groupby(["sorter_name", "rec_name"]).agg(to_agg) - - m = data.groupby("sorter_name").mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = data.groupby("sorter_name").std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values.flatten(), yerr=yerr.flatten(), width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - - ax.set_xticks(np.arange(sorter_names.size) + 1 + width) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("metric") - ax.set_xlim(0, sorter_names.size + 1) - - -class StudyComparisonPerformancesAveragesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - study = self.study - ax = self.ax - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean"] - - data = perf_by_units.groupby(["sorter_name", "rec_name"]).agg(to_agg) - - m = data.groupby("sorter_name").mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = data.groupby("sorter_name").std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values.flatten(), yerr=yerr.flatten(), width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - - ax.set_xticks(np.arange(sorter_names.size) + 1 + width) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("metric") - ax.set_xlim(0, sorter_names.size + 1) - - -class StudyComparisonPerformancesByTemplateSimilarity(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None, ylim=(0.6, 1), show_legend=True): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.show_legend = show_legend - self.ylim = ylim - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import sklearn - - cmap = plt.get_cmap(self.cmap_name, len(self.study.sorter_names)) - colors = [cmap(i) for i in range(len(self.study.sorter_names))] - - flat_templates_gt = {} - for rec_name in self.study.rec_names: - waveform_folder = self.study.study_folder / "waveforms" / f"waveforms_GroundTruth_{rec_name}" - if not waveform_folder.is_dir(): - self.study.compute_waveforms(rec_name) - - templates = self.study.get_templates(rec_name) - flat_templates_gt[rec_name] = templates.reshape(templates.shape[0], -1) - - all_results = {} - - for sorter_name in self.study.sorter_names: - all_results[sorter_name] = {"similarity": [], "accuracy": []} - - for rec_name in self.study.rec_names: - try: - waveform_folder = self.study.study_folder / "waveforms" / f"waveforms_{sorter_name}_{rec_name}" - if not waveform_folder.is_dir(): - self.study.compute_waveforms(rec_name, sorter_name) - templates = self.study.get_templates(rec_name, sorter_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix = sklearn.metrics.pairwise.cosine_similarity( - flat_templates_gt[rec_name], flat_templates - ) - - comp = self.study.comparisons[(rec_name, sorter_name)] - - for i, u1 in enumerate(comp.sorting1.unit_ids): - u2 = comp.best_match_12[u1] - if u2 != -1: - all_results[sorter_name]["similarity"] += [ - similarity_matrix[comp.sorting1.id_to_index(u1), comp.sorting2.id_to_index(u2)] - ] - all_results[sorter_name]["accuracy"] += [comp.agreement_scores.at[u1, u2]] - except Exception: - pass - - all_results[sorter_name]["similarity"] = np.array(all_results[sorter_name]["similarity"]) - all_results[sorter_name]["accuracy"] = np.array(all_results[sorter_name]["accuracy"]) - - from matplotlib.patches import Ellipse - - similarity_means = [all_results[sorter_name]["similarity"].mean() for sorter_name in self.study.sorter_names] - similarity_stds = [all_results[sorter_name]["similarity"].std() for sorter_name in self.study.sorter_names] - - accuracy_means = [all_results[sorter_name]["accuracy"].mean() for sorter_name in self.study.sorter_names] - accuracy_stds = [all_results[sorter_name]["accuracy"].std() for sorter_name in self.study.sorter_names] - - scount = 0 - for x, y, i, j in zip(similarity_means, accuracy_means, similarity_stds, accuracy_stds): - e = Ellipse((x, y), i, j) - e.set_alpha(0.2) - e.set_facecolor(colors[scount]) - self.ax.add_artist(e) - self.ax.scatter([x], [y], c=colors[scount], label=self.study.sorter_names[scount]) - scount += 1 - - self.ax.set_ylabel("accuracy") - self.ax.set_xlabel("cosine similarity") - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - if self.show_legend: - self.ax.legend() - - -def plot_gt_study_performances(*args, **kwargs): - W = StudyComparisonPerformancesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances.__doc__ = StudyComparisonPerformancesWidget.__doc__ - - -def plot_gt_study_performances_averages(*args, **kwargs): - W = StudyComparisonPerformancesAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances_averages.__doc__ = StudyComparisonPerformancesAveragesWidget.__doc__ - - -def plot_gt_study_performances_by_template_similarity(*args, **kwargs): - W = StudyComparisonPerformancesByTemplateSimilarity(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances_by_template_similarity.__doc__ = StudyComparisonPerformancesByTemplateSimilarity.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py deleted file mode 100644 index 6e6578a4c4..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ProbeMapWidget(BaseWidget): - """ - Plot the probe of a recording. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - channel_ids: list - The channel ids to display - with_channel_ids: bool False default - Add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function - - Returns - ------- - W: ProbeMapWidget - The output widget - """ - - def __init__(self, recording, channel_ids=None, with_channel_ids=False, figure=None, ax=None, **plot_probe_kwargs): - import matplotlib.pylab as plt - from probeinterface.plotting import plot_probe, get_auto_lims - - BaseWidget.__init__(self, figure, ax) - - if channel_ids is not None: - recording = recording.channel_slice(channel_ids) - self._recording = recording - self._probegroup = recording.get_probegroup() - self.with_channel_ids = with_channel_ids - self._plot_probe_kwargs = plot_probe_kwargs - - def plot(self): - self._do_plot() - - def _do_plot(self): - from probeinterface.plotting import get_auto_lims - - xlims, ylims, zlims = get_auto_lims(self._probegroup.probes[0]) - for i, probe in enumerate(self._probegroup.probes): - xlims2, ylims2, _ = get_auto_lims(probe) - xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) - ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) - - self._plot_probe_kwargs["title"] = False - pos = 0 - text_on_contact = None - for i, probe in enumerate(self._probegroup.probes): - n = probe.get_contact_count() - if self.with_channel_ids: - text_on_contact = self._recording.channel_ids[pos : pos + n] - pos += n - from probeinterface.plotting import plot_probe - - plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **self._plot_probe_kwargs) - - self.ax.set_xlim(*xlims) - self.ax.set_ylim(*ylims) - - -def plot_probe_map(*args, **kwargs): - W = ProbeMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_probe_map.__doc__ = ProbeMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py deleted file mode 100644 index d05373103e..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class RasterWidget(BaseWidget): - """ - Plots spike train rasters. - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - segment_index: None or int - The segment index. - unit_ids: list - List of unit ids - time_range: list - List with start time and end time - color: matplotlib color - The color to be used - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: RasterWidget - The output widget - """ - - def __init__(self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", figure=None, ax=None): - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._sorting = sorting - - if segment_index is None: - nseg = sorting.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - else: - segment_index = 0 - self.segment_index = segment_index - - self._unit_ids = unit_ids - self._figure = None - self._sampling_frequency = sorting.get_sampling_frequency() - self._color = color - self._max_frame = 0 - for unit_id in self._sorting.get_unit_ids(): - spike_train = self._sorting.get_unit_spike_train(unit_id, segment_index=self.segment_index) - if len(spike_train) > 0: - curr_max_frame = np.max(spike_train) - if curr_max_frame > self._max_frame: - self._max_frame = curr_max_frame - self._visible_trange = time_range - if self._visible_trange is None: - self._visible_trange = [0, self._max_frame] - else: - assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" - self._visible_trange = [int(t * self._sampling_frequency) for t in time_range] - - self._visible_trange = self._fix_trange(self._visible_trange) - self.name = "Raster" - - def plot(self): - self._do_plot() - - def _do_plot(self): - units_ids = self._unit_ids - if units_ids is None: - units_ids = self._sorting.get_unit_ids() - import matplotlib.pyplot as plt - - with plt.rc_context({"axes.edgecolor": "gray"}): - for u_i, unit_id in enumerate(units_ids): - spiketrain = self._sorting.get_unit_spike_train( - unit_id, - start_frame=self._visible_trange[0], - end_frame=self._visible_trange[1], - segment_index=self.segment_index, - ) - spiketimes = spiketrain / float(self._sampling_frequency) - self.ax.plot( - spiketimes, - u_i * np.ones_like(spiketimes), - marker="|", - mew=1, - markersize=3, - ls="", - color=self._color, - ) - visible_start_frame = self._visible_trange[0] / self._sampling_frequency - visible_end_frame = self._visible_trange[1] / self._sampling_frequency - self.ax.set_yticks(np.arange(len(units_ids))) - self.ax.set_yticklabels(units_ids) - self.ax.set_xlim(visible_start_frame, visible_end_frame) - self.ax.set_xlabel("time (s)") - - def _fix_trange(self, trange): - if trange[1] > self._max_frame: - # trange[0] += max_t - trange[1] - trange[1] = self._max_frame - if trange[0] < 0: - # trange[1] += -trange[0] - trange[0] = 0 - # trange[0] = np.maximum(0, trange[0]) - # trange[1] = np.minimum(max_t, trange[1]) - return trange - - -def plot_rasters(*args, **kwargs): - W = RasterWidget(*args, **kwargs) - W.plot() - return W - - -plot_rasters.__doc__ = RasterWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 5004765251..39eb80e2e5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -43,44 +43,6 @@ def setUp(self): def tearDown(self): pass - # def test_timeseries(self): - # sw.plot_timeseries(self._rec, mode='auto') - # sw.plot_timeseries(self._rec, mode='line', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True, order_channel_by_depth=True) - - def test_rasters(self): - sw.plot_rasters(self._sorting) - - def test_plot_probe_map(self): - sw.plot_probe_map(self._rec) - sw.plot_probe_map(self._rec, with_channel_ids=True) - - # TODO - # def test_spectrum(self): - # sw.plot_spectrum(self._rec) - - # TODO - # def test_spectrogram(self): - # sw.plot_spectrogram(self._rec, channel=0) - - # def test_unitwaveforms(self): - # w = sw.plot_unit_waveforms(self._we) - # unit_ids = self._sorting.unit_ids[:6] - # sw.plot_unit_waveforms(self._we, max_channels=5, unit_ids=unit_ids) - # sw.plot_unit_waveforms(self._we, radius_um=60, unit_ids=unit_ids) - - # def test_plot_unit_waveform_density_map(self): - # unit_ids = self._sorting.unit_ids[:3] - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=4) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=50) - # - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=25, same_axis=True) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=2, same_axis=True) - - # def test_unittemplates(self): - # sw.plot_unit_templates(self._we) - def test_plot_unit_probe_map(self): sw.plot_unit_probe_map(self._we, with_channel_ids=True) sw.plot_unit_probe_map(self._we, animated=True) @@ -120,12 +82,6 @@ def test_plot_peak_activity_map(self): sw.plot_peak_activity_map(self._rec, with_channel_ids=True) sw.plot_peak_activity_map(self._rec, bin_duration_s=1.0) - def test_confusion(self): - sw.plot_confusion_matrix(self._gt_comp, count_text=True) - - def test_agreement(self): - sw.plot_agreement_matrix(self._gt_comp, count_text=True) - def test_multicomp_graph(self): msc = sc.compare_multiple_sorters([self._sorting, self._sorting, self._sorting]) sw.plot_multicomp_graph(msc, edge_cmap="viridis", node_cmap="rainbow", draw_labels=False) @@ -150,8 +106,6 @@ def test_sorting_performance(self): mytest.setUp() # ~ mytest.test_timeseries() - # ~ mytest.test_rasters() - mytest.test_plot_probe_map() # ~ mytest.test_unitwaveforms() # ~ mytest.test_plot_unit_waveform_density_map() # mytest.test_unittemplates() @@ -169,8 +123,6 @@ def test_sorting_performance(self): # ~ mytest.test_plot_drift_over_time() # ~ mytest.test_plot_peak_activity_map() - # mytest.test_confusion() - # mytest.test_agreement() # ~ mytest.test_multicomp_graph() #  mytest.test_sorting_performance() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py deleted file mode 100644 index ab6fa2ace5..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py +++ /dev/null @@ -1,233 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from matplotlib.ticker import MaxNLocator -from .basewidget import BaseWidget - -import scipy.spatial - - -class TracesWidget(BaseWidget): - """ - Plots recording timeseries. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - segment_index: None or int - The segment index. - channel_ids: list - The channel ids to display. - order_channel_by_depth: boolean - Reorder channel by depth. - time_range: list - List with start time and end time - mode: 'line' or 'map' or 'auto' - 2 possible mode: - * 'line' : classical for low channel count - * 'map' : for high channel count use color heat map - * 'auto' : auto switch depending the channel count <32ch - cmap: str default 'RdBu' - matplotlib colormap used in mode 'map' - show_channel_ids: bool - Set yticks with channel ids - color_groups: bool - If True groups are plotted with different colors - color: matplotlib color, default: None - The color used to draw the traces. - clim: None or tupple - When mode='map' this control color lims - with_colorbar: bool default True - When mode='map' add colorbar - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: TracesWidget - The output widget - """ - - def __init__( - self, - recording, - segment_index=None, - channel_ids=None, - order_channel_by_depth=False, - time_range=None, - mode="auto", - cmap="RdBu", - show_channel_ids=False, - color_groups=False, - color=None, - clim=None, - with_colorbar=True, - figure=None, - ax=None, - **plot_kwargs, - ): - BaseWidget.__init__(self, figure, ax) - self.recording = recording - self._sampling_frequency = recording.get_sampling_frequency() - self.visible_channel_ids = channel_ids - self._plot_kwargs = plot_kwargs - - if segment_index is None: - nseg = recording.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 - self.segment_index = segment_index - - if self.visible_channel_ids is None: - self.visible_channel_ids = recording.get_channel_ids() - - if order_channel_by_depth: - locations = self.recording.get_channel_locations() - channel_inds = self.recording.ids_to_indices(self.visible_channel_ids) - locations = locations[channel_inds, :] - origin = np.array([np.max(locations[:, 0]), np.min(locations[:, 1])])[None, :] - dist = scipy.spatial.distance.cdist(locations, origin, metric="euclidean") - dist = dist[:, 0] - self.order = np.argsort(dist) - else: - self.order = None - - if channel_ids is None: - channel_ids = recording.get_channel_ids() - - fs = recording.get_sampling_frequency() - if time_range is None: - time_range = (0, 1.0) - time_range = np.array(time_range) - - assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map" - if mode == "auto": - if len(channel_ids) <= 64: - mode = "line" - else: - mode = "map" - self.mode = mode - self.cmap = cmap - - self.show_channel_ids = show_channel_ids - - self._frame_range = (time_range * fs).astype("int64") - a_max = self.recording.get_num_frames(segment_index=self.segment_index) - self._frame_range = np.clip(self._frame_range, 0, a_max) - self._time_range = [e / fs for e in self._frame_range] - - self.clim = clim - self.with_colorbar = with_colorbar - - self._initialize_stats() - - # self._vspacing = self._mean_channel_std * 20 - self._vspacing = self._max_channel_amp * 1.5 - - if recording.get_channel_groups() is None: - color_groups = False - - self._color_groups = color_groups - self._color = color - if color_groups: - self._colors = [] - self._group_color_map = {} - all_groups = recording.get_channel_groups() - groups = np.unique(all_groups) - N = len(groups) - import colorsys - - HSV_tuples = [(x * 1.0 / N, 0.5, 0.5) for x in range(N)] - self._colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)) - color_idx = 0 - for group in groups: - self._group_color_map[group] = color_idx - color_idx += 1 - self.name = "TimeSeries" - - def plot(self): - self._do_plot() - - def _do_plot(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - if self.order is not None: - chunk0 = chunk0[:, self.order] - self.visible_channel_ids = np.array(self.visible_channel_ids)[self.order] - - ax = self.ax - - n = len(self.visible_channel_ids) - - if self.mode == "line": - ax.set_xlim( - self._frame_range[0] / self._sampling_frequency, self._frame_range[1] / self._sampling_frequency - ) - ax.set_ylim(-self._vspacing, self._vspacing * n) - ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) - ax.get_yaxis().set_ticks([]) - ax.set_xlabel("time (s)") - - self._plots = {} - self._plot_offsets = {} - offset0 = self._vspacing * (n - 1) - times = np.arange(self._frame_range[0], self._frame_range[1]) / self._sampling_frequency - for im, m in enumerate(self.visible_channel_ids): - self._plot_offsets[m] = offset0 - if self._color_groups: - group = self.recording.get_channel_groups(channel_ids=[m])[0] - group_color_idx = self._group_color_map[group] - color = self._colors[group_color_idx] - else: - color = self._color - self._plots[m] = ax.plot(times, self._plot_offsets[m] + chunk0[:, im], color=color, **self._plot_kwargs) - offset0 = offset0 - self._vspacing - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) * self._vspacing) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - elif self.mode == "map": - extent = (self._time_range[0], self._time_range[1], 0, self.recording.get_num_channels()) - im = ax.imshow( - chunk0.T, interpolation="nearest", origin="upper", aspect="auto", extent=extent, cmap=self.cmap - ) - - if self.clim is None: - im.set_clim(-self._max_channel_amp, self._max_channel_amp) - else: - im.set_clim(*self.clim) - - if self.with_colorbar: - self.figure.colorbar(im, ax=ax) - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) + 0.5) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - def _initialize_stats(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - - self._mean_channel_std = np.mean(np.std(chunk0, axis=0)) - self._max_channel_amp = np.max(np.max(np.abs(chunk0), axis=0)) - - -def plot_timeseries(*args, **kwargs): - W = TracesWidget(*args, **kwargs) - W.plot() - return W - - -plot_timeseries.__doc__ = TracesWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py b/src/spikeinterface/widgets/agreement_matrix.py similarity index 53% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py rename to src/spikeinterface/widgets/agreement_matrix.py index 369746e99b..ec6ea1c87c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py +++ b/src/spikeinterface/widgets/agreement_matrix.py @@ -1,11 +1,13 @@ import numpy as np +from warnings import warn -from .basewidget import BaseWidget +from .base import BaseWidget, to_attr +from .utils import get_unit_colors class AgreementMatrixWidget(BaseWidget): """ - Plots sorting comparison confusion matrix. + Plots sorting comparison agreement matrix. Parameters ---------- @@ -19,31 +21,34 @@ class AgreementMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created + """ - def __init__(self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt + def __init__( + self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs + ): + plot_data = dict( + sorting_comparison=sorting_comparison, + ordered=ordered, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) - BaseWidget.__init__(self, figure, ax) - self._sc = sorting_comparison - self._ordered = ordered - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - def plot(self): - self._do_plot() + comp = dp.sorting_comparison - def _do_plot(self): - # a dataframe - if self._ordered: - scores = self._sc.get_ordered_agreement_scores() + if dp.ordered: + scores = comp.get_ordered_agreement_scores() else: - scores = self._sc.agreement_scores + scores = comp.agreement_scores N1 = scores.shape[0] N2 = scores.shape[1] @@ -54,9 +59,9 @@ def _do_plot(self): # Using matshow here just because it sets the ticks up nicely. imshow is faster. self.ax.matshow(scores.values, cmap="Greens") - if self._count_text: + if dp.count_text: for i, u1 in enumerate(unit_ids1): - u2 = self._sc.best_match_12[u1] + u2 = comp.best_match_12[u1] if u2 != -1: j = np.where(unit_ids2 == u2)[0][0] @@ -68,24 +73,15 @@ def _do_plot(self): self.ax.xaxis.tick_bottom() # Labels for major ticks - if self._unit_ticks: + if dp.unit_ticks: self.ax.set_yticklabels(scores.index, fontsize=12) self.ax.set_xticklabels(scores.columns, fontsize=12) - self.ax.set_xlabel(self._sc.name_list[1], fontsize=20) - self.ax.set_ylabel(self._sc.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) self.ax.set_xlim(-0.5, N2 - 0.5) self.ax.set_ylim( N1 - 0.5, -0.5, ) - - -def plot_agreement_matrix(*args, **kwargs): - W = AgreementMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_agreement_matrix.__doc__ = AgreementMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index dea46b8f51..4ed83fcca9 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -39,12 +39,14 @@ def set_default_plotter_backend(backend): "height_cm": "Height of the figure in cm (default 6)", "display": "If True, widgets are immediately displayed", }, + "ephyviewer": {}, } default_backend_kwargs = { "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True}, + "ephyviewer": {}, } diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py b/src/spikeinterface/widgets/confusion_matrix.py similarity index 62% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py rename to src/spikeinterface/widgets/confusion_matrix.py index 942b613fbf..8eb58f30b2 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py +++ b/src/spikeinterface/widgets/confusion_matrix.py @@ -1,6 +1,8 @@ import numpy as np +from warnings import warn -from .basewidget import BaseWidget +from .base import BaseWidget, to_attr +from .utils import get_unit_colors class ConfusionMatrixWidget(BaseWidget): @@ -15,40 +17,35 @@ class ConfusionMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ConfusionMatrixWidget - The output widget + """ - def __init__(self, gt_comparison, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt + def __init__(self, gt_comparison, count_text=True, unit_ticks=True, backend=None, **backend_kwargs): + plot_data = dict( + gt_comparison=gt_comparison, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure - BaseWidget.__init__(self, figure, ax) - self._gtcomp = gt_comparison - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" + dp = to_attr(data_plot) - def plot(self): - self._do_plot() + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - def _do_plot(self): - # a dataframe - confusion_matrix = self._gtcomp.get_confusion_matrix() + comp = dp.gt_comparison + confusion_matrix = comp.get_confusion_matrix() N1 = confusion_matrix.shape[0] - 1 N2 = confusion_matrix.shape[1] - 1 # Using matshow here just because it sets the ticks up nicely. imshow is faster. self.ax.matshow(confusion_matrix.values, cmap="Greens") - if self._count_text: + if dp.count_text: for (i, j), z in np.ndenumerate(confusion_matrix.values): if z != 0: if z > np.max(confusion_matrix.values) / 2.0: @@ -65,27 +62,18 @@ def _do_plot(self): self.ax.xaxis.tick_bottom() # Labels for major ticks - if self._unit_ticks: + if dp.unit_ticks: self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) else: self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) - self.ax.set_xlabel(self._gtcomp.name_list[1], fontsize=20) - self.ax.set_ylabel(self._gtcomp.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) self.ax.set_xlim(-0.5, N2 + 0.5) self.ax.set_ylim( N1 + 0.5, -0.5, ) - - -def plot_confusion_matrix(*args, **kwargs): - W = ConfusionMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_confusion_matrix.__doc__ = ConfusionMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py new file mode 100644 index 0000000000..6a27b78dec --- /dev/null +++ b/src/spikeinterface/widgets/gtstudy.py @@ -0,0 +1,253 @@ +import numpy as np + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + +from ..core import ChannelSparsity +from ..core.waveform_extractor import WaveformExtractor +from ..core.basesorting import BaseSorting + + +class StudyRunTimesWidget(BaseWidget): + """ + Plot sorter run times for a GroundTruthStudy + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + run_times=study.get_run_times(case_keys), + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + for i, key in enumerate(dp.case_keys): + label = dp.study.cases[key]["label"] + rt = dp.run_times.loc[key] + self.ax.bar(i, rt, width=0.8, label=label) + + self.ax.legend() + + +# TODO : plot optionally average on some levels using group by +class StudyUnitCountsWidget(BaseWidget): + """ + Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + count_units=study.get_count_units(case_keys=case_keys), + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + columns = dp.count_units.columns.tolist() + columns.remove("num_gt") + columns.remove("num_sorter") + + ncol = len(columns) + + colors = get_some_colors(columns, color_engine="auto", map_name="hot") + colors["num_well_detected"] = "green" + + xticklabels = [] + for i, key in enumerate(dp.case_keys): + for c, col in enumerate(columns): + x = i + 1 + c / (ncol + 1) + y = dp.count_units.loc[key, col] + if not "well_detected" in col: + y = -y + + if i == 0: + label = col.replace("num_", "").replace("_", " ").title() + else: + label = None + + self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) + + xticklabels.append(dp.study.cases[key]["label"]) + + self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) + self.ax.set_xticklabels(xticklabels) + self.ax.legend() + + +# TODO : plot optionally average on some levels using group by +class StudyPerformances(BaseWidget): + """ + Plot performances over case for a study. + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + mode: str + Which mode in "swarm" + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + mode="swarm", + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + perfs=study.get_performance_by_unit(case_keys=case_keys), + mode=mode, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + import pandas as pd + import seaborn as sns + + dp = to_attr(data_plot) + perfs = dp.perfs + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + if dp.mode == "swarm": + levels = perfs.index.names + df = pd.melt( + perfs.reset_index(), + id_vars=levels, + var_name="Metric", + value_name="Score", + value_vars=("accuracy", "precision", "recall"), + ) + df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) + + +class StudyPerformancesVsMetrics(BaseWidget): + """ + Plot performances vs a metrics (snr for instance) over case for a study. + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + mode: str + Which mode in "swarm" + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + metric_name="snr", + performance_name="accuracy", + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + metric_name=metric_name, + performance_name=performance_name, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + dp = to_attr(data_plot) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + study = dp.study + perfs = study.get_performance_by_unit(case_keys=dp.case_keys) + + max_metric = 0 + for key in dp.case_keys: + x = study.get_metrics(key)[dp.metric_name].values + y = perfs.xs(key)[dp.performance_name].values + label = dp.study.cases[key]["label"] + self.ax.scatter(x, y, label=label) + max_metric = max(max_metric, np.max(x)) + + self.ax.legend() + self.ax.set_xlim(0, max_metric * 1.05) + self.ax.set_ylim(0, 1.05) diff --git a/src/spikeinterface/widgets/probe_map.py b/src/spikeinterface/widgets/probe_map.py new file mode 100644 index 0000000000..7fb74abd7c --- /dev/null +++ b/src/spikeinterface/widgets/probe_map.py @@ -0,0 +1,75 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs +from .utils import get_unit_colors + + +class ProbeMapWidget(BaseWidget): + """ + Plot the probe of a recording. + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object + channel_ids: list + The channel ids to display + with_channel_ids: bool False default + Add channel ids text on the probe + **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function + + """ + + def __init__( + self, recording, channel_ids=None, with_channel_ids=False, backend=None, **backend_or_plot_probe_kwargs + ): + # split backend_or_plot_probe_kwargs + backend_kwargs = dict() + plot_probe_kwargs = dict() + backend = self.check_backend(backend) + for k, v in backend_or_plot_probe_kwargs.items(): + if k in default_backend_kwargs[backend]: + backend_kwargs[k] = v + else: + plot_probe_kwargs[k] = v + + plot_data = dict( + recording=recording, + channel_ids=channel_ids, + with_channel_ids=with_channel_ids, + plot_probe_kwargs=plot_probe_kwargs, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from probeinterface.plotting import get_auto_lims, plot_probe + + dp = to_attr(data_plot) + + plot_probe_kwargs = dp.plot_probe_kwargs + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + probegroup = dp.recording.get_probegroup() + + xlims, ylims, zlims = get_auto_lims(probegroup.probes[0]) + for i, probe in enumerate(probegroup.probes): + xlims2, ylims2, _ = get_auto_lims(probe) + xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) + ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) + + plot_probe_kwargs["title"] = False + pos = 0 + text_on_contact = None + for i, probe in enumerate(probegroup.probes): + n = probe.get_contact_count() + if dp.with_channel_ids: + text_on_contact = dp.recording.channel_ids[pos : pos + n] + pos += n + plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **plot_probe_kwargs) + + self.ax.set_xlim(*xlims) + self.ax.set_ylim(*ylims) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py new file mode 100644 index 0000000000..4a1d76279f --- /dev/null +++ b/src/spikeinterface/widgets/rasters.py @@ -0,0 +1,84 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs + + +class RasterWidget(BaseWidget): + """ + Plots spike train rasters. + + Parameters + ---------- + sorting: SortingExtractor + The sorting extractor object + segment_index: None or int + The segment index. + unit_ids: list + List of unit ids + time_range: list + List with start time and end time + color: matplotlib color + The color to be used + """ + + def __init__( + self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", backend=None, **backend_kwargs + ): + if segment_index is None: + if sorting.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 + + if time_range is None: + frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]] + time_range = [f / sorting.sampling_frequency for f in frame_range] + else: + assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" + frame_range = [int(t * sorting.sampling_frequency) for t in time_range] + + plot_data = dict( + sorting=sorting, + segment_index=segment_index, + unit_ids=unit_ids, + color=color, + frame_range=frame_range, + time_range=time_range, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + sorting = dp.sorting + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + units_ids = dp.unit_ids + if units_ids is None: + units_ids = sorting.unit_ids + + with plt.rc_context({"axes.edgecolor": "gray"}): + for unit_index, unit_id in enumerate(units_ids): + spiketrain = sorting.get_unit_spike_train( + unit_id, + start_frame=dp.frame_range[0], + end_frame=dp.frame_range[1], + segment_index=dp.segment_index, + ) + spiketimes = spiketrain / float(sorting.sampling_frequency) + self.ax.plot( + spiketimes, + unit_index * np.ones_like(spiketimes), + marker="|", + mew=1, + markersize=3, + ls="", + color=dp.color, + ) + self.ax.set_yticks(np.arange(len(units_ids))) + self.ax.set_yticklabels(units_ids) + self.ax.set_xlim(*dp.time_range) + self.ax.set_xlabel("time (s)") diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index a5f75ebf50..f44878927d 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -72,7 +72,7 @@ def setUpClass(cls): else: cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) - cls.skip_backends = ["ipywidgets"] + cls.skip_backends = ["ipywidgets", "ephyviewer"] if ON_GITHUB and not KACHERY_CLOUD_SET: cls.skip_backends.append("sortingview") @@ -324,6 +324,30 @@ def test_sorting_summary(self): sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + def test_plot_agreement_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_agreement_matrix(self.gt_comp) + + def test_plot_confusion_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_confusion_matrix(self.gt_comp) + + def test_plot_probe_map(self): + possible_backends = list(sw.ProbeMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_probe_map(self.recording, with_channel_ids=True, with_contact_id=True) + + def test_plot_rasters(self): + possible_backends = list(sw.RasterWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_rasters(self.sorting) + if __name__ == "__main__": # unittest.main() @@ -344,7 +368,11 @@ def test_sorting_summary(self): # mytest.test_unit_locations() # mytest.test_quality_metrics() # mytest.test_template_metrics() - mytest.test_amplitudes() + # mytest.test_amplitudes() + # mytest.test_plot_agreement_matrix() + # mytest.test_plot_confusion_matrix() + # mytest.test_plot_probe_map() + mytest.test_plot_rasters() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index e025f779c1..7bb2126744 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -524,6 +524,30 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **backend_kwargs) + def plot_ephyviewer(self, data_plot, **backend_kwargs): + import ephyviewer + from ..preprocessing import depth_order + + dp = to_attr(data_plot) + + app = ephyviewer.mkQApp() + win = ephyviewer.MainViewer(debug=False, show_auto_scale=True) + + for k, rec in dp.recordings.items(): + if dp.order_channel_by_depth: + rec = depth_order(rec, flip=True) + + sig_source = ephyviewer.SpikeInterfaceRecordingSource(recording=rec) + view = ephyviewer.TraceViewer(source=sig_source, name=k) + view.params["scale_mode"] = "by_channel" + if dp.show_channel_ids: + view.params["display_labels"] = True + view.auto_scale() + win.add_view(view) + + win.show() + app.exec() + def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): # function also used in ipywidgets plotter diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 9c89b3981e..ed77de6128 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -2,12 +2,16 @@ from .base import backend_kwargs_desc +from .agreement_matrix import AgreementMatrixWidget from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget +from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .motion import MotionWidget +from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget +from .rasters import RasterWidget from .sorting_summary import SortingSummaryWidget from .spike_locations import SpikeLocationsWidget from .spikes_on_traces import SpikesOnTracesWidget @@ -20,15 +24,20 @@ from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget +from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics widget_list = [ + AgreementMatrixWidget, AllAmplitudesDistributionsWidget, AmplitudesWidget, AutoCorrelogramsWidget, + ConfusionMatrixWidget, CrossCorrelogramsWidget, MotionWidget, + ProbeMapWidget, QualityMetricsWidget, + RasterWidget, SortingSummaryWidget, SpikeLocationsWidget, SpikesOnTracesWidget, @@ -41,6 +50,10 @@ UnitTemplatesWidget, UnitWaveformDensityMapWidget, UnitWaveformsWidget, + StudyRunTimesWidget, + StudyUnitCountsWidget, + StudyPerformances, + StudyPerformancesVsMetrics, ] @@ -76,12 +89,16 @@ # make function for all widgets +plot_agreement_matrix = AgreementMatrixWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget +plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_motion = MotionWidget +plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget +plot_rasters = RasterWidget plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget plot_spikes_on_traces = SpikesOnTracesWidget @@ -94,6 +111,10 @@ plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_waveforms = UnitWaveformsWidget +plot_study_run_times = StudyRunTimesWidget +plot_study_unit_counts = StudyUnitCountsWidget +plot_study_performances = StudyPerformances +plot_stufy_performances_vs_metrics = StudyPerformancesVsMetrics def plot_timeseries(*args, **kwargs):