diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 004fe31203..7241f60a8b 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -37,6 +37,11 @@ runs: - name: git-annex install run: | wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz + mkdir /home/runner/work/installation + mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/ + workdir=$(pwd) + cd /home/runner/work/installation tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH + cd $workdir shell: bash diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index 0e522e6baa..b3bf08954d 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -30,4 +30,4 @@ jobs: - name: Test Conda Environment Creation uses: conda-incubator/setup-miniconda@v2.2.0 with: - environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml + environment-file: ./installation_tips/full_spikeinterface_environment_${{ matrix.label }}.yml diff --git a/.github/workflows/test_containers_singularity_gpu.yml b/.github/workflows/test_containers_singularity_gpu.yml index e74fbeb4a5..d075f5a6ef 100644 --- a/.github/workflows/test_containers_singularity_gpu.yml +++ b/.github/workflows/test_containers_singularity_gpu.yml @@ -46,5 +46,6 @@ jobs: - name: Run test singularity containers with GPU env: REPO_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} + SPIKEINTERFACE_DEV_PATH: ${{ github.workspace }} run: | pytest -vv --capture=tee-sys -rA src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py diff --git a/.gitignore b/.gitignore index 3ee3cb8867..7838213bed 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,4 @@ test_folder/ # Mac OS .DS_Store +test_data.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07601cd208..7153a7dfc0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer diff --git a/README.md b/README.md index 55f33d04b1..883dcdb944 100644 --- a/README.md +++ b/README.md @@ -59,15 +59,17 @@ With SpikeInterface, users can: - post-process sorted datasets. - compare and benchmark spike sorting outputs. - compute quality metrics to validate and curate spike sorting outputs. -- visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, in jupyter) -- export report and export to phy -- offer a powerful Qt-based viewer in separate package [spikeinterface-gui](https://github.com/SpikeInterface/spikeinterface-gui) -- have some powerful sorting components to build your own sorter. +- visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, jupyter, ephyviewer) +- export a report and/or export to phy +- offer a powerful Qt-based viewer in a separate package [spikeinterface-gui](https://github.com/SpikeInterface/spikeinterface-gui) +- have powerful sorting components to build your own sorter. ## Documentation -Detailed documentation for spikeinterface can be found [here](https://spikeinterface.readthedocs.io/en/latest). +Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.98.2). + +Detailed documentation of the development version of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/latest). Several tutorials to get started can be found in [spiketutorials](https://github.com/SpikeInterface/spiketutorials). @@ -77,9 +79,9 @@ and sorting components. You can also have a look at the [spikeinterface-gui](https://github.com/SpikeInterface/spikeinterface-gui). -## How to install spikeinteface +## How to install spikeinterface -You can install the new `spikeinterface` version with pip: +You can install the latest version of `spikeinterface` version with pip: ```bash pip install spikeinterface[full] @@ -94,7 +96,7 @@ To install all interactive widget backends, you can use: ``` -To get the latest updates, you can install `spikeinterface` from sources: +To get the latest updates, you can install `spikeinterface` from source: ```bash git clone https://github.com/SpikeInterface/spikeinterface.git diff --git a/doc/api.rst b/doc/api.rst index 43f79386e6..97c956c2f6 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -19,6 +19,8 @@ spikeinterface.core .. autofunction:: extract_waveforms .. autofunction:: load_waveforms .. autofunction:: compute_sparsity + .. autoclass:: ChannelSparsity + :members: .. autoclass:: BinaryRecordingExtractor .. autoclass:: ZarrRecordingExtractor .. autoclass:: BinaryFolderRecording @@ -48,10 +50,6 @@ spikeinterface.core .. autofunction:: get_template_extremum_channel .. autofunction:: get_template_extremum_channel_peak_shift .. autofunction:: get_template_extremum_amplitude - -.. - .. autofunction:: read_binary - .. autofunction:: read_zarr .. autofunction:: append_recordings .. autofunction:: concatenate_recordings .. autofunction:: split_recording @@ -59,6 +57,8 @@ spikeinterface.core .. autofunction:: append_sortings .. autofunction:: split_sorting .. autofunction:: select_segment_sorting + .. autofunction:: read_binary + .. autofunction:: read_zarr Low-level ~~~~~~~~~ @@ -67,7 +67,6 @@ Low-level :noindex: .. autoclass:: BaseWaveformExtractorExtension - .. autoclass:: ChannelSparsity .. autoclass:: ChunkRecordingExecutor spikeinterface.extractors @@ -83,6 +82,7 @@ NEO-based .. autofunction:: read_alphaomega_event .. autofunction:: read_axona .. autofunction:: read_biocam + .. autofunction:: read_binary .. autofunction:: read_blackrock .. autofunction:: read_ced .. autofunction:: read_intan @@ -104,6 +104,7 @@ NEO-based .. autofunction:: read_spikegadgets .. autofunction:: read_spikeglx .. autofunction:: read_tdt + .. autofunction:: read_zarr Non-NEO-based @@ -216,8 +217,10 @@ spikeinterface.sorters .. autofunction:: print_sorter_versions .. autofunction:: get_sorter_description .. autofunction:: run_sorter + .. autofunction:: run_sorter_jobs .. autofunction:: run_sorters .. autofunction:: run_sorter_by_property + .. autofunction:: read_sorter_folder Low level ~~~~~~~~~ diff --git a/doc/conf.py b/doc/conf.py index 15cb65d46a..13d1ef4e65 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -118,14 +118,15 @@ 'examples_dirs': ['../examples/modules_gallery'], 'gallery_dirs': ['modules_gallery', ], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ - '../examples/modules_gallery/core/', - '../examples/modules_gallery/extractors/', + '../examples/modules_gallery/core', + '../examples/modules_gallery/extractors', '../examples/modules_gallery/qualitymetrics', '../examples/modules_gallery/comparison', '../examples/modules_gallery/widgets', ]), 'within_subsection_order': FileNameSortKey, 'ignore_pattern': '/generate_', + 'nested_sections': False, } intersphinx_mapping = { diff --git a/doc/development/development.rst b/doc/development/development.rst index f1371639c3..7656da11ab 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -14,7 +14,7 @@ There are various ways to contribute to SpikeInterface as a user or developer. S * Writing unit tests to expand code coverage and use case scenarios. * Reporting bugs and issues. -We use a forking workflow _ to manage contributions. Here's a summary of the steps involved, with more details available in the provided link: +We use a forking workflow ``_ to manage contributions. Here's a summary of the steps involved, with more details available in the provided link: * Fork the SpikeInterface repository. * Create a new branch (e.g., :code:`git switch -c my-contribution`). @@ -22,7 +22,7 @@ We use a forking workflow _ . +While we appreciate all the contributions please be mindful of the cost of reviewing pull requests ``_ . How to run tests locally @@ -201,7 +201,7 @@ Implement a new extractor SpikeInterface already supports over 30 file formats, but the acquisition system you use might not be among the supported formats list (***ref***). Most of the extractord rely on the `NEO `_ package to read information from files. -Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new `neo.rawio `_ class. +Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new :code:`neo.rawio.BaseRawIO` class (see `example `_). Once that is done, the new class can be easily wrapped into SpikeInterface as an extension of the :py:class:`~spikeinterface.extractors.neoextractors.neobaseextractors.NeoBaseRecordingExtractor` (for :py:class:`~spikeinterface.core.BaseRecording` objects) or diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index dabad818f9..da94cf549c 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -7,3 +7,4 @@ How to guides get_started analyse_neuropixels handle_drift + load_matlab_data diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst new file mode 100644 index 0000000000..54a66c0890 --- /dev/null +++ b/doc/how_to/load_matlab_data.rst @@ -0,0 +1,101 @@ +Exporting MATLAB Data to Binary & Loading in SpikeInterface +=========================================================== + +In this tutorial, we will walk through the process of exporting data from MATLAB in a binary format and subsequently loading it using SpikeInterface in Python. + +Exporting Data from MATLAB +-------------------------- + +Begin by ensuring your data structure is correct. Organize your data matrix so that the first dimension corresponds to samples/time and the second to channels. +Here, we present a MATLAB code that creates a random dataset and writes it to a binary file as an illustration. + +.. code-block:: matlab + + % Define the size of your data + numSamples = 1000; + numChannels = 384; + + % Generate random data as an example + data = rand(numSamples, numChannels); + + % Write the data to a binary file + fileID = fopen('your_data_as_a_binary.bin', 'wb'); + fwrite(fileID, data, 'double'); + fclose(fileID); + +.. note:: + + In your own script, replace the random data generation with your actual dataset. + +Loading Data in SpikeInterface +------------------------------ + +After executing the above MATLAB code, a binary file named :code:`your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path. + +Use the following Python script to load the binary data into SpikeInterface: + +.. code-block:: python + + import spikeinterface as si + from pathlib import Path + + # Define file path + # For Linux or macOS: + file_path = Path("/The/Path/To/Your/Data/your_data_as_a_binary.bin") + # For Windows: + # file_path = Path(r"c:\path\to\your\data\your_data_as_a_binary.bin") + + # Confirm file existence + assert file_path.is_file(), f"Error: {file_path} is not a valid file. Please check the path." + + # Define recording parameters + sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset + num_channels = 384 # Adjust according to your MATLAB dataset + dtype = "float64" # MATLAB's double corresponds to Python's float64 + + # Load data using SpikeInterface + recording = si.read_binary(file_paths=file_path, sampling_frequency=sampling_frequency, + num_channels=num_channels, dtype=dtype) + + # Confirm that the data was loaded correctly by comparing the data shapes and see they match the MATLAB data + print(recording.get_num_frames(), recording.get_num_channels()) + +Follow the steps above to seamlessly import your MATLAB data into SpikeInterface. Once loaded, you can harness the full power of SpikeInterface for data processing, including filtering, spike sorting, and more. + +Common Pitfalls & Tips +---------------------- + +1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use :code:`time_axis=1` in :code:`si.read_binary()`. +2. **File Path**: Always double-check the Python file path. +3. **Data Type Consistency**: Ensure data types between MATLAB and Python are consistent. MATLAB's `double` is equivalent to Numpy's `float64`. +4. **Sampling Frequency**: Set the appropriate sampling frequency in Hz for SpikeInterface. +5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's `Numpy for MATLAB Users `_ guide. + +Using gains and offsets for integer data +---------------------------------------- + +Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units, you can apply a gain and an offset. +In SpikeInterface, you can use the :code:`gain_to_uV` and :code:`offset_to_uV` parameters, since traces are handled in microvolts (uV). Both parameters can be integrated into the :code:`read_binary` function. +If your data in MATLAB is stored as :code:`int16`, and you know the gain and offset, you can use the following code to load the data: + +.. code-block:: python + + sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset + num_channels = 384 # Adjust according to your MATLAB dataset + dtype_int = 'int16' # Adjust according to your MATLAB dataset + gain_to_uV = 0.195 # Adjust according to your MATLAB dataset + offset_to_uV = 0 # Adjust according to your MATLAB dataset + + recording = si.read_binary(file_paths=file_path, sampling_frequency=sampling_frequency, + num_channels=num_channels, dtype=dtype_int, + gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV) + + recording.get_traces() # Return traces in original units [type: int] + recording.get_traces(return_scaled=True) # Return traces in micro volts (uV) [type: float] + + +This will equip your recording object with capabilities to convert the data to float values in uV using the :code:`get_traces()` method with the :code:`return_scaled` parameter set to :code:`True`. + +.. note:: + + The gain and offset parameters are usually format dependent and you will need to find out the correct values for your data format. You can load your data without gain and offset but then the traces will be in integer values and not in uV. diff --git a/doc/images/plot_traces_ephyviewer.png b/doc/images/plot_traces_ephyviewer.png new file mode 100644 index 0000000000..9d926725a4 Binary files /dev/null and b/doc/images/plot_traces_ephyviewer.png differ diff --git a/doc/install_sorters.rst b/doc/install_sorters.rst index 3fda05848c..10a3185c5c 100644 --- a/doc/install_sorters.rst +++ b/doc/install_sorters.rst @@ -117,7 +117,7 @@ Kilosort2.5 git clone https://github.com/MouseLand/Kilosort # provide installation path by setting the KILOSORT2_5_PATH environment variable - # or using Kilosort2_5Sorter.set_kilosort2_path() + # or using Kilosort2_5Sorter.set_kilosort2_5_path() * See also for Matlab/CUDA: https://www.mathworks.com/help/parallel-computing/gpu-support-by-release.html diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index b452307e3c..76ab7855c6 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -248,21 +248,19 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for We also have a high level class to compare many sorters against ground truth: :py:func:`~spiekinterface.comparison.GroundTruthStudy()` -A study is a systematic performance comparison of several ground truth recordings with several sorters. +A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases +like the different parameter sets. -The study class proposes high-level tool functions to run many ground truth comparisons with many sorters +The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" on many recordings and then collect and aggregate results in an easy way. The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolder: - * raw_files : contain a copy of recordings in binary format - * sorter_folders : contains outputs of sorters - * ground_truth : contains a copy of sorting ground truth in npz format - * sortings: contains light copy of all sorting in npz format - * tables: some tables in csv format - -In order to run and rerun the computation all gt_sorting and recordings are copied to a fast and universal format: -binary (for recordings) and npz (for sortings). + * datasets: contains ground truth datasets + * sorters : contains outputs of sorters + * sortings: contains light copy of all sorting + * metrics: contains metrics + * ... .. code-block:: python @@ -274,28 +272,51 @@ binary (for recordings) and npz (for sortings). import spikeinterface.widgets as sw from spikeinterface.comparison import GroundTruthStudy - # Setup study folder - rec0, gt_sorting0 = se.toy_example(num_channels=4, duration=10, seed=10, num_segments=1) - rec1, gt_sorting1 = se.toy_example(num_channels=4, duration=10, seed=0, num_segments=1) - gt_dict = { - 'rec0': (rec0, gt_sorting0), - 'rec1': (rec1, gt_sorting1), + + # generate 2 simulated datasets (could be also mearec files) + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + + datasets = { + "toy0": (rec0, gt_sorting0), + "toy1": (rec1, gt_sorting1), } - study_folder = 'a_study_folder' - study = GroundTruthStudy.create(study_folder, gt_dict) - # all sorters for all recordings in one function. - sorter_list = ['herdingspikes', 'tridesclous', ] - study.run_sorters(sorter_list, mode_if_folder_exists="keep") + # define some "cases" here we want to tests tridesclous2 on 2 datasets and spykingcircus on one dataset + # so it is a two level study (sorter_name, dataset) + # this could be more complicated like (sorter_name, dataset, params) + cases = { + ("tdc2", "toy0"): { + "label": "tridesclous2 on tetrode0", + "dataset": "toy0", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + }, + ("tdc2", "toy1"): { + "label": "tridesclous2 on tetrode1", + "dataset": "toy1", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + }, + + ("sc", "toy0"): { + "label": "spykingcircus2 on tetrode0", + "dataset": "toy0", + "run_sorter_params": { + "sorter_name": "spykingcircus", + "docker_image": True + }, + }, + } + # this initilize a folder + study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases, + levels=["sorter_name", "dataset"]) - # You can re-run **run_study_sorters** as many times as you want. - # By default **mode='keep'** so only uncomputed sorters are re-run. - # For instance, just remove the "sorter_folders/rec1/herdingspikes" to re-run - # only one sorter on one recording. - # - # Then we copy the spike sorting outputs into a separate subfolder. - # This allow us to remove the "large" sorter_folders. - study.copy_sortings() + + # all cases in one function + study.run_sorters() # Collect comparisons #   @@ -306,11 +327,11 @@ binary (for recordings) and npz (for sortings). # Note: use exhaustive_gt=True when you know exactly how many # units in ground truth (for synthetic datasets) + # run all comparisons and loop over the results study.run_comparisons(exhaustive_gt=True) - - for (rec_name, sorter_name), comp in study.comparisons.items(): + for key, comp in study.comparisons.items(): print('*' * 10) - print(rec_name, sorter_name) + print(key) # raw counting of tp/fp/... print(comp.count_score) # summary @@ -323,26 +344,27 @@ binary (for recordings) and npz (for sortings). # Collect synthetic dataframes and display # As shown previously, the performance is returned as a pandas dataframe. - # The :py:func:`~spikeinterface.comparison.aggregate_performances_table()` function, + # The :py:func:`~spikeinterface.comparison.get_performance_by_unit()` function, # gathers all the outputs in the study folder and merges them in a single dataframe. + # Same idea for :py:func:`~spikeinterface.comparison.get_count_units()` - dataframes = study.aggregate_dataframes() + # this is a dataframe + perfs = study.get_performance_by_unit() - # Pandas dataframes can be nicely displayed as tables in the notebook. - print(dataframes.keys()) + # this is a dataframe + unit_counts = study.get_count_units() # we can also access run times - print(dataframes['run_times']) + run_times = study.get_run_times() + print(run_times) # Easy plot with seaborn - run_times = dataframes['run_times'] fig1, ax1 = plt.subplots() sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1) ax1.set_title('Run times') ############################################################################## - perfs = dataframes['perf_by_unit'] fig2, ax2 = plt.subplots() sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2) ax2.set_title('Recall') diff --git a/doc/modules/core.rst b/doc/modules/core.rst index fdc4d71fe7..976a82a4a3 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -547,8 +547,7 @@ workflow. In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikeinterface.core.NumpyRecording`, :py:class:`~spikeinterface.core.NumpySorting`, :py:class:`~spikeinterface.core.NumpyEvent`, and :py:class:`~spikeinterface.core.NumpySnippets`. These object behave exactly like normal SpikeInterface objects, -but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported. -In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`). +but they are not bound to a file. Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 6101b81552..23e9e20d96 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -24,21 +24,21 @@ The merging and splitting operations are handled by the :py:class:`~spikeinterfa from spikeinterface.curation import CurationSorting - sorting = run_sorter('kilosort2', recording) + sorting = run_sorter(sorter_name='kilosort2', recording=recording) - cs = CurationSorting(sorting) + cs = CurationSorting(parent_sorting=sorting) # make a first merge - cs.merge(['#1', '#5', '#15']) + cs.merge(units_to_merge=['#1', '#5', '#15']) # make a second merge - cs.merge(['#11', '#21']) + cs.merge(units_to_merge=['#11', '#21']) # make a split split_index = ... # some criteria on spikes - cs.split('#20', split_index) + cs.split(split_unit_id='#20', indices_list=split_index) - # here the final clean sorting + # here is the final clean sorting clean_sorting = cs.sorting @@ -60,12 +60,12 @@ merges. Therefore, it has many parameters and options. from spikeinterface.curation import MergeUnitsSorting, get_potential_auto_merge - sorting = run_sorter('kilosort', recording) + sorting = run_sorter(sorter_name='kilosort', recording=recording) - we = extract_waveforms(recording, sorting, folder='wf_folder') + we = extract_waveforms(recording=recording, sorting=sorting, folder='wf_folder') # merges is a list of lists, with unit_ids to be merged. - merges = get_potential_auto_merge(we, minimum_spikes=1000, maximum_distance_um=150., + merges = get_potential_auto_merge(waveform_extractor=we, minimum_spikes=1000, maximum_distance_um=150., peak_sign="neg", bin_ms=0.25, window_ms=100., corr_diff_thresh=0.16, template_diff_thresh=0.25, censored_period_ms=0., refractory_period_ms=1.0, @@ -73,7 +73,7 @@ merges. Therefore, it has many parameters and options. firing_contamination_balance=1.5) # here we apply the merges - clean_sorting = MergeUnitsSorting(sorting, merges) + clean_sorting = MergeUnitsSorting(parent_sorting=sorting, units_to_merge=merges) Manual curation with sorting view @@ -98,24 +98,24 @@ The manual curation (including merges and labels) can be applied to a SpikeInter from spikeinterface.widgets import plot_sorting_summary # run a sorter and export waveforms - sorting = run_sorter('kilosort2', recording) - we = extract_waveforms(recording, sorting, folder='wf_folder') + sorting = run_sorter(sorter_name'kilosort2', recording=recording) + we = extract_waveforms(recording=recording, sorting=sorting, folder='wf_folder') # some postprocessing is required - _ = compute_spike_amplitudes(we) - _ = compute_unit_locations(we) - _ = compute_template_similarity(we) - _ = compute_correlograms(we) + _ = compute_spike_amplitudes(waveform_extractor=we) + _ = compute_unit_locations(waveform_extractor=we) + _ = compute_template_similarity(waveform_extractor=we) + _ = compute_correlograms(waveform_extractor=we) # This loads the data to the cloud for web-based plotting and sharing - plot_sorting_summary(we, curation=True, backend='sortingview') + plot_sorting_summary(waveform_extractor=we, curation=True, backend='sortingview') # we open the printed link URL in a browswe # - make manual merges and labeling # - from the curation box, click on "Save as snapshot (sha1://)" # copy the uri sha_uri = "sha1://59feb326204cf61356f1a2eb31f04d8e0177c4f1" - clean_sorting = apply_sortingview_curation(sorting, uri_or_json=sha_uri) + clean_sorting = apply_sortingview_curation(sorting=sorting, uri_or_json=sha_uri) Note that you can also "Export as JSON" and pass the json file as :code:`uri_or_json` parameter. diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index fa637f898b..155050ddb0 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -28,15 +28,14 @@ The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:` from spikeinterface.exporters import export_to_phy # the waveforms are sparse so it is faster to export to phy - folder = 'waveforms' - we = extract_waveforms(recording, sorting, folder, sparse=True) + we = extract_waveforms(recording=recording, sorting=sorting, folder='waveforms', sparse=True) # some computations are done before to control all options - compute_spike_amplitudes(we) - compute_principal_components(we, n_components=3, mode='by_channel_global') + compute_spike_amplitudes(waveform_extractor=we) + compute_principal_components(waveform_extractor=we, n_components=3, mode='by_channel_global') # the export process is fast because everything is pre-computed - export_to_phy(we, output_folder='path/to/phy_folder') + export_to_phy(wavefor_extractor=we, output_folder='path/to/phy_folder') @@ -72,12 +71,12 @@ with many units! # the waveforms are sparse for more interpretable figures - we = extract_waveforms(recording, sorting, folder='path/to/wf', sparse=True) + we = extract_waveforms(recording=recording, sorting=sorting, folder='path/to/wf', sparse=True) # some computations are done before to control all options - compute_spike_amplitudes(we) - compute_correlograms(we) - compute_quality_metrics(we, metric_names=['snr', 'isi_violation', 'presence_ratio']) + compute_spike_amplitudes(waveform_extractor=we) + compute_correlograms(waveform_extractor=we) + compute_quality_metrics(waveform_extractor=we, metric_names=['snr', 'isi_violation', 'presence_ratio']) # the export process - export_report(we, output_folder='path/to/spikeinterface-report-folder') + export_report(waveform_extractor=we, output_folder='path/to/spikeinterface-report-folder') diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index 5aed24ca41..2d0e047672 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -13,11 +13,12 @@ Most of the :code:`Recording` classes are implemented by wrapping the Most of the :code:`Sorting` classes are instead directly implemented in SpikeInterface. - Although SpikeInterface is object-oriented (class-based), each object can also be loaded with a convenient :code:`read_XXXXX()` function. +.. code-block:: python + import spikeinterface.extractors as se Read one Recording @@ -27,32 +28,34 @@ Every format can be read with a simple function: .. code-block:: python - recording_oe = read_openephys("open-ephys-folder") + recording_oe = read_openephys(folder_path="open-ephys-folder") - recording_spikeglx = read_spikeglx("spikeglx-folder") + recording_spikeglx = read_spikeglx(folder_path="spikeglx-folder") - recording_blackrock = read_blackrock("blackrock-folder") + recording_blackrock = read_blackrock(folder_path="blackrock-folder") - recording_mearec = read_mearec("mearec_file.h5") + recording_mearec = read_mearec(file_path="mearec_file.h5") Importantly, some formats directly handle the probe information: .. code-block:: python - recording_spikeglx = read_spikeglx("spikeglx-folder") + recording_spikeglx = read_spikeglx(folder_path="spikeglx-folder") print(recording_spikeglx.get_probe()) - recording_mearec = read_mearec("mearec_file.h5") + recording_mearec = read_mearec(file_path="mearec_file.h5") print(recording_mearec.get_probe()) + + Read one Sorting ---------------- .. code-block:: python - sorting_KS = read_kilosort("kilosort-folder") + sorting_KS = read_kilosort(folder_path="kilosort-folder") Read one Event @@ -60,7 +63,7 @@ Read one Event .. code-block:: python - events_OE = read_openephys_event("open-ephys-folder") + events_OE = read_openephys_event(folder_path="open-ephys-folder") For a comprehensive list of compatible technologies, see :ref:`compatible_formats`. @@ -77,7 +80,7 @@ The actual reading will be done on demand using the :py:meth:`~spikeinterface.co .. code-block:: python # opening a 40GB SpikeGLX dataset is fast - recording_spikeglx = read_spikeglx("spikeglx-folder") + recording_spikeglx = read_spikeglx(folder_path="spikeglx-folder") # this really does load the full 40GB into memory : not recommended!!!!! traces = recording_spikeglx.get_traces(start_frame=None, end_frame=None, return_scaled=False) diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index afedc4f982..8934ae1ff6 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -77,12 +77,12 @@ We currently have 3 presets: .. code-block:: python # read and preprocess - rec = read_spikeglx('/my/Neuropixel/recording') - rec = bandpass_filter(rec) - rec = common_reference(rec) + rec = read_spikeglx(folder_path='/my/Neuropixel/recording') + rec = bandpass_filter(recording=rec) + rec = common_reference(recording=rec) # then correction is one line of code - rec_corrected = correct_motion(rec, preset="nonrigid_accurate") + rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate") The process is quite long due the two first steps (activity profile + motion inference) But the return :code:`rec_corrected` is a lazy recording object that will interpolate traces on the @@ -94,20 +94,20 @@ If you want to user other presets, this is as easy as: .. code-block:: python # mimic kilosort motion - rec_corrected = correct_motion(rec, preset="kilosort_like") + rec_corrected = correct_motion(recording=rec, preset="kilosort_like") # super but less accurate and rigid - rec_corrected = correct_motion(rec, preset="rigid_fast") + rec_corrected = correct_motion(recording=rec, preset="rigid_fast") Optionally any parameter from the preset can be overwritten: .. code-block:: python - rec_corrected = correct_motion(rec, preset="nonrigid_accurate", + rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate", detect_kwargs=dict( detect_threshold=10.), - estimate_motion_kwargs=dic( + estimate_motion_kwargs=dict( histogram_depth_smooth_um=8., time_horizon_s=120., ), @@ -123,7 +123,7 @@ and checking. The folder will contain the motion vector itself of course but als .. code-block:: python motion_folder = '/somewhere/to/save/the/motion' - rec_corrected = correct_motion(rec, preset="nonrigid_accurate", folder=motion_folder) + rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate", folder=motion_folder) # and then motion_info = load_motion_info(motion_folder) @@ -156,14 +156,16 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte job_kwargs = dict(chunk_duration="1s", n_jobs=20, progress_bar=True) # Step 1 : activity profile - peaks = detect_peaks(rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs) + peaks = detect_peaks(recording=rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs) # (optional) sub-select some peaks to speed up the localization - peaks = select_peaks(peaks, ...) - peak_locations = localize_peaks(rec, peaks, method="monopolar_triangulation",radius_um=75.0, + peaks = select_peaks(peaks=peaks, ...) + peak_locations = localize_peaks(recording=rec, peaks=peaks, method="monopolar_triangulation",radius_um=75.0, max_distance_um=150.0, **job_kwargs) # Step 2: motion inference - motion, temporal_bins, spatial_bins = estimate_motion(rec, peaks, peak_locations, + motion, temporal_bins, spatial_bins = estimate_motion(recording=rec, + peaks=peaks, + peak_locations=peak_locations, method="decentralized", direction="y", bin_duration_s=2.0, @@ -173,7 +175,9 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte # Step 3: motion interpolation # this step is lazy - rec_corrected = interpolate_motion(rec, motion, temporal_bins, spatial_bins, + rec_corrected = interpolate_motion(recording=rec, motion=motion, + temporal_bins=temporal_bins, + spatial_bins=spatial_bins, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=30.) @@ -196,20 +200,20 @@ different preprocessing chains: one for motion correction and one for spike sort .. code-block:: python - raw_rec = read_spikeglx(...) + raw_rec = read_spikeglx(folder_path='/spikeglx_folder') # preprocessing 1 : bandpass (this is smoother) + cmr - rec1 = si.bandpass_filter(raw_rec, freq_min=300., freq_max=5000.) - rec1 = si.common_reference(rec1, reference='global', operator='median') + rec1 = si.bandpass_filter(recording=raw_rec, freq_min=300., freq_max=5000.) + rec1 = si.common_reference(recording=rec1, reference='global', operator='median') # here the corrected recording is done on the preprocessing 1 # rec_corrected1 will not be used for sorting! motion_folder = '/my/folder' - rec_corrected1 = correct_motion(rec1, preset="nonrigid_accurate", folder=motion_folder) + rec_corrected1 = correct_motion(recording=rec1, preset="nonrigid_accurate", folder=motion_folder) # preprocessing 2 : highpass + cmr - rec2 = si.highpass_filter(raw_rec, freq_min=300.) - rec2 = si.common_reference(rec2, reference='global', operator='median') + rec2 = si.highpass_filter(recording=raw_rec, freq_min=300.) + rec2 = si.common_reference(recording=rec2, reference='global', operator='median') # we use another preprocessing for the final interpolation motion_info = load_motion_info(motion_folder) @@ -220,7 +224,7 @@ different preprocessing chains: one for motion correction and one for spike sort spatial_bins=motion_info['spatial_bins'], **motion_info['parameters']['interpolate_motion_kwargs']) - sorting = run_sorter("montainsort5", rec_corrected2) + sorting = run_sorter(sorter_name="montainsort5", recording=rec_corrected2) References diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index a560f4d5c9..112c6e367d 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -14,9 +14,9 @@ WaveformExtractor extensions There are several postprocessing tools available, and all of them are implemented as a :py:class:`~spikeinterface.core.BaseWaveformExtractorExtension`. All computations on top -of a WaveformExtractor will be saved along side the WaveformExtractor itself (sub folder, zarr path or sub dict). +of a :code:`WaveformExtractor` will be saved along side the :code:`WaveformExtractor` itself (sub folder, zarr path or sub dict). This workflow is convenient for retrieval of time-consuming computations (such as pca or spike amplitudes) when reloading a -WaveformExtractor. +:code:`WaveformExtractor`. :py:class:`~spikeinterface.core.BaseWaveformExtractorExtension` objects are tightly connected to the parent :code:`WaveformExtractor` object, so that operations done on the :code:`WaveformExtractor`, such as saving, @@ -80,9 +80,9 @@ This extension computes the principal components of the waveforms. There are sev * "by_channel_local" (default): fits one PCA model for each by_channel * "by_channel_global": fits the same PCA model to all channels (also termed temporal PCA) -* "concatenated": contatenates all channels and fits a PCA model on the concatenated data +* "concatenated": concatenates all channels and fits a PCA model on the concatenated data -If the input :code:`WaveformExtractor` is sparse, the sparsity is used when computing PCA. +If the input :code:`WaveformExtractor` is sparse, the sparsity is used when computing the PCA. For dense waveforms, sparsity can also be passed as an argument. For more information, see :py:func:`~spikeinterface.postprocessing.compute_principal_components` @@ -127,7 +127,7 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_locations` -unit locations +unit_locations ^^^^^^^^^^^^^^ diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 7c1f33f298..67f1e52011 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -22,8 +22,8 @@ In this code example, we build a preprocessing chain with two steps: import spikeinterface.preprocessing import bandpass_filter, common_reference # recording is a RecordingExtractor object - recording_f = bandpass_filter(recording, freq_min=300, freq_max=6000) - recording_cmr = common_reference(recording_f, operator="median") + recording_f = bandpass_filter(recording=recording, freq_min=300, freq_max=6000) + recording_cmr = common_reference(recording=recording_f, operator="median") These two preprocessors will not compute anything at instantiation, but the computation will be "on-demand" ("on-the-fly") when getting traces. @@ -38,7 +38,7 @@ save the object: .. code-block:: python # here the spykingcircus2 sorter engine directly uses the lazy "recording_cmr" object - sorting = run_sorter(recording_cmr, 'spykingcircus2') + sorting = run_sorter(recording=recording_cmr, sorter_name='spykingcircus2') Most of the external sorters, however, will need a binary file as input, so we can optionally save the processed recording with the efficient SpikeInterface :code:`save()` function: @@ -64,12 +64,13 @@ dtype (unless specified otherwise): .. code-block:: python + import spikeinterface.extractors as se # spikeGLX is int16 - rec_int16 = read_spikeglx("my_folder") + rec_int16 = se.read_spikeglx(folder_path"my_folder") # by default the int16 is kept - rec_f = bandpass_filter(rec_int16, freq_min=300, freq_max=6000) + rec_f = bandpass_filter(recording=rec_int16, freq_min=300, freq_max=6000) # we can force a float32 casting - rec_f2 = bandpass_filter(rec_int16, freq_min=300, freq_max=6000, dtype='float32') + rec_f2 = bandpass_filter(recording=rec_int16, freq_min=300, freq_max=6000, dtype='float32') Some scaling pre-processors, such as :code:`whiten()` or :code:`zscore()`, will force the output to :code:`float32`. @@ -83,6 +84,8 @@ The full list of preprocessing functions can be found here: :ref:`api_preprocess Here is a full list of possible preprocessing steps, grouped by type of processing: +For all examples :code:`rec` is a :code:`RecordingExtractor`. + filter() / bandpass_filter() / notch_filter() / highpass_filter() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -98,7 +101,7 @@ Important aspects of filtering functions: .. code-block:: python - rec_f = bandpass_filter(rec, freq_min=300, freq_max=6000) + rec_f = bandpass_filter(recording=rec, freq_min=300, freq_max=6000) * :py:func:`~spikeinterface.preprocessing.filter()` @@ -119,7 +122,7 @@ There are various options when combining :code:`operator` and :code:`reference` .. code-block:: python - rec_cmr = common_reference(rec, operator="median", reference="global") + rec_cmr = common_reference(recording=rec, operator="median", reference="global") * :py:func:`~spikeinterface.preprocessing.common_reference()` @@ -144,8 +147,8 @@ difference on artifact removal. .. code-block:: python - rec_shift = phase_shift(rec) - rec_cmr = common_reference(rec_shift, operator="median", reference="global") + rec_shift = phase_shift(recording=rec) + rec_cmr = common_reference(recording=rec_shift, operator="median", reference="global") @@ -168,7 +171,7 @@ centered with unitary variance on each channel. .. code-block:: python - rec_normed = zscore(rec) + rec_normed = zscore(recording=rec) * :py:func:`~spikeinterface.preprocessing.normalize_by_quantile()` * :py:func:`~spikeinterface.preprocessing.scale()` @@ -186,7 +189,7 @@ The whitened traces are then the dot product between the traces and the :code:`W .. code-block:: python - rec_w = whiten(rec) + rec_w = whiten(recording=rec) * :py:func:`~spikeinterface.preprocessing.whiten()` @@ -199,7 +202,7 @@ The :code:`blank_staturation()` function is similar, but it automatically estima .. code-block:: python - rec_w = clip(rec, a_min=-250., a_max=260) + rec_w = clip(recording=rec, a_min=-250., a_max=260) * :py:func:`~spikeinterface.preprocessing.clip()` * :py:func:`~spikeinterface.preprocessing.blank_staturation()` @@ -234,11 +237,11 @@ interpolated with the :code:`interpolate_bad_channels()` function (channels labe .. code-block:: python # detect - bad_channel_ids, channel_labels = detect_bad_channels(rec) + bad_channel_ids, channel_labels = detect_bad_channels(recording=rec) # Case 1 : remove then - rec_clean = recording.remove_channels(bad_channel_ids) + rec_clean = recording.remove_channels(remove_channel_ids=bad_channel_ids) # Case 2 : interpolate then - rec_clean = interpolate_bad_channels(rec, bad_channel_ids) + rec_clean = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids) * :py:func:`~spikeinterface.preprocessing.detect_bad_channels()` @@ -257,13 +260,13 @@ remove_artifacts() Given an external list of trigger times, :code:`remove_artifacts()` function can remove artifacts with several strategies: -* replace with zeros (blank) -* make a linear or cubic interpolation -* remove the median or average template (with optional time jitter and amplitude scaling correction) +* replace with zeros (blank) :code:`'zeros'` +* make a linear (:code:`'linear'`) or cubic (:code:`'cubic'`) interpolation +* remove the median (:code:`'median'`) or average (:code:`'avereage'`) template (with optional time jitter and amplitude scaling correction) .. code-block:: python - rec_clean = remove_artifacts(rec, list_triggers) + rec_clean = remove_artifacts(recording=rec, list_triggers=[100, 200, 300], mode='zeros') * :py:func:`~spikeinterface.preprocessing.remove_artifacts()` @@ -276,7 +279,7 @@ Similarly to :code:`numpy.astype()`, the :code:`astype()` casts the traces to th .. code-block:: python - rec_int16 = astype(rec_float, "int16") + rec_int16 = astype(recording=rec_float, dtype="int16") For recordings whose traces are unsigned (e.g. Maxwell Biosystems), the :code:`unsigned_to_signed()` function makes them @@ -286,7 +289,7 @@ is subtracted, and the traces are finally cast to :code:`int16`: .. code-block:: python - rec_int16 = unsigned_to_signed(rec_uint16) + rec_int16 = unsigned_to_signed(recording=rec_uint16) * :py:func:`~spikeinterface.preprocessing.astype()` * :py:func:`~spikeinterface.preprocessing.unsigned_to_signed()` @@ -300,7 +303,7 @@ required. .. code-block:: python - rec_with_more_channels = zero_channel_pad(rec, 128) + rec_with_more_channels = zero_channel_pad(parent_recording=rec, num_channels=128) * :py:func:`~spikeinterface.preprocessing.zero_channel_pad()` @@ -331,7 +334,7 @@ How to implement "IBL destriping" or "SpikeGLX CatGT" in SpikeInterface SpikeGLX has a built-in function called `CatGT `_ to apply some preprocessing on the traces to remove noise and artifacts. IBL also has a standardized pipeline for preprocessed traces a bit similar to CatGT which is called "destriping" [IBL_spikesorting]_. -In these both cases, the traces are entiely read, processed and written back to a file. +In both these cases, the traces are entirely read, processed and written back to a file. SpikeInterface can reproduce similar results without the need to write back to a file by building a *lazy* preprocessing chain. Optionally, the result can still be written to a binary (or a zarr) file. @@ -341,12 +344,12 @@ Here is a recipe to mimic the **IBL destriping**: .. code-block:: python - rec = read_spikeglx('my_spikeglx_folder') - rec = highpass_filter(rec, n_channel_pad=60) - rec = phase_shift(rec) - bad_channel_ids = detect_bad_channels(rec) - rec = interpolate_bad_channels(rec, bad_channel_ids) - rec = highpass_spatial_filter(rec) + rec = read_spikeglx(folder_path='my_spikeglx_folder') + rec = highpass_filter(recording=rec, n_channel_pad=60) + rec = phase_shift(recording=rec) + bad_channel_ids = detect_bad_channels(recording=rec) + rec = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids) + rec = highpass_spatial_filter(recording=rec) # optional rec.save(folder='clean_traces', n_jobs=10, chunk_duration='1s', progres_bar=True) @@ -356,9 +359,9 @@ Here is a recipe to mimic the **SpikeGLX CatGT**: .. code-block:: python - rec = read_spikeglx('my_spikeglx_folder') - rec = phase_shift(rec) - rec = common_reference(rec, operator="median", reference="global") + rec = read_spikeglx(folder_path='my_spikeglx_folder') + rec = phase_shift(recording=rec) + rec = common_reference(recording=rec, operator="median", reference="global") # optional rec.save(folder='clean_traces', n_jobs=10, chunk_duration='1s', progres_bar=True) @@ -369,7 +372,6 @@ Of course, these pipelines can be enhanced and customized using other available - Preprocessing on Snippets ------------------------- diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index 8c7c0a2cc3..ec1788350f 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -25,9 +25,11 @@ For more details about each metric and it's availability and use within SpikeInt :glob: qualitymetrics/amplitude_cutoff + qualitymetrics/amplitude_cv qualitymetrics/amplitude_median qualitymetrics/d_prime qualitymetrics/drift + qualitymetrics/firing_range qualitymetrics/firing_rate qualitymetrics/isi_violations qualitymetrics/isolation_distance @@ -45,16 +47,16 @@ This code snippet shows how to compute quality metrics (with or without principa .. code-block:: python - we = si.load_waveforms(...) # start from a waveform extractor + we = si.load_waveforms(folder='waveforms') # start from a waveform extractor # without PC - metrics = compute_quality_metrics(we, metric_names=['snr']) + metrics = compute_quality_metrics(waveform_extractor=we, metric_names=['snr']) assert 'snr' in metrics.columns # with PCs from spikeinterface.postprocessing import compute_principal_components - pca = compute_principal_components(we, n_components=5, mode='by_channel_local') - metrics = compute_quality_metrics(we) + pca = compute_principal_components(waveform_extractor=we, n_components=5, mode='by_channel_local') + metrics = compute_quality_metrics(waveform_extractor=we) assert 'isolation_distance' in metrics.columns For more information about quality metrics, check out this excellent diff --git a/doc/modules/qualitymetrics/amplitude_cutoff.rst b/doc/modules/qualitymetrics/amplitude_cutoff.rst index 9f747f8d40..a1e4d85d01 100644 --- a/doc/modules/qualitymetrics/amplitude_cutoff.rst +++ b/doc/modules/qualitymetrics/amplitude_cutoff.rst @@ -21,12 +21,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitudes from all spikes - fraction_missing = qm.compute_amplitude_cutoffs(wvf_extractor, peak_sign="neg") - # fraction_missing is a dict containing the units' IDs as keys, + fraction_missing = sqm.compute_amplitude_cutoffs(wvf_extractor, peak_sign="neg") + # fraction_missing is a dict containing the unit IDs as keys, # and their estimated fraction of missing spikes as values. Reference diff --git a/doc/modules/qualitymetrics/amplitude_cv.rst b/doc/modules/qualitymetrics/amplitude_cv.rst new file mode 100644 index 0000000000..81d3b4f12d --- /dev/null +++ b/doc/modules/qualitymetrics/amplitude_cv.rst @@ -0,0 +1,55 @@ +Amplitude CV (:code:`amplitude_cv_median`, :code:`amplitude_cv_range`) +====================================================================== + + +Calculation +----------- + +The amplitude CV (coefficient of variation) is a measure of the amplitude variability. +It is computed as the ratio between the standard deviation and the amplitude mean. +To obtain a better estimate of this measure, it is first computed separately for several temporal bins. +Out of these values, the median and the range (percentile distance, by default between the +5th and 95th percentiles) are computed. + +The computation requires either spike amplitudes (see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes()`) +or amplitude scalings (see :py:func:`~spikeinterface.postprocessing.compute_amplitude_scalings()`) to be pre-computed. + + +Expectation and use +------------------- + +The amplitude CV median is expected to be relatively low for well-isolated units, indicating a "stereotypical" spike shape. + +The amplitude CV range can be high in the presence of noise contamination, due to amplitude outliers like in +the example below. + +.. image:: amplitudes.png + :width: 600 + + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + # Make recording, sorting and wvf_extractor object for your data. + # It is required to run `compute_spike_amplitudes(wvf_extractor)` or + # `compute_amplitude_scalings(wvf_extractor)` (if missing, values will be NaN) + amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(waveform_extractor=wvf_extractor) + # amplitude_cv_median and amplitude_cv_range are dicts containing the unit ids as keys, + # and their amplitude_cv metrics as values. + + + +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_amplitude_cv_metrics + + +Literature +---------- + +Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino. diff --git a/doc/modules/qualitymetrics/amplitude_median.rst b/doc/modules/qualitymetrics/amplitude_median.rst index ffc45d1cf6..c77a57b033 100644 --- a/doc/modules/qualitymetrics/amplitude_median.rst +++ b/doc/modules/qualitymetrics/amplitude_median.rst @@ -20,12 +20,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitude values from all spikes. - amplitude_medians = qm.compute_amplitude_medians(wvf_extractor) - # amplitude_medians is a dict containing the units' IDs as keys, + amplitude_medians = sqm.compute_amplitude_medians(waveform_extractor=wvf_extractor) + # amplitude_medians is a dict containing the unit IDs as keys, # and their estimated amplitude medians as values. Reference diff --git a/doc/modules/qualitymetrics/amplitudes.png b/doc/modules/qualitymetrics/amplitudes.png new file mode 100644 index 0000000000..0ee4dd1eda Binary files /dev/null and b/doc/modules/qualitymetrics/amplitudes.png differ diff --git a/doc/modules/qualitymetrics/d_prime.rst b/doc/modules/qualitymetrics/d_prime.rst index abb8c1dc74..9b540be743 100644 --- a/doc/modules/qualitymetrics/d_prime.rst +++ b/doc/modules/qualitymetrics/d_prime.rst @@ -32,9 +32,9 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm - d_prime = qm.lda_metrics(all_pcs, all_labels, 0) + d_prime = sqm.lda_metrics(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) Reference diff --git a/doc/modules/qualitymetrics/drift.rst b/doc/modules/qualitymetrics/drift.rst index 0a852f80af..dad2aafe7c 100644 --- a/doc/modules/qualitymetrics/drift.rst +++ b/doc/modules/qualitymetrics/drift.rst @@ -40,12 +40,13 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm - # It is required to run `compute_spike_locations(wvf_extractor)` + # Make recording, sorting and wvf_extractor object for your data. + # It is required to run `compute_spike_locations(wvf_extractor) first` # (if missing, values will be NaN) - drift_ptps, drift_stds, drift_mads = qm.compute_drift_metrics(wvf_extractor, peak_sign="neg") - # drift_ptps, drift_stds, and drift_mads are dict containing the units' ID as keys, + drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(waveform_extractor=wvf_extractor, peak_sign="neg") + # drift_ptps, drift_stds, and drift_mads are each a dict containing the unit IDs as keys, # and their metrics as values. diff --git a/doc/modules/qualitymetrics/firing_range.rst b/doc/modules/qualitymetrics/firing_range.rst new file mode 100644 index 0000000000..1cbd903c7a --- /dev/null +++ b/doc/modules/qualitymetrics/firing_range.rst @@ -0,0 +1,40 @@ +Firing range (:code:`firing_range`) +=================================== + + +Calculation +----------- + +The firing range indicates the dispersion of the firing rate of a unit across the recording. It is computed by +taking the difference between the 95th percentile's firing rate and the 5th percentile's firing rate computed over short time bins (e.g. 10 s). + + + +Expectation and use +------------------- + +Very high levels of firing ranges, outside of a physiological range, might indicate noise contamination. + + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + # Make recording, sorting and wvf_extractor object for your data. + firing_range = sqm.compute_firing_ranges(waveform_extractor=wvf_extractor) + # firing_range is a dict containing the unit IDs as keys, + # and their firing firing_range as values (in Hz). + +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_firing_ranges + + +Literature +---------- + +Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino. diff --git a/doc/modules/qualitymetrics/firing_rate.rst b/doc/modules/qualitymetrics/firing_rate.rst index eddef3e48f..ef8cb3d8f4 100644 --- a/doc/modules/qualitymetrics/firing_rate.rst +++ b/doc/modules/qualitymetrics/firing_rate.rst @@ -37,11 +37,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - firing_rate = qm.compute_firing_rates(wvf_extractor) - # firing_rate is a dict containing the units' IDs as keys, + firing_rate = sqm.compute_firing_rates(waveform_extractor=wvf_extractor) + # firing_rate is a dict containing the unit IDs as keys, # and their firing rates across segments as values (in Hz). References diff --git a/doc/modules/qualitymetrics/isi_violations.rst b/doc/modules/qualitymetrics/isi_violations.rst index 947e7d4938..725d9b0fd6 100644 --- a/doc/modules/qualitymetrics/isi_violations.rst +++ b/doc/modules/qualitymetrics/isi_violations.rst @@ -77,11 +77,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - isi_violations_ratio, isi_violations_count = qm.compute_isi_violations(wvf_extractor, isi_threshold_ms=1.0) + isi_violations_ratio, isi_violations_count = sqm.compute_isi_violations(wvf_extractor, isi_threshold_ms=1.0) References ---------- diff --git a/doc/modules/qualitymetrics/isolation_distance.rst b/doc/modules/qualitymetrics/isolation_distance.rst index 640a5a8b5a..6ba0d0b1ec 100644 --- a/doc/modules/qualitymetrics/isolation_distance.rst +++ b/doc/modules/qualitymetrics/isolation_distance.rst @@ -23,6 +23,16 @@ Expectation and use Isolation distance can be interpreted as a measure of distance from the cluster to the nearest other cluster. A well isolated unit should have a large isolation distance. +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + iso_distance, _ = sqm.isolation_distance(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + + References ---------- diff --git a/doc/modules/qualitymetrics/l_ratio.rst b/doc/modules/qualitymetrics/l_ratio.rst index b37913ba58..ae31ab40a4 100644 --- a/doc/modules/qualitymetrics/l_ratio.rst +++ b/doc/modules/qualitymetrics/l_ratio.rst @@ -37,6 +37,17 @@ Since this metric identifies unit separation, a high value indicates a highly co A well separated unit should have a low L-ratio ([Schmitzer-Torbert]_ et al.). + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + _, l_ratio = sqm.isolation_distance(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + + References ---------- diff --git a/doc/modules/qualitymetrics/presence_ratio.rst b/doc/modules/qualitymetrics/presence_ratio.rst index e4de2248bd..ad0766d37c 100644 --- a/doc/modules/qualitymetrics/presence_ratio.rst +++ b/doc/modules/qualitymetrics/presence_ratio.rst @@ -23,12 +23,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - presence_ratio = qm.compute_presence_ratios(wvf_extractor) - # presence_ratio is a dict containing the units' IDs as keys + presence_ratio = sqm.compute_presence_ratios(waveform_extractor=wvf_extractor) + # presence_ratio is a dict containing the unit IDs as keys # and their presence ratio (between 0 and 1) as values. Links to original implementations diff --git a/doc/modules/qualitymetrics/silhouette_score.rst b/doc/modules/qualitymetrics/silhouette_score.rst index b924cdbf73..7da01e0476 100644 --- a/doc/modules/qualitymetrics/silhouette_score.rst +++ b/doc/modules/qualitymetrics/silhouette_score.rst @@ -50,6 +50,16 @@ To reduce complexity the default implementation in SpikeInterface is to use the This can be changes by switching the silhouette method to either 'full' (the Rousseeuw implementation) or ('simplified', 'full') for both methods when entering the qm_params parameter. +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + simple_sil_score = sqm.simplified_silhouette_score(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + + References ---------- diff --git a/doc/modules/qualitymetrics/sliding_rp_violations.rst b/doc/modules/qualitymetrics/sliding_rp_violations.rst index 843242c1e8..fd53d7da3b 100644 --- a/doc/modules/qualitymetrics/sliding_rp_violations.rst +++ b/doc/modules/qualitymetrics/sliding_rp_violations.rst @@ -27,11 +27,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - contamination = qm.compute_sliding_rp_violations(wvf_extractor, bin_size_ms=0.25) + contamination = sqm.compute_sliding_rp_violations(waveform_extractor=wvf_extractor, bin_size_ms=0.25) References ---------- diff --git a/doc/modules/qualitymetrics/snr.rst b/doc/modules/qualitymetrics/snr.rst index 288ab60515..7f27a5078a 100644 --- a/doc/modules/qualitymetrics/snr.rst +++ b/doc/modules/qualitymetrics/snr.rst @@ -41,12 +41,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - - SNRs = qm.compute_snrs(wvf_extractor) - # SNRs is a dict containing the units' IDs as keys and their SNRs as values. + SNRs = sqm.compute_snrs(waveform_extractor=wvf_extractor) + # SNRs is a dict containing the unit IDs as keys and their SNRs as values. Links to original implementations --------------------------------- diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index 2f566bf8a7..d1a3c70a97 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -27,9 +27,9 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - synchrony = qm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(waveform_extractor=wvf_extractor, synchrony_sizes=(2, 4, 8)) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index 34ab3d1151..5040b01ec2 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -49,15 +49,15 @@ to easily run spike sorters: from spikeinterface.sorters import run_sorter # run Tridesclous - sorting_TDC = run_sorter("tridesclous", recording, output_folder="/folder_TDC") + sorting_TDC = run_sorter(sorter_name="tridesclous", recording=recording, output_folder="/folder_TDC") # run Kilosort2.5 - sorting_KS2_5 = run_sorter("kilosort2_5", recording, output_folder="/folder_KS2.5") + sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, output_folder="/folder_KS2.5") # run IronClust - sorting_IC = run_sorter("ironclust", recording, output_folder="/folder_IC") + sorting_IC = run_sorter(sorter_name="ironclust", recording=recording, output_folder="/folder_IC") # run pyKilosort - sorting_pyKS = run_sorter("pykilosort", recording, output_folder="/folder_pyKS") + sorting_pyKS = run_sorter(sorter_name="pykilosort", recording=recording, output_folder="/folder_pyKS") # run SpykingCircus - sorting_SC = run_sorter("spykingcircus", recording, output_folder="/folder_SC") + sorting_SC = run_sorter(sorter_name="spykingcircus", recording=recording, output_folder="/folder_SC") Then the output, which is a :py:class:`~spikeinterface.core.BaseSorting` object, can be easily @@ -81,10 +81,10 @@ Spike-sorter-specific parameters can be controlled directly from the .. code-block:: python - sorting_TDC = run_sorter('tridesclous', recording, output_folder="/folder_TDC", + sorting_TDC = run_sorter(sorter_name='tridesclous', recording=recording, output_folder="/folder_TDC", detect_threshold=8.) - sorting_KS2_5 = run_sorter("kilosort2_5", recording, output_folder="/folder_KS2.5" + sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, output_folder="/folder_KS2.5" do_correction=False, preclust_threshold=6, freq_min=200.) @@ -185,7 +185,7 @@ The following code creates a test recording and runs a containerized spike sorte ) test_recording = test_recording.save(folder="test-docker-folder") - sorting = ss.run_sorter('kilosort3', + sorting = ss.run_sorter(sorter_name='kilosort3', recording=test_recording, output_folder="kilosort3", singularity_image=True) @@ -201,7 +201,7 @@ To run in Docker instead of Singularity, use ``docker_image=True``. .. code-block:: python - sorting = run_sorter('kilosort3', recording=test_recording, + sorting = run_sorter(sorter_name='kilosort3', recording=test_recording, output_folder="/tmp/kilosort3", docker_image=True) To use a specific image, set either ``docker_image`` or ``singularity_image`` to a string, @@ -209,7 +209,7 @@ e.g. ``singularity_image="spikeinterface/kilosort3-compiled-base:0.1.0"``. .. code-block:: python - sorting = run_sorter("kilosort3", + sorting = run_sorter(sorter_name="kilosort3", recording=test_recording, output_folder="kilosort3", singularity_image="spikeinterface/kilosort3-compiled-base:0.1.0") @@ -239,7 +239,7 @@ There are three options: 1. **released PyPi version**: if you installed :code:`spikeinterface` with :code:`pip install spikeinterface`, the latest released version will be installed in the container. -2. **development :code:`main` version**: if you installed :code:`spikeinterface` from source from the cloned repo +2. **development** :code:`main` **version**: if you installed :code:`spikeinterface` from source from the cloned repo (with :code:`pip install .`) or with :code:`pip install git+https://github.com/SpikeInterface/spikeinterface.git`, the current development version from the :code:`main` branch will be installed in the container. @@ -271,7 +271,7 @@ And use the custom image whith the :code:`run_sorter` function: .. code-block:: python - sorting = run_sorter("kilosort3", + sorting = run_sorter(sorter_name="kilosort3", recording=recording, docker_image="my-user/ks3-with-spikeinterface-test:0.1.0") @@ -285,27 +285,26 @@ Running several sorters in parallel The :py:mod:`~spikeinterface.sorters` module also includes tools to run several spike sorting jobs sequentially or in parallel. This can be done with the -:py:func:`~spikeinterface.sorters.run_sorters()` function by specifying +:py:func:`~spikeinterface.sorters.run_sorter_jobs()` function by specifying an :code:`engine` that supports parallel processing (such as :code:`joblib` or :code:`slurm`). .. code-block:: python - recordings = {'rec1' : recording, 'rec2': another_recording} - sorter_list = ['herdingspikes', 'tridesclous'] - sorter_params = { - 'herdingspikes': {'clustering_bandwidth' : 8}, - 'tridesclous': {'detect_threshold' : 5.}, - } - sorting_output = run_sorters(sorter_list, recordings, working_folder='tmp_some_sorters', - mode_if_folder_exists='overwrite', sorter_params=sorter_params) + # here we run 2 sorters on 2 different recordings = 4 jobs + recording = ... + another_recording = ... + + job_list = [ + {'sorter_name': 'tridesclous', 'recording': recording, 'output_folder': 'folder1','detect_threshold': 5.}, + {'sorter_name': 'tridesclous', 'recording': another_recording, 'output_folder': 'folder2', 'detect_threshold': 5.}, + {'sorter_name': 'herdingspikes', 'recording': recording, 'output_folder': 'folder3', 'clustering_bandwidth': 8., 'docker_image': True}, + {'sorter_name': 'herdingspikes', 'recording': another_recording, 'output_folder': 'folder4', 'clustering_bandwidth': 8., 'docker_image': True}, + ] + + # run in loop + sortings = run_sorter_jobs(job_list=job_list, engine='loop') - # the output is a dict with (rec_name, sorter_name) as keys - for (rec_name, sorter_name), sorting in sorting_output.items(): - print(rec_name, sorter_name, ':', sorting.get_unit_ids()) -After the jobs are run, the :code:`sorting_outputs` is a dictionary with :code:`(rec_name, sorter_name)` as a key (e.g. -:code:`('rec1', 'tridesclous')` in this example), and the corresponding :py:class:`~spikeinterface.core.BaseSorting` -as a value. :py:func:`~spikeinterface.sorters.run_sorters` has several "engines" available to launch the computation: @@ -315,13 +314,11 @@ as a value. .. code-block:: python - run_sorters(sorter_list, recordings, engine='loop') + run_sorter_jobs(job_list=job_list, engine='loop') - run_sorters(sorter_list, recordings, engine='joblib', - engine_kwargs={'n_jobs': 2}) + run_sorter_jobs(job_list=job_list, engine='joblib', engine_kwargs={'n_jobs': 2}) - run_sorters(sorter_list, recordings, engine='slurm', - engine_kwargs={'cpus_per_task': 10, 'mem', '5G'}) + run_sorter_jobs(job_list=job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem': '5G'}) Spike sorting by group @@ -377,7 +374,7 @@ In this example, we create a 16-channel recording with 4 tetrodes: # here the result is a dict of a sorting object sortings = {} for group, sub_recording in recordings.items(): - sorting = run_sorter('kilosort2', recording, output_folder=f"folder_KS2_group{group}") + sorting = run_sorter(sorter_name='kilosort2', recording=recording, output_folder=f"folder_KS2_group{group}") sortings[group] = sorting **Option 2 : Automatic splitting** @@ -385,7 +382,7 @@ In this example, we create a 16-channel recording with 4 tetrodes: .. code-block:: python # here the result is one sorting that aggregates all sub sorting objects - aggregate_sorting = run_sorter_by_property('kilosort2', recording_4_tetrodes, + aggregate_sorting = run_sorter_by_property(sorter_name='kilosort2', recording=recording_4_tetrodes, grouping_property='group', working_folder='working_path') @@ -424,7 +421,7 @@ do not handle multi-segment, and in that case we will use the # multirecording has 4 segments of 10s each # run tridesclous in multi-segment mode - multisorting = si.run_sorter('tridesclous', multirecording) + multisorting = si.run_sorter(sorter_name='tridesclous', recording=multirecording) print(multisorting) # Case 2: the sorter DOES NOT handle multi-segment objects @@ -436,7 +433,7 @@ do not handle multi-segment, and in that case we will use the # multirecording has 1 segment of 40s each # run mountainsort4 in mono-segment mode - multisorting = si.run_sorter('mountainsort4', multirecording) + multisorting = si.run_sorter(sorter_name='mountainsort4', recording=multirecording) See also the :ref:`multi_seg` section. @@ -458,7 +455,7 @@ Here is the list of external sorters accessible using the run_sorter wrapper: * **Kilosort** :code:`run_sorter('kilosort')` * **Kilosort2** :code:`run_sorter('kilosort2')` * **Kilosort2.5** :code:`run_sorter('kilosort2_5')` -* **Kilosort3** :code:`run_sorter('Kilosort3')` +* **Kilosort3** :code:`run_sorter('kilosort3')` * **PyKilosort** :code:`run_sorter('pykilosort')` * **Klusta** :code:`run_sorter('klusta')` * **Mountainsort4** :code:`run_sorter('mountainsort4')` @@ -474,7 +471,7 @@ Here is the list of external sorters accessible using the run_sorter wrapper: Here a list of internal sorter based on `spikeinterface.sortingcomponents`; they are totally experimental for now: -* **Spyking circus2** :code:`run_sorter('spykingcircus2')` +* **Spyking Circus2** :code:`run_sorter('spykingcircus2')` * **Tridesclous2** :code:`run_sorter('tridesclous2')` In 2023, we expect to add many more sorters to this list. @@ -510,7 +507,7 @@ message will appear indicating how to install the given sorter, .. code:: python - recording = run_sorter('ironclust', recording) + recording = run_sorter(sorter_name='ironclust', recording=recording) throws the error, @@ -543,7 +540,7 @@ From the user's perspective, they behave exactly like the external sorters: .. code-block:: python - sorting = run_sorter("spykingcircus2", recording, "/tmp/folder") + sorting = run_sorter(sorter_name="spykingcircus2", recording=recording, output_folder="/tmp/folder") Contributing diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index aa62ea5b33..1e58972497 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -47,7 +47,8 @@ follows: job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True) peaks = detect_peaks( - recording, method='by_channel', + recording=recording, + method='by_channel', peak_sign='neg', detect_threshold=5, exclude_sweep_ms=0.2, @@ -94,7 +95,7 @@ follows: job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True) - peak_locations = localize_peaks(recording, peaks, method='center_of_mass', + peak_locations = localize_peaks(recording=recording, peaks=peaks, method='center_of_mass', radius_um=70., ms_before=0.3, ms_after=0.6, **job_kwargs) @@ -122,7 +123,7 @@ For instance, the 'monopolar_triangulation' method will have: .. note:: - By convention in SpikeInterface, when a probe is described in 2d + By convention in SpikeInterface, when a probe is described in 3d * **'x'** is the width of the probe * **'y'** is the depth * **'z'** is orthogonal to the probe plane @@ -144,11 +145,11 @@ can be *hidden* by this process. from spikeinterface.sortingcomponents.peak_detection import detect_peaks - many_peaks = detect_peaks(...) + many_peaks = detect_peaks(...) # as in above example from spikeinterface.sortingcomponents.peak_selection import select_peaks - some_peaks = select_peaks(many_peaks, method='uniform', n_peaks=10000) + some_peaks = select_peaks(peaks=many_peaks, method='uniform', n_peaks=10000) Implemented methods are the following: @@ -183,15 +184,15 @@ Here is an example with non-rigid motion estimation: .. code-block:: python from spikeinterface.sortingcomponents.peak_detection import detect_peaks - peaks = detect_peaks(recording, ...) + peaks = detect_peaks(recording=recording, ...) # as in above example from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(recording, peaks, ...) + peak_locations = localize_peaks(recording=recording, peaks=peaks, ...) # as above from spikeinterface.sortingcomponents.motion_estimation import estimate_motion motion, temporal_bins, spatial_bins, - extra_check = estimate_motion(recording, peaks, peak_locations=peak_locations, + extra_check = estimate_motion(recording=recording, peaks=peaks, peak_locations=peak_locations, direction='y', bin_duration_s=10., bin_um=10., margin_um=0., method='decentralized_registration', rigid=False, win_shape='gaussian', win_step_um=50., win_sigma_um=150., @@ -217,13 +218,13 @@ Here is a short example that depends on the output of "Motion interpolation": from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording - recording_corrected = InterpolateMotionRecording(recording_with_drift, motion, temporal_bins, spatial_bins + recording_corrected = InterpolateMotionRecording(recording=recording_with_drift, motion=motion, temporal_bins=temporal_bins, spatial_bins=spatial_bins spatial_interpolation_method='kriging, border_mode='remove_channels') **Notes**: * :code:`spatial_interpolation_method` "kriging" or "iwd" do not play a big role. - * :code:`border_mode` is a very important parameter. It controls how to deal with the border because motion causes units on the + * :code:`border_mode` is a very important parameter. It controls dealing with the border because motion causes units on the border to not be present throughout the entire recording. We highly recommend the :code:`border_mode='remove_channels'` because this removes channels on the border that will be impacted by drift. Of course the larger the motion is the more channels are removed. @@ -255,10 +256,10 @@ Different methods may need different inputs (for instance some of them require p .. code-block:: python from spikeinterface.sortingcomponents.peak_detection import detect_peaks - peaks = detect_peaks(recording, ...) + peaks = detect_peaks(recording, ...) # as in above example from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks - labels, peak_labels = find_cluster_from_peaks(recording, peaks, method="sliding_hdbscan") + labels, peak_labels = find_cluster_from_peaks(recording=recording, peaks=peaks, method="sliding_hdbscan") * **labels** : contains all possible labels @@ -278,7 +279,7 @@ At the moment, there are five methods implemented: * 'naive': a very naive implemenation used as a reference for benchmarks * 'tridesclous': the algorithm for template matching implemented in Tridesclous * 'circus': the algorithm for template matching implemented in SpyKING-Circus - * 'circus-omp': a updated algorithm similar to SpyKING-Circus but with OMP (orthogonal macthing + * 'circus-omp': a updated algorithm similar to SpyKING-Circus but with OMP (orthogonal matching pursuit) * 'wobble' : an algorithm loosely based on YASS that scales template amplitudes and shifts them in time to match detected spikes diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 86c541dfd0..f37b2a5a6f 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -14,6 +14,9 @@ Since version 0.95.0, the :py:mod:`spikeinterface.widgets` module supports multi * | :code:`sortingview`: web-based and interactive rendering using the `sortingview `_ | and `FIGURL `_ packages. +Version 0.100.0, also come with this new backend: +* | :code:`ephyviewer`: interactive Qt based using the `ephyviewer `_ package + Installing backends ------------------- @@ -85,6 +88,28 @@ Finally, if you wish to set up another cloud provider, follow the instruction fr `kachery-cloud `_ package ("Using your own storage bucket"). +ephyviewer +^^^^^^^^^^ + +This backend is Qt based with PyQt5, PyQt6 or PySide6 support. Qt is sometimes tedious to install. + + +For a pip-based installation, run: + +.. code-block:: bash + + pip install PySide6 ephyviewer + + +Anaconda users will have a better experience with this: + +.. code-block:: bash + + conda install pyqt=5 + pip install ephyviewer + + + Usage ----- @@ -123,7 +148,7 @@ The :code:`plot_*(..., backend="matplotlib")` functions come with the following .. code-block:: python # matplotlib backend - w = plot_traces(recording, backend="matplotlib") + w = plot_traces(recording=recording, backend="matplotlib") **Output:** @@ -148,7 +173,7 @@ Each function has the following additional arguments: # ipywidgets backend also supports multiple "layers" for plot_traces rec_dict = dict(filt=recording, cmr=common_reference(recording)) - w = sw.plot_traces(rec_dict, backend="ipywidgets") + w = sw.plot_traces(recording=rec_dict, backend="ipywidgets") **Output:** @@ -171,8 +196,8 @@ The functions have the following additional arguments: .. code-block:: python # sortingview backend - w_ts = sw.plot_traces(recording, backend="ipywidgets") - w_ss = sw.plot_sorting_summary(recording, backend="sortingview") + w_ts = sw.plot_traces(recording=recording, backend="ipywidgets") + w_ss = sw.plot_sorting_summary(recording=recording, backend="sortingview") **Output:** @@ -215,6 +240,21 @@ For example, here is how to combine the timeseries and sorting summary generated print(url) +ephyviewer +^^^^^^^^^^ + + +The :code:`ephyviewer` backend is currently only available for the :py:func:`~spikeinterface.widgets.plot_traces()` function. + + +.. code-block:: python + + plot_traces(recording=recording, backend="ephyviewer", mode="line", show_channel_ids=True) + + +.. image:: ../images/plot_traces_ephyviewer.png + + Available plotting functions ---------------------------- @@ -229,7 +269,7 @@ Available plotting functions * :py:func:`~spikeinterface.widgets.plot_spikes_on_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`) * :py:func:`~spikeinterface.widgets.plot_template_metrics` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_template_similarity` (backends: ::code:`matplotlib`, :code:`sortingview`) -* :py:func:`~spikeinterface.widgets.plot_timeseries` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) +* :py:func:`~spikeinterface.widgets.plot_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`, :code:`ephyviewer`) * :py:func:`~spikeinterface.widgets.plot_unit_depths` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_unit_locations` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_unit_summary` (backends: :code:`matplotlib`) diff --git a/examples/modules_gallery/core/plot_4_waveform_extractor.py b/examples/modules_gallery/core/plot_4_waveform_extractor.py index 6c886c1eb0..bee8f4061b 100644 --- a/examples/modules_gallery/core/plot_4_waveform_extractor.py +++ b/examples/modules_gallery/core/plot_4_waveform_extractor.py @@ -49,7 +49,8 @@ ############################################################################### # A :py:class:`~spikeinterface.core.WaveformExtractor` object can be created with the -# :py:func:`~spikeinterface.core.extract_waveforms` function: +# :py:func:`~spikeinterface.core.extract_waveforms` function (this defaults to a sparse +# representation of the waveforms): folder = 'waveform_folder' we = extract_waveforms( @@ -87,6 +88,7 @@ recording, sorting, folder, + sparse=False, ms_before=3., ms_after=4., max_spikes_per_unit=500, @@ -149,7 +151,7 @@ # # Option 1) Save a dense waveform extractor to sparse: # -# In this case, from an existing waveform extractor, we can first estimate a +# In this case, from an existing (dense) waveform extractor, we can first estimate a # sparsity (which channels each unit is defined on) and then save to a new # folder in sparse mode: @@ -173,7 +175,7 @@ ############################################################################### -# Option 2) Directly extract sparse waveforms: +# Option 2) Directly extract sparse waveforms (current spikeinterface default): # # We can also directly extract sparse waveforms. To do so, dense waveforms are # extracted first using a small number of spikes (:code:`'num_spikes_for_sparsity'`) diff --git a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py index 209f357457..7b6aae3e30 100644 --- a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py +++ b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py @@ -30,7 +30,7 @@ # because it contains a reference to the "Recording" and the "Sorting" objects: folder = 'waveforms_mearec' -we = si.extract_waveforms(recording, sorting, folder, +we = si.extract_waveforms(recording, sorting, folder, sparse=False, ms_before=1, ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_durations='1s') print(we) diff --git a/examples/modules_gallery/qualitymetrics/plot_4_curation.py b/examples/modules_gallery/qualitymetrics/plot_4_curation.py index c66f55f221..edd7a85ce5 100644 --- a/examples/modules_gallery/qualitymetrics/plot_4_curation.py +++ b/examples/modules_gallery/qualitymetrics/plot_4_curation.py @@ -6,6 +6,8 @@ quality metrics. """ +############################################################################# +# Import the modules and/or functions necessary from spikeinterface import spikeinterface as si import spikeinterface.extractors as se @@ -15,22 +17,21 @@ ############################################################################## -# First, let's download a simulated dataset -# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' +# Let's download a simulated dataset +# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' # # Let's imagine that the ground-truth sorting is in fact the output of a sorter. -# local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') -recording, sorting = se.read_mearec(local_path) +recording, sorting = se.read_mearec(file_path=local_path) print(recording) print(sorting) ############################################################################## -# First, we extract waveforms and compute their PC scores: +# First, we extract waveforms (to be saved in the folder 'wfs_mearec') and +# compute their PC scores: -folder = 'wfs_mearec' -we = si.extract_waveforms(recording, sorting, folder, +we = si.extract_waveforms(recording, sorting, folder='wfs_mearec', ms_before=1, ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000) print(we) @@ -47,12 +48,15 @@ ############################################################################## # We can now threshold each quality metric and select units based on some rules. # -# The easiest and most intuitive way is to use boolean masking with dataframe: +# The easiest and most intuitive way is to use boolean masking with a dataframe. +# +# Then create a list of unit ids that we want to keep keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_ratio'] < 0.2) & (metrics['nn_hit_rate'] > 0.90) print(keep_mask) keep_unit_ids = keep_mask[keep_mask].index.values +keep_unit_ids = [unit_id for unit_id in keep_unit_ids] print(keep_unit_ids) ############################################################################## @@ -61,4 +65,5 @@ curated_sorting = sorting.select_units(keep_unit_ids) print(curated_sorting) -se.NpzSortingExtractor.write_sorting(curated_sorting, 'curated_sorting.pnz') + +se.NpzSortingExtractor.write_sorting(sorting=curated_sorting, save_path='curated_sorting.npz') diff --git a/src/spikeinterface/comparison/__init__.py b/src/spikeinterface/comparison/__init__.py index a390bb7689..bff85dde4a 100644 --- a/src/spikeinterface/comparison/__init__.py +++ b/src/spikeinterface/comparison/__init__.py @@ -28,12 +28,11 @@ compare_multiple_templates, MultiTemplateComparison, ) -from .collisioncomparison import CollisionGTComparison -from .correlogramcomparison import CorrelogramGTComparison + from .groundtruthstudy import GroundTruthStudy -from .collisionstudy import CollisionGTStudy -from .correlogramstudy import CorrelogramGTStudy -from .studytools import aggregate_performances_table +from .collision import CollisionGTComparison, CollisionGTStudy +from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy + from .hybrid import ( HybridSpikesRecording, HybridUnitsRecording, diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 79c784491a..5af20d79b5 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self): indexes = np.arange(scores.shape[1]) order1 = [] for r in range(scores.shape[0]): - possible = indexes[~np.in1d(indexes, order1)] + possible = indexes[~np.isin(indexes, order1)] if possible.size > 0: ind = np.argmax(scores.iloc[r, possible].values) order1.append(possible[ind]) - remain = indexes[~np.in1d(indexes, order1)] + remain = indexes[~np.isin(indexes, order1)] order1.extend(remain) scores = scores.iloc[:, order1] diff --git a/src/spikeinterface/comparison/collisioncomparison.py b/src/spikeinterface/comparison/collision.py similarity index 64% rename from src/spikeinterface/comparison/collisioncomparison.py rename to src/spikeinterface/comparison/collision.py index 3b279717b7..dd04b2c72d 100644 --- a/src/spikeinterface/comparison/collisioncomparison.py +++ b/src/spikeinterface/comparison/collision.py @@ -1,13 +1,15 @@ -import numpy as np - from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy from .comparisontools import make_collision_events +import numpy as np + class CollisionGTComparison(GroundTruthComparison): """ - This class is an extension of GroundTruthComparison by focusing - to benchmark spike in collision + This class is an extension of GroundTruthComparison by focusing to benchmark spike in collision. + + This class needs maintenance and need a bit of refactoring. collision_lag: float @@ -156,3 +158,73 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good pair_names = pair_names[order] return similarities, recall_scores, pair_names + + +class CollisionGTStudy(GroundTruthStudy): + def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): + _kwargs = dict() + _kwargs.update(kwargs) + _kwargs["exhaustive_gt"] = exhaustive_gt + _kwargs["collision_lag"] = collision_lag + _kwargs["nbins"] = nbins + GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) + self.exhaustive_gt = exhaustive_gt + self.collision_lag = collision_lag + + def get_lags(self, key): + comp = self.comparisons[key] + fs = comp.sorting1.get_sampling_frequency() + lags = comp.bins / fs * 1000.0 + return lags + + def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): + import sklearn + + if case_keys is None: + case_keys = self.cases.keys() + + self.all_similarities = {} + self.all_recall_scores = {} + self.good_only = good_only + + for key in case_keys: + templates = self.get_templates(key) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + comp = self.comparisons[key] + similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( + similarity, good_only=good_only, min_accuracy=min_accuracy + ) + self.all_similarities[key] = similarities + self.all_recall_scores[key] = recall_scores + + def get_mean_over_similarity_range(self, similarity_range, key): + idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) + all_similarities = self.all_similarities[key][idx] + all_recall_scores = self.all_recall_scores[key][idx] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + mean_recall_scores = np.nanmean(all_recall_scores, axis=0) + + return mean_recall_scores + + def get_lag_profile_over_similarity_bins(self, similarity_bins, key): + all_similarities = self.all_similarities[key] + all_recall_scores = self.all_recall_scores[key] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) + result[(cmin, cmax)] = mean_recall_scores + + return result diff --git a/src/spikeinterface/comparison/collisionstudy.py b/src/spikeinterface/comparison/collisionstudy.py deleted file mode 100644 index 34a556e8b9..0000000000 --- a/src/spikeinterface/comparison/collisionstudy.py +++ /dev/null @@ -1,88 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .collisioncomparison import CollisionGTComparison - -import numpy as np - - -class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CollisionGTComparison( - gt_sorting, sorting, exhaustive_gt=exhaustive_gt, collision_lag=collision_lag, nbins=nbins - ) - self.comparisons[(rec_name, sorter_name)] = comp - self.exhaustive_gt = exhaustive_gt - self.collision_lag = collision_lag - - def get_lags(self): - fs = self.comparisons[(self.rec_names[0], self.sorter_names[0])].sorting1.get_sampling_frequency() - lags = self.comparisons[(self.rec_names[0], self.sorter_names[0])].bins / fs * 1000 - return lags - - def precompute_scores_by_similarities(self, good_only=True, min_accuracy=0.9): - if not hasattr(self, "_good_only") or self._good_only != good_only: - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_similarities = [] - all_recall_scores = [] - - for rec_name in self.rec_names: - if (rec_name, sorter_name) in self.comparisons.keys(): - comp = self.comparisons[(rec_name, sorter_name)] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity_matrix[rec_name], good_only=good_only, min_accuracy=min_accuracy - ) - - all_similarities.append(similarities) - all_recall_scores.append(recall_scores) - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_recall_scores[sorter_name] = np.concatenate(all_recall_scores, axis=0) - - def get_mean_over_similarity_range(self, similarity_range, sorter_name): - idx = (self.all_similarities[sorter_name] >= similarity_range[0]) & ( - self.all_similarities[sorter_name] <= similarity_range[1] - ) - all_similarities = self.all_similarities[sorter_name][idx] - all_recall_scores = self.all_recall_scores[sorter_name][idx] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = np.nanmean(all_recall_scores, axis=0) - - return mean_recall_scores - - def get_lag_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_recall_scores = self.all_recall_scores[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) - result[(cmin, cmax)] = mean_recall_scores - - return result diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index db45e2b25b..20ee7910b4 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun matched_units2 = match_12[match_12 != -1].values unmatched_units1 = match_12[match_12 == -1].index - unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)] + unmatched_units2 = unit2_ids[~np.isin(unit2_ids, matched_units2)] ordered_units1 = np.hstack([matched_units1, unmatched_units1]) ordered_units2 = np.hstack([matched_units2, unmatched_units2]) diff --git a/src/spikeinterface/comparison/correlogramcomparison.py b/src/spikeinterface/comparison/correlogram.py similarity index 64% rename from src/spikeinterface/comparison/correlogramcomparison.py rename to src/spikeinterface/comparison/correlogram.py index 80e881a152..aaffef1887 100644 --- a/src/spikeinterface/comparison/correlogramcomparison.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -1,16 +1,17 @@ -import numpy as np from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy from spikeinterface.postprocessing import compute_correlograms +import numpy as np + + class CorrelogramGTComparison(GroundTruthComparison): """ This class is an extension of GroundTruthComparison by focusing - to benchmark correlation reconstruction - + to benchmark correlation reconstruction. - collision_lag: float - Collision lag in ms. + This class needs maintenance and need a bit of refactoring. """ @@ -105,6 +106,62 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): order = np.argsort(similarities) similarities = similarities[order] - errors = errors[order, :] + errors = errors[order] return similarities, errors + + +class CorrelogramGTStudy(GroundTruthStudy): + def run_comparisons( + self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs + ): + _kwargs = dict() + _kwargs.update(kwargs) + _kwargs["exhaustive_gt"] = exhaustive_gt + _kwargs["window_ms"] = window_ms + _kwargs["bin_ms"] = bin_ms + _kwargs["well_detected_score"] = well_detected_score + GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) + self.exhaustive_gt = exhaustive_gt + + @property + def time_bins(self): + for key, value in self.comparisons.items(): + return value.time_bins + + def precompute_scores_by_similarities(self, case_keys=None, good_only=True): + import sklearn.metrics + + if case_keys is None: + case_keys = self.cases.keys() + + self.all_similarities = {} + self.all_errors = {} + + for key in case_keys: + templates = self.get_templates(key) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + comp = self.comparisons[key] + similarities, errors = comp.compute_correlogram_by_similarity(similarity) + + self.all_similarities[key] = similarities + self.all_errors[key] = errors + + def get_error_profile_over_similarity_bins(self, similarity_bins, key): + all_similarities = self.all_similarities[key] + all_errors = self.all_errors[key] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_errors = all_errors[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_errors = np.nanmean(all_errors[amin:amax], axis=0) + result[(cmin, cmax)] = mean_errors + + return result diff --git a/src/spikeinterface/comparison/correlogramstudy.py b/src/spikeinterface/comparison/correlogramstudy.py deleted file mode 100644 index fb00c08157..0000000000 --- a/src/spikeinterface/comparison/correlogramstudy.py +++ /dev/null @@ -1,76 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .correlogramcomparison import CorrelogramGTComparison - -import numpy as np - - -class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CorrelogramGTComparison( - gt_sorting, - sorting, - exhaustive_gt=exhaustive_gt, - window_ms=window_ms, - bin_ms=bin_ms, - well_detected_score=well_detected_score, - ) - self.comparisons[(rec_name, sorter_name)] = comp - - self.exhaustive_gt = exhaustive_gt - - @property - def time_bins(self): - for key, value in self.comparisons.items(): - return value.time_bins - - def precompute_scores_by_similarities(self, good_only=True): - if not hasattr(self, "_computed"): - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_errors = {} - self._computed = True - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_errors = [] - all_similarities = [] - for rec_name in self.rec_names: - try: - comp = self.comparisons[(rec_name, sorter_name)] - similarities, errors = comp.compute_correlogram_by_similarity(similarity_matrix[rec_name]) - all_similarities.append(similarities) - all_errors.append(errors) - except Exception: - pass - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_errors[sorter_name] = np.concatenate(all_errors, axis=0) - - def get_error_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_errors = self.all_errors[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_errors = np.nanmean(all_errors[amin:amax], axis=0) - result[(cmin, cmax)] = mean_errors - - return result diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 7b146f07bc..e5f4ce8b31 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -1,327 +1,406 @@ from pathlib import Path import shutil +import os +import json +import pickle import numpy as np -from spikeinterface.core import load_extractor -from spikeinterface.extractors import NpzSortingExtractor -from spikeinterface.sorters import sorter_dict, run_sorters +from spikeinterface.core import load_extractor, extract_waveforms, load_waveforms +from spikeinterface.core.core_tools import SIJsonEncoder + +from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder from spikeinterface import WaveformExtractor from spikeinterface.qualitymetrics import compute_quality_metrics -from .paircomparisons import compare_sorter_to_ground_truth - -from .studytools import ( - setup_comparison_study, - get_rec_names, - get_recordings, - iter_working_folder, - iter_computed_names, - iter_computed_sorting, - collect_run_times, -) +from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison -class GroundTruthStudy: - def __init__(self, study_folder=None): - import pandas as pd +# TODO later : save comparison in folders when comparison object will be able to serialize - self.study_folder = Path(study_folder) - self._is_scanned = False - self.computed_names = None - self.rec_names = None - self.sorter_names = None - self.scan_folder() +# This is to separate names when the key are tuples when saving folders +_key_separator = " ## " - self.comparisons = None - self.exhaustive_gt = None - def __repr__(self): - t = "Ground truth study\n" - t += " " + str(self.study_folder) + "\n" - t += " recordings: {} {}\n".format(len(self.rec_names), self.rec_names) - if len(self.sorter_names): - t += " sorters: {} {}\n".format(len(self.sorter_names), self.sorter_names) +class GroundTruthStudy: + """ + This class is an helper function to run any comparison on several "cases" for many ground-truth dataset. - return t + "cases" refer to: + * several sorters for comparisons + * same sorter with differents parameters + * any combination of these (and more) - def scan_folder(self): - self.rec_names = get_rec_names(self.study_folder) - # scan computed names - self.computed_names = list(iter_computed_names(self.study_folder)) # list of pair (rec_name, sorter_name) - self.sorter_names = np.unique([e for _, e in iter_computed_names(self.study_folder)]).tolist() - self._is_scanned = True + For increased flexibility, cases keys can be a tuple so that we can vary complexity along several + "levels" or "axis" (paremeters or sorters). + In this case, the result dataframes will have `MultiIndex` to handle the different levels. - @classmethod - def create(cls, study_folder, gt_dict, **job_kwargs): - setup_comparison_study(study_folder, gt_dict, **job_kwargs) - return cls(study_folder) + A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see + :py:fun:`~spikeinterface.core.generate.generate_ground_truth_recording()`). - def run_sorters(self, sorter_list, mode_if_folder_exists="keep", remove_sorter_folders=False, **kwargs): - sorter_folders = self.study_folder / "sorter_folders" - recording_dict = get_recordings(self.study_folder) - - run_sorters( - sorter_list, - recording_dict, - sorter_folders, - with_output=False, - mode_if_folder_exists=mode_if_folder_exists, - **kwargs, - ) - - # results are copied so the heavy sorter_folders can be removed - self.copy_sortings() - - if remove_sorter_folders: - shutil.rmtree(self.study_folder / "sorter_folders") - - def _check_rec_name(self, rec_name): - if not self._is_scanned: - self.scan_folder() - if len(self.rec_names) > 1 and rec_name is None: - raise Exception("Pass 'rec_name' parameter to select which recording to use.") - elif len(self.rec_names) == 1: - rec_name = self.rec_names[0] - else: - rec_name = self.rec_names[self.rec_names.index(rec_name)] - return rec_name - - def get_ground_truth(self, rec_name=None): - rec_name = self._check_rec_name(rec_name) - sorting = load_extractor(self.study_folder / "ground_truth" / rec_name) - return sorting - - def get_recording(self, rec_name=None): - rec_name = self._check_rec_name(rec_name) - rec = load_extractor(self.study_folder / "raw_files" / rec_name) - return rec - - def get_sorting(self, sort_name, rec_name=None): - rec_name = self._check_rec_name(rec_name) - - selected_sorting = None - if sort_name in self.sorter_names: - for r_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - if sort_name == sorter_name and r_name == rec_name: - selected_sorting = sorting - return selected_sorting - - def copy_sortings(self): - sorter_folders = self.study_folder / "sorter_folders" - sorting_folders = self.study_folder / "sortings" - log_olders = self.study_folder / "sortings" / "run_log" - - log_olders.mkdir(parents=True, exist_ok=True) - - for rec_name, sorter_name, output_folder in iter_working_folder(sorter_folders): - SorterClass = sorter_dict[sorter_name] - fname = rec_name + "[#]" + sorter_name - npz_filename = sorting_folders / (fname + ".npz") - - try: - sorting = SorterClass.get_result_from_folder(output_folder) - NpzSortingExtractor.write_sorting(sorting, npz_filename) - except: - if npz_filename.is_file(): - npz_filename.unlink() - if (output_folder / "spikeinterface_log.json").is_file(): - shutil.copyfile( - output_folder / "spikeinterface_log.json", sorting_folders / "run_log" / (fname + ".json") - ) + This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. + Note that the underlying folder structure is not backward compatible! + """ - self.scan_folder() + def __init__(self, study_folder): + self.folder = Path(study_folder) - def run_comparisons(self, exhaustive_gt=False, **kwargs): + self.datasets = {} + self.cases = {} + self.sortings = {} self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - sc = compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=exhaustive_gt, **kwargs) - self.comparisons[(rec_name, sorter_name)] = sc - self.exhaustive_gt = exhaustive_gt - def aggregate_run_times(self): - return collect_run_times(self.study_folder) - - def aggregate_performance_by_unit(self): - assert self.comparisons is not None, "run_comparisons first" + self.scan_folder() - perf_by_unit = [] - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - comp = self.comparisons[(rec_name, sorter_name)] + @classmethod + def create(cls, study_folder, datasets={}, cases={}, levels=None): + # check that cases keys are homogeneous + key0 = list(cases.keys())[0] + if isinstance(key0, str): + assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" + if levels is None: + levels = "level0" + else: + assert isinstance(levels, str) + elif isinstance(key0, tuple): + assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" + num_levels = len(key0) + assert all( + len(key) == num_levels for key in cases.keys() + ), "Keys for cases are not homogeneous, tuple negth differ" + if levels is None: + levels = [f"level{i}" for i in range(num_levels)] + else: + levels = list(levels) + assert len(levels) == num_levels + else: + raise ValueError("Keys for cases must str or tuple") - perf = comp.get_performance(method="by_unit", output="pandas") - perf["rec_name"] = rec_name - perf["sorter_name"] = sorter_name - perf = perf.reset_index() - perf_by_unit.append(perf) + study_folder = Path(study_folder) + study_folder.mkdir(exist_ok=False, parents=True) - import pandas as pd + (study_folder / "datasets").mkdir() + (study_folder / "datasets" / "recordings").mkdir() + (study_folder / "datasets" / "gt_sortings").mkdir() + (study_folder / "sorters").mkdir() + (study_folder / "sortings").mkdir() + (study_folder / "sortings" / "run_logs").mkdir() + (study_folder / "metrics").mkdir() - perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(["rec_name", "sorter_name", "gt_unit_id"]) + for key, (rec, gt_sorting) in datasets.items(): + assert "/" not in key, "'/' cannot be in the key name!" + assert "\\" not in key, "'\\' cannot be in the key name!" - return perf_by_unit + # recordings are pickled + rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") - def aggregate_count_units(self, well_detected_score=None, redundant_score=None, overmerged_score=None): - assert self.comparisons is not None, "run_comparisons first" + # sortings are pickled + saved as NumpyFolderSorting + gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") + gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") - import pandas as pd + info = {} + info["levels"] = levels + (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") - index = pd.MultiIndex.from_tuples(self.computed_names, names=["rec_name", "sorter_name"]) + # cases is dumped to a pickle file, json is not possible because of the tuple key + (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) - count_units = pd.DataFrame( - index=index, - columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant", "num_overmerged"], - dtype=int, - ) + return cls(study_folder) - if self.exhaustive_gt: - count_units["num_false_positive"] = pd.Series(dtype=int) - count_units["num_bad"] = pd.Series(dtype=int) + def scan_folder(self): + if not (self.folder / "datasets").exists(): + raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = self.comparisons[(rec_name, sorter_name)] + with open(self.folder / "info.json", "r") as f: + self.info = json.load(f) - count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units( - well_detected_score - ) - if self.exhaustive_gt: - count_units.loc[(rec_name, sorter_name), "num_overmerged"] = comp.count_overmerged_units( - overmerged_score - ) - count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units(redundant_score) - count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units( - redundant_score - ) - count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() + self.levels = self.info["levels"] - return count_units + for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): + key = rec_file.stem + rec = load_extractor(rec_file) + gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) + self.datasets[key] = (rec, gt_sorting) - def aggregate_dataframes(self, copy_into_folder=True, **karg_thresh): - dataframes = {} - dataframes["run_times"] = self.aggregate_run_times().reset_index() - perfs = self.aggregate_performance_by_unit() + with open(self.folder / "cases.pickle", "rb") as f: + self.cases = pickle.load(f) - dataframes["perf_by_unit"] = perfs.reset_index() - dataframes["count_units"] = self.aggregate_count_units(**karg_thresh).reset_index() + self.comparisons = {k: None for k in self.cases} - if copy_into_folder: - tables_folder = self.study_folder / "tables" - tables_folder.mkdir(parents=True, exist_ok=True) + self.sortings = {} + for key in self.cases: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + if sorting_folder.exists(): + sorting = load_extractor(sorting_folder) + else: + sorting = None + self.sortings[key] = sorting - for name, df in dataframes.items(): - df.to_csv(str(tables_folder / (name + ".csv")), sep="\t", index=False) - - return dataframes + def __repr__(self): + t = f"{self.__class__.__name__} {self.folder.stem} \n" + t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" + t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" + num_computed = sum([1 for sorting in self.sortings.values() if sorting is not None]) + t += f" computed: {num_computed}\n" - def get_waveform_extractor(self, rec_name, sorter_name=None): - rec = self.get_recording(rec_name) + return t - if sorter_name is None: - name = "GroundTruth" - sorting = self.get_ground_truth(rec_name) + def key_to_str(self, key): + if isinstance(key, str): + return key + elif isinstance(key, tuple): + return _key_separator.join(key) else: - assert sorter_name in self.sorter_names - name = sorter_name - sorting = self.get_sorting(sorter_name, rec_name) - - waveform_folder = self.study_folder / "waveforms" / f"waveforms_{name}_{rec_name}" + raise ValueError("Keys for cases must str or tuple") + + def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): + if case_keys is None: + case_keys = self.cases.keys() + + job_list = [] + for key in case_keys: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + sorting_exists = sorting_folder.exists() + + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + sorter_folder_exists = sorter_folder.exists() + + if keep: + if sorting_exists: + continue + if sorter_folder_exists: + # the sorter folder exists but havent been copied to sortings folder + sorting = read_sorter_folder(sorter_folder, raise_error=False) + if sorting is not None: + # save and skip + self.copy_sortings(case_keys=[key]) + continue + + if sorting_exists: + # delete older sorting + log before running sorters + shutil.rmtree(sorting_folder) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" + if log_file.exists(): + log_file.unlink() + + if sorter_folder_exists: + shutil.rmtree(sorter_folder) + + params = self.cases[key]["run_sorter_params"].copy() + # this ensure that sorter_name is given + recording, _ = self.datasets[self.cases[key]["dataset"]] + sorter_name = params.pop("sorter_name") + job = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=sorter_folder, + ) + job.update(params) + # the verbose is overwritten and global to all run_sorters + job["verbose"] = verbose + job["with_output"] = False + job_list.append(job) + + run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) + + # TODO later create a list in laucher for engine blocking and non-blocking + if engine not in ("slurm",): + self.copy_sortings(case_keys) + + def copy_sortings(self, case_keys=None, force=True): + if case_keys is None: + case_keys = self.cases.keys() + + for key in case_keys: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" + + if (sorter_folder / "spikeinterface_log.json").exists(): + sorting = read_sorter_folder( + sorter_folder, raise_error=False, register_recording=False, sorting_info=False + ) + else: + sorting = None + + if sorting is not None: + if sorting_folder.exists(): + if force: + # delete folder + log + shutil.rmtree(sorting_folder) + if log_file.exists(): + log_file.unlink() + else: + continue + + sorting = sorting.save(format="numpy_folder", folder=sorting_folder) + self.sortings[key] = sorting + + # copy logs + shutil.copyfile(sorter_folder / "spikeinterface_log.json", log_file) + + def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): + if case_keys is None: + case_keys = self.cases.keys() + + for key in case_keys: + dataset_key = self.cases[key]["dataset"] + _, gt_sorting = self.datasets[dataset_key] + sorting = self.sortings[key] + if sorting is None: + self.comparisons[key] = None + continue + comp = comparison_class(gt_sorting, sorting, **kwargs) + self.comparisons[key] = comp + + def get_run_times(self, case_keys=None): + import pandas as pd - if waveform_folder.is_dir(): - we = WaveformExtractor.load(waveform_folder) - else: - we = WaveformExtractor.create(rec, sorting, waveform_folder) + if case_keys is None: + case_keys = self.cases.keys() + + log_folder = self.folder / "sortings" / "run_logs" + + run_times = {} + for key in case_keys: + log_file = log_folder / f"{self.key_to_str(key)}.json" + with open(log_file, mode="r") as logfile: + log = json.load(logfile) + run_time = log.get("run_time", None) + run_times[key] = run_time + + return pd.Series(run_times, name="run_time") + + def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): + if case_keys is None: + case_keys = self.cases.keys() + + base_folder = self.folder / "waveforms" + base_folder.mkdir(exist_ok=True) + + dataset_keys = [self.cases[key]["dataset"] for key in case_keys] + dataset_keys = set(dataset_keys) + for dataset_key in dataset_keys: + # the waveforms depend on the dataset key + wf_folder = base_folder / self.key_to_str(dataset_key) + recording, gt_sorting = self.datasets[dataset_key] + we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs) + + def get_waveform_extractor(self, key): + # some recording are not dumpable to json and the waveforms extactor need it! + # so we load it with and put after + # this should be fixed in PR 2027 so remove this after + + dataset_key = self.cases[key]["dataset"] + wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) + we = load_waveforms(wf_folder, with_recording=False) + recording, _ = self.datasets[dataset_key] + we.set_recording(recording) return we - def compute_waveforms( - self, - rec_name, - sorter_name=None, - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - n_jobs=-1, - total_memory="1G", - ): - we = self.get_waveform_extractor(rec_name, sorter_name) - we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) - we.run_extract_waveforms(n_jobs=n_jobs, total_memory=total_memory) - - def get_templates(self, rec_name, sorter_name=None, mode="median"): - """ - Get template for a given recording. - - If sorter_name=None then template are from the ground truth. - - """ - we = self.get_waveform_extractor(rec_name, sorter_name=sorter_name) + def get_templates(self, key, mode="average"): + we = self.get_waveform_extractor(key) templates = we.get_all_templates(mode=mode) return templates - def compute_metrics( - self, - rec_name, - metric_names=["snr"], - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - n_jobs=-1, - total_memory="1G", - ): - we = self.get_waveform_extractor(rec_name) - we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) - we.run_extract_waveforms(n_jobs=n_jobs, total_memory=total_memory) - - # metrics - metrics = compute_quality_metrics(we, metric_names=metric_names) - folder = self.study_folder / "metrics" - folder.mkdir(exist_ok=True) - filename = folder / f"metrics _{rec_name}.txt" - metrics.to_csv(filename, sep="\t", index=True) + def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): + if case_keys is None: + case_keys = self.cases.keys() + + done = [] + for key in case_keys: + dataset_key = self.cases[key]["dataset"] + if dataset_key in done: + # some case can share the same waveform extractor + continue + done.append(dataset_key) + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + if filename.exists(): + if force: + os.remove(filename) + else: + continue + we = self.get_waveform_extractor(key) + metrics = compute_quality_metrics(we, metric_names=metric_names) + metrics.to_csv(filename, sep="\t", index=True) + + def get_metrics(self, key): + import pandas as pd + + dataset_key = self.cases[key]["dataset"] + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + if not filename.exists(): + return + metrics = pd.read_csv(filename, sep="\t", index_col=0) + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + metrics.index = gt_sorting.unit_ids return metrics - def get_metrics(self, rec_name=None, **metric_kwargs): - """ - Load or compute units metrics for a given recording. - """ - rec_name = self._check_rec_name(rec_name) - metrics_folder = self.study_folder / "metrics" - metrics_folder.mkdir(parents=True, exist_ok=True) + def get_units_snr(self, key): + """ """ + return self.get_metrics(key)["snr"] - filename = self.study_folder / "metrics" / f"metrics _{rec_name}.txt" + def get_performance_by_unit(self, case_keys=None): import pandas as pd - if filename.is_file(): - metrics = pd.read_csv(filename, sep="\t", index_col=0) - gt_sorting = self.get_ground_truth(rec_name) - metrics.index = gt_sorting.unit_ids + if case_keys is None: + case_keys = self.cases.keys() + + perf_by_unit = [] + for key in case_keys: + comp = self.comparisons.get(key, None) + assert comp is not None, "You need to do study.run_comparisons() first" + + perf = comp.get_performance(method="by_unit", output="pandas") + if isinstance(key, str): + perf[self.levels] = key + elif isinstance(key, tuple): + for col, k in zip(self.levels, key): + perf[col] = k + + perf = perf.reset_index() + perf_by_unit.append(perf) + + perf_by_unit = pd.concat(perf_by_unit) + perf_by_unit = perf_by_unit.set_index(self.levels) + return perf_by_unit + + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) else: - metrics = self.compute_metrics(rec_name, **metric_kwargs) + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) - metrics.index.name = "unit_id" - # add rec name columns - metrics["rec_name"] = rec_name + columns = ["num_gt", "num_sorter", "num_well_detected"] + comp = self.comparisons[case_keys[0]] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) - return metrics + for key in case_keys: + comp = self.comparisons.get(key, None) + assert comp is not None, "You need to do study.run_comparisons() first" - def get_units_snr(self, rec_name=None, **metric_kwargs): - """ """ - metric = self.get_metrics(rec_name=rec_name, **metric_kwargs) - return metric["snr"] - - def concat_all_snr(self): - metrics = [] - for rec_name in self.rec_names: - df = self.get_metrics(rec_name) - df = df.reset_index() - metrics.append(df) - metrics = pd.concat(metrics) - metrics = metrics.set_index(["rec_name", "unit_id"]) - return metrics["snr"] + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index af410255b9..e0c98cd772 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -39,7 +39,7 @@ class HybridUnitsRecording(InjectTemplatesRecording): The refractory period of the injected spike train (in ms). injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serialisable to file. Returns ------- @@ -84,7 +84,8 @@ def __init__( ) # save injected sorting if necessary self.injected_sorting = injected_sorting - if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) @@ -137,7 +138,7 @@ class HybridSpikesRecording(InjectTemplatesRecording): this refractory period. injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serializable to file. Returns ------- @@ -180,7 +181,8 @@ def __init__( self.injected_sorting = injected_sorting # save injected sorting if necessary - if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 9e02fd5b2d..f44e14c4c4 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -1,6 +1,7 @@ from pathlib import Path import json import pickle +import warnings import numpy as np @@ -180,9 +181,16 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou return sorting def save_to_folder(self, save_folder): + warnings.warn( + "save_to_folder() is deprecated. " + "You should save and load the multi sorting comparison object using pickle." + "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + DeprecationWarning, + stacklevel=2, + ) for sorting in self.object_list: - assert ( - sorting.check_if_json_serializable() + assert sorting.check_serializablility( + "json" ), "MultiSortingComparison.save_to_folder() need json serializable sortings" save_folder = Path(save_folder) @@ -205,6 +213,13 @@ def save_to_folder(self, save_folder): @staticmethod def load_from_folder(folder_path): + warnings.warn( + "load_from_folder() is deprecated. " + "You should save and load the multi sorting comparison object using pickle." + "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + DeprecationWarning, + stacklevel=2, + ) folder_path = Path(folder_path) with (folder_path / "kwargs.json").open() as f: kwargs = json.load(f) @@ -244,7 +259,8 @@ def __init__( BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) - self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = True if len(unit_ids) > 0: for k in ("agreement_number", "avg_agreement", "unit_ids"): diff --git a/src/spikeinterface/comparison/studytools.py b/src/spikeinterface/comparison/studytools.py deleted file mode 100644 index 79227c865f..0000000000 --- a/src/spikeinterface/comparison/studytools.py +++ /dev/null @@ -1,316 +0,0 @@ -""" -High level tools to run many ground-truth comparison with -many sorter on many recordings and then collect and aggregate results -in an easy way. - -The all mechanism is based on an intrinsic organization -into a "study_folder" with several subfolder: - * raw_files : contain a copy in binary format of recordings - * sorter_folders : contains output of sorters - * ground_truth : contains a copy of sorting ground in npz format - * sortings: contains light copy of all sorting in npz format - * tables: some table in cvs format -""" - -from pathlib import Path -import shutil -import json -import os - - -from spikeinterface.core import load_extractor -from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.extractors import NpzSortingExtractor -from spikeinterface.sorters import sorter_dict -from spikeinterface.sorters.launcher import iter_working_folder, iter_sorting_output - -from .comparisontools import _perf_keys -from .paircomparisons import compare_sorter_to_ground_truth - - -def setup_comparison_study(study_folder, gt_dict, **job_kwargs): - """ - Based on a dict of (recording, sorting) create the study folder. - - Parameters - ---------- - study_folder: str - The study folder. - gt_dict : a dict of tuple (recording, sorting_gt) - Dict of tuple that contain recording and sorting ground truth - """ - job_kwargs = fix_job_kwargs(job_kwargs) - study_folder = Path(study_folder) - assert not study_folder.is_dir(), "'study_folder' already exists. Please remove it" - - study_folder.mkdir(parents=True, exist_ok=True) - sorting_folders = study_folder / "sortings" - log_folder = sorting_folders / "run_log" - log_folder.mkdir(parents=True, exist_ok=True) - tables_folder = study_folder / "tables" - tables_folder.mkdir(parents=True, exist_ok=True) - - for rec_name, (recording, sorting_gt) in gt_dict.items(): - # write recording using save with binary - folder = study_folder / "ground_truth" / rec_name - sorting_gt.save(folder=folder, format="numpy_folder") - folder = study_folder / "raw_files" / rec_name - recording.save(folder=folder, format="binary", **job_kwargs) - - # make an index of recording names - with open(study_folder / "names.txt", mode="w", encoding="utf8") as f: - for rec_name in gt_dict: - f.write(rec_name + "\n") - - -def get_rec_names(study_folder): - """ - Get list of keys of recordings. - Read from the 'names.txt' file in study folder. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - rec_names: list - List of names. - """ - study_folder = Path(study_folder) - with open(study_folder / "names.txt", mode="r", encoding="utf8") as f: - rec_names = f.read()[:-1].split("\n") - return rec_names - - -def get_recordings(study_folder): - """ - Get ground recording as a dict. - - They are read from the 'raw_files' folder with binary format. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - recording_dict: dict - Dict of recording. - """ - study_folder = Path(study_folder) - - rec_names = get_rec_names(study_folder) - recording_dict = {} - for rec_name in rec_names: - rec = load_extractor(study_folder / "raw_files" / rec_name) - recording_dict[rec_name] = rec - - return recording_dict - - -def get_ground_truths(study_folder): - """ - Get ground truth sorting extractor as a dict. - - They are read from the 'ground_truth' folder with npz format. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - ground_truths: dict - Dict of sorting_gt. - """ - study_folder = Path(study_folder) - rec_names = get_rec_names(study_folder) - ground_truths = {} - for rec_name in rec_names: - sorting = load_extractor(study_folder / "ground_truth" / rec_name) - ground_truths[rec_name] = sorting - return ground_truths - - -def iter_computed_names(study_folder): - sorting_folder = Path(study_folder) / "sortings" - for filename in os.listdir(sorting_folder): - if filename.endswith(".npz") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".npz", "").split("[#]") - yield rec_name, sorter_name - - -def iter_computed_sorting(study_folder): - """ - Iter over sorting files. - """ - sorting_folder = Path(study_folder) / "sortings" - for filename in os.listdir(sorting_folder): - if filename.endswith(".npz") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".npz", "").split("[#]") - sorting = NpzSortingExtractor(sorting_folder / filename) - yield rec_name, sorter_name, sorting - - -def collect_run_times(study_folder): - """ - Collect run times in a working folder and store it in CVS files. - - The output is list of (rec_name, sorter_name, run_time) - """ - import pandas as pd - - study_folder = Path(study_folder) - sorting_folders = study_folder / "sortings" - log_folder = sorting_folders / "run_log" - tables_folder = study_folder / "tables" - - tables_folder.mkdir(parents=True, exist_ok=True) - - run_times = [] - for filename in os.listdir(log_folder): - if filename.endswith(".json") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".json", "").split("[#]") - with open(log_folder / filename, encoding="utf8", mode="r") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - run_times.append((rec_name, sorter_name, run_time)) - - run_times = pd.DataFrame(run_times, columns=["rec_name", "sorter_name", "run_time"]) - run_times = run_times.set_index(["rec_name", "sorter_name"]) - - return run_times - - -def aggregate_sorting_comparison(study_folder, exhaustive_gt=False): - """ - Loop over output folder in a tree to collect sorting output and run - ground_truth_comparison on them. - - Parameters - ---------- - study_folder: str - The study folder. - exhaustive_gt: bool (default True) - Tell if the ground true is "exhaustive" or not. In other world if the - GT have all possible units. It allows more performance measurement. - For instance, MEArec simulated dataset have exhaustive_gt=True - - Returns - ---------- - comparisons: a dict of SortingComparison - - """ - - study_folder = Path(study_folder) - - ground_truths = get_ground_truths(study_folder) - results = collect_study_sorting(study_folder) - - comparisons = {} - for (rec_name, sorter_name), sorting in results.items(): - gt_sorting = ground_truths[rec_name] - sc = compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=exhaustive_gt) - comparisons[(rec_name, sorter_name)] = sc - - return comparisons - - -def aggregate_performances_table(study_folder, exhaustive_gt=False, **karg_thresh): - """ - Aggregate some results into dataframe to have a "study" overview on all recordingXsorter. - - Tables are: - * run_times: run times per recordingXsorter - * perf_pooled_with_sum: GroundTruthComparison.see get_performance - * perf_pooled_with_average: GroundTruthComparison.see get_performance - * count_units: given some threshold count how many units : 'well_detected', 'redundant', 'false_postive_units, 'bad' - - Parameters - ---------- - study_folder: str - The study folder. - karg_thresh: dict - Threshold parameters used for the "count_units" table. - - Returns - ------- - dataframes: a dict of DataFrame - Return several useful DataFrame to compare all results. - Note that count_units depend on karg_thresh. - """ - import pandas as pd - - study_folder = Path(study_folder) - sorter_folders = study_folder / "sorter_folders" - tables_folder = study_folder / "tables" - - comparisons = aggregate_sorting_comparison(study_folder, exhaustive_gt=exhaustive_gt) - ground_truths = get_ground_truths(study_folder) - results = collect_study_sorting(study_folder) - - study_folder = Path(study_folder) - - dataframes = {} - - # get run times: - run_times = pd.read_csv(str(tables_folder / "run_times.csv"), sep="\t") - run_times.columns = ["rec_name", "sorter_name", "run_time"] - run_times = run_times.set_index( - [ - "rec_name", - "sorter_name", - ] - ) - dataframes["run_times"] = run_times - - perf_pooled_with_sum = pd.DataFrame(index=run_times.index, columns=_perf_keys) - dataframes["perf_pooled_with_sum"] = perf_pooled_with_sum - - perf_pooled_with_average = pd.DataFrame(index=run_times.index, columns=_perf_keys) - dataframes["perf_pooled_with_average"] = perf_pooled_with_average - - count_units = pd.DataFrame( - index=run_times.index, columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant"] - ) - dataframes["count_units"] = count_units - if exhaustive_gt: - count_units["num_false_positive"] = None - count_units["num_bad"] = None - - perf_by_spiketrain = [] - - for (rec_name, sorter_name), comp in comparisons.items(): - gt_sorting = ground_truths[rec_name] - sorting = results[(rec_name, sorter_name)] - - perf = comp.get_performance(method="pooled_with_sum", output="pandas") - perf_pooled_with_sum.loc[(rec_name, sorter_name), :] = perf - - perf = comp.get_performance(method="pooled_with_average", output="pandas") - perf_pooled_with_average.loc[(rec_name, sorter_name), :] = perf - - perf = comp.get_performance(method="by_spiketrain", output="pandas") - perf["rec_name"] = rec_name - perf["sorter_name"] = sorter_name - perf = perf.reset_index() - - perf_by_spiketrain.append(perf) - - count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units(**karg_thresh) - count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units() - if exhaustive_gt: - count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units() - count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() - - perf_by_spiketrain = pd.concat(perf_by_spiketrain) - perf_by_spiketrain = perf_by_spiketrain.set_index(["rec_name", "sorter_name", "gt_unit_id"]) - dataframes["perf_by_spiketrain"] = perf_by_spiketrain - - return dataframes diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 70f8a63c8c..91c8c640e0 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -1,19 +1,11 @@ -import importlib import shutil import pytest from pathlib import Path -from spikeinterface.extractors import toy_example -from spikeinterface.sorters import installed_sorters +from spikeinterface import generate_ground_truth_recording +from spikeinterface.preprocessing import bandpass_filter from spikeinterface.comparison import GroundTruthStudy -try: - import tridesclous - - HAVE_TDC = True -except ImportError: - HAVE_TDC = False - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "comparison" @@ -27,61 +19,85 @@ def setup_module(): if study_folder.is_dir(): shutil.rmtree(study_folder) - _setup_comparison_study() + create_a_study(study_folder) + + +def simple_preprocess(rec): + return bandpass_filter(rec) -def _setup_comparison_study(): - rec0, gt_sorting0 = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) - rec1, gt_sorting1 = toy_example(num_channels=32, duration=30, seed=0, num_segments=1) +def create_a_study(study_folder): + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=91) - gt_dict = { + datasets = { "toy_tetrode": (rec0, gt_sorting0), "toy_probe32": (rec1, gt_sorting1), + "toy_probe32_preprocess": (simple_preprocess(rec1), gt_sorting1), } - study = GroundTruthStudy.create(study_folder, gt_dict) + # cases can also be generated via simple loops + cases = { + # + ("tdc2", "no-preprocess", "tetrode"): { + "label": "tridesclous2 without preprocessing and standard params", + "dataset": "toy_tetrode", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + "comparison_params": {}, + }, + # + ("tdc2", "with-preprocess", "probe32"): { + "label": "tridesclous2 with preprocessing standar params", + "dataset": "toy_probe32_preprocess", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + "comparison_params": {}, + }, + # we comment this at the moement because SC2 is quite slow for testing + # ("sc2", "no-preprocess", "tetrode"): { + # "label": "spykingcircus2 without preprocessing standar params", + # "dataset": "toy_tetrode", + # "run_sorter_params": { + # "sorter_name": "spykingcircus2", + # }, + # "comparison_params": { + # }, + # }, + } -@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -def test_run_study_sorters(): - study = GroundTruthStudy(study_folder) - sorter_list = [ - "tridesclous", - ] - print( - f"\n#################################\nINSTALLED SORTERS\n#################################\n" - f"{installed_sorters()}" + study = GroundTruthStudy.create( + study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"] ) - study.run_sorters(sorter_list) + # print(study) -@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -def test_extract_sortings(): +def test_GroundTruthStudy(): study = GroundTruthStudy(study_folder) + print(study) - study.copy_sortings() - - for rec_name in study.rec_names: - gt_sorting = study.get_ground_truth(rec_name) - - for rec_name in study.rec_names: - metrics = study.get_metrics(rec_name=rec_name) + study.run_sorters(verbose=True) - snr = study.get_units_snr(rec_name=rec_name) + print(study.sortings) - study.copy_sortings() + print(study.comparisons) + study.run_comparisons() + print(study.comparisons) - run_times = study.aggregate_run_times() + study.extract_waveforms_gt(n_jobs=-1) - study.run_comparisons(exhaustive_gt=True) + study.compute_metrics() - perf = study.aggregate_performance_by_unit() + for key in study.cases: + metrics = study.get_metrics(key) + print(metrics) - count_units = study.aggregate_count_units() - dataframes = study.aggregate_dataframes() - print(dataframes) + study.get_performance_by_unit() + study.get_count_units() if __name__ == "__main__": - # setup_module() - # test_run_study_sorters() - test_extract_sortings() + setup_module() + test_GroundTruthStudy() diff --git a/src/spikeinterface/comparison/tests/test_studytools.py b/src/spikeinterface/comparison/tests/test_studytools.py deleted file mode 100644 index dbc39d5e1d..0000000000 --- a/src/spikeinterface/comparison/tests/test_studytools.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import shutil -from pathlib import Path - -import pytest - -from spikeinterface.extractors import toy_example -from spikeinterface.comparison.studytools import ( - setup_comparison_study, - iter_computed_names, - iter_computed_sorting, - get_rec_names, - get_ground_truths, - get_recordings, -) - -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "comparison" -else: - cache_folder = Path("cache_folder") / "comparison" - - -study_folder = cache_folder / "test_studytools" - - -def setup_module(): - if study_folder.is_dir(): - shutil.rmtree(study_folder) - - -def test_setup_comparison_study(): - rec0, gt_sorting0 = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) - rec1, gt_sorting1 = toy_example(num_channels=32, duration=30, seed=0, num_segments=1) - - gt_dict = { - "toy_tetrode": (rec0, gt_sorting0), - "toy_probe32": (rec1, gt_sorting1), - } - setup_comparison_study(study_folder, gt_dict) - - -def test_get_ground_truths(): - names = get_rec_names(study_folder) - d = get_ground_truths(study_folder) - d = get_recordings(study_folder) - - -def test_loops(): - names = list(iter_computed_names(study_folder)) - for rec_name, sorter_name, sorting in iter_computed_sorting(study_folder): - print(rec_name, sorter_name) - print(sorting) - - -if __name__ == "__main__": - setup_module() - test_setup_comparison_study() - test_get_ground_truths() - test_loops() diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 87c0805630..ad31b97d8e 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -45,8 +45,12 @@ def __init__(self, main_ids: Sequence) -> None: self._kwargs = {} # 'main_ids' will either be channel_ids or units_ids - # They is used for properties + # They are used for properties self._main_ids = np.array(main_ids) + if len(self._main_ids) > 0: + assert ( + self._main_ids.dtype.kind in "uiSU" + ), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}" # dict at object level self._annotations = {} @@ -57,8 +61,7 @@ def __init__(self, main_ids: Sequence) -> None: # * number of units for sorting self._properties = {} - self._is_dumpable = True - self._is_json_serializable = True + self._serializablility = {"memory": True, "json": True, "pickle": True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -425,14 +428,15 @@ def from_dict(dictionary: dict, base_folder: Optional[Union[Path, str]] = None) extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ - if dictionary["relative_paths"]: + # for pickle dump relative_path was not in the dict, this ensure compatibility + if dictionary.get("relative_paths", False): assert base_folder is not None, "When relative_paths=True, need to provide base_folder" dictionary = _make_paths_absolute(dictionary, base_folder) extractor = _load_extractor_from_dict(dictionary) folder_metadata = dictionary.get("folder_metadata", None) if folder_metadata is not None: folder_metadata = Path(folder_metadata) - if dictionary["relative_paths"]: + if dictionary.get("relative_paths", False): folder_metadata = base_folder / folder_metadata extractor.load_metadata_from_folder(folder_metadata) return extractor @@ -471,24 +475,33 @@ def clone(self) -> "BaseExtractor": clone = BaseExtractor.from_dict(d) return clone - def check_if_dumpable(self): - """Check if the object is dumpable, including nested objects. + def check_serializablility(self, type): + kwargs = self._kwargs + for value in kwargs.values(): + # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors + if isinstance(value, BaseExtractor): + if not value.check_serializablility(type=type): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + return self._serializablility[type] + + def check_if_memory_serializable(self): + """ + Check if the object is serializable to memory with pickle, including nested objects. Returns ------- bool - True if the object is dumpable, False otherwise. + True if the object is memory serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_dumpable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_dumpable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_dumpable() for k, v in value.items()]) - return self._is_dumpable + return self.check_serializablility("memory") def check_if_json_serializable(self): """ @@ -499,16 +512,13 @@ def check_if_json_serializable(self): bool True if the object is json serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_json_serializable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_json_serializable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_json_serializable() for k, v in value.items()]) - return self._is_json_serializable + # we keep this for backward compatilibity or not ???? + # is this needed ??? I think no. + return self.check_serializablility("json") + + def check_if_pickle_serializable(self): + # is this needed ??? I think no. + return self.check_serializablility("pickle") @staticmethod def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path: @@ -557,7 +567,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No if str(file_path).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata) + self.dump_to_pickle(file_path, folder_metadata=folder_metadata) else: raise ValueError("Dump: file must .json or .pkl") @@ -576,7 +586,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_json_serializable(), "The extractor is not json serializable" + assert self.check_serializablility("json"), "The extractor is not json serializable" # Writing paths as relative_to requires recursively expanding the dict if relative_to: @@ -610,18 +620,19 @@ def dump_to_pickle( Parameters ---------- file_path: str - Path of the json file + Path of the pickle file include_properties: bool If True, all properties are dumped folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_dumpable(), "The extractor is not dumpable" + assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle" dump_dict = self.to_dict( include_annotations=True, include_properties=include_properties, folder_metadata=folder_metadata, + relative_to=None, recursive=False, ) file_path = self._get_file_path(file_path, [".pkl", ".pickle"]) @@ -653,8 +664,8 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") - if "warning" in d and "not dumpable" in d["warning"]: - print("The extractor was not dumpable") + if "warning" in d: + print("The extractor was not serializable to file") return None extractor = BaseExtractor.from_dict(d, base_folder=base_folder) return extractor @@ -814,10 +825,12 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): # dump provenance provenance_file = folder / f"provenance.json" - if self.check_if_json_serializable(): + if self.check_serializablility("json"): self.dump(provenance_file) else: - provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8") + provenance_file.write_text( + json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" + ) self.save_metadata_to_folder(folder) @@ -911,7 +924,7 @@ def save_to_zarr( zarr_root = zarr.open(zarr_path_init, mode="w", storage_options=storage_options) - if self.check_if_dumpable(): + if self.check_if_json_serializable(): zarr_root.attrs["provenance"] = check_json(self.to_dict()) else: zarr_root.attrs["provenance"] = None @@ -975,7 +988,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: class_name = None if "kwargs" not in dic: - raise Exception(f"This dict cannot be load into extractor {dic}") + raise Exception(f"This dict cannot be loaded into extractor {dic}") # Create new kwargs to avoid modifying the original dict["kwargs"] new_kwargs = dict() @@ -996,7 +1009,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class" if not _check_same_version(class_name, dic["version"]): warnings.warn( - f"Versions are not the same. This might lead compatibility errors. " + f"Versions are not the same. This might lead to compatibility errors. " f"Using {class_name.split('.')[0]}=={dic['version']} is recommended" ) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index af4970a4ad..2977211c25 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -305,7 +305,8 @@ def get_traces( if not self.has_scaled(): raise ValueError( - "This recording do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)" + "This recording does not support return_scaled=True (need gain_to_uV and offset_" + "to_uV properties)" ) else: gains = self.get_property("gain_to_uV") @@ -416,8 +417,8 @@ def set_times(self, times, segment_index=None, with_warning=True): if with_warning: warn( "Setting times with Recording.set_times() is not recommended because " - "times are not always propagated to across preprocessing" - "Use use this carefully!" + "times are not always propagated across preprocessing" + "Use this carefully!" ) def sample_index_to_time(self, sample_ind, segment_index=None): @@ -592,7 +593,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceRecording - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index affde8a75e..d411f38d2a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,4 +1,4 @@ -from typing import List +from __future__ import annotations from pathlib import Path import numpy as np @@ -19,7 +19,7 @@ class BaseRecordingSnippets(BaseExtractor): has_default_locations = False - def __init__(self, sampling_frequency: float, channel_ids: List, dtype): + def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype): BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = sampling_frequency self._dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 737087abc1..b4e3c11f55 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -1,10 +1,8 @@ from typing import List, Union -from pathlib import Path from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets import numpy as np from warnings import warn -from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes # snippets segments? @@ -139,7 +137,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceSnippets - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceSnippets(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 52f71c2399..2a06a699cb 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -170,7 +170,7 @@ def register_recording(self, recording, check_spike_frames=True): if check_spike_frames: if has_exceeding_spikes(recording, self): warnings.warn( - "Some spikes are exceeding the recording's duration! " + "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " "Might be necessary for further postprocessing." ) @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids): """ from spikeinterface import UnitsSelectionSorting - new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)] + new_unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] new_sorting = UnitsSelectionSorting(self, new_unit_ids) return new_sorting @@ -473,8 +473,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac if not concatenated: spikes_ = [] for segment_index in range(self.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left") - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left") + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left") spikes_.append(spikes[s0:s1]) spikes = spikes_ diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 72a95637f6..b45290caa5 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -91,7 +91,7 @@ def __init__( file_path_list = [Path(file_paths)] if t_starts is not None: - assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths" + assert len(t_starts) == len(file_path_list), "t_starts must be a list of the same size as file_paths" t_starts = [float(t_start) for t_start in t_starts] dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index d36e168f8d..8714580821 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -104,11 +104,11 @@ def __init__(self, channel_map, parent_segments): times_kargs0 = parent_segment0.get_times_kwargs() if times_kargs0["time_vector"] is None: for ps in parent_segments: - assert ps.get_times_kwargs()["time_vector"] is None, "All segment should not have times set" + assert ps.get_times_kwargs()["time_vector"] is None, "All segments should not have times set" else: for ps in parent_segments: assert ps.get_times_kwargs()["t_start"] == times_kargs0["t_start"], ( - "All segment should have the same " "t_start" + "All segments should have the same " "t_start" ) BaseRecordingSegment.__init__(self, **times_kargs0) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index ebd1b7db03..3a21e356a6 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -35,7 +35,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) ), "ChannelSliceRecording: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceRecording : channel_ids not unique" + ), "ChannelSliceRecording : channel_ids are not unique" sampling_frequency = parent_recording.get_sampling_frequency() @@ -123,7 +123,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): ), "ChannelSliceSnippets: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceSnippets : channel_ids not unique" + ), "ChannelSliceSnippets : channel_ids are not unique" sampling_frequency = parent_snippets.get_sampling_frequency() diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 968f27c6ad..b8574c506f 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -27,7 +27,7 @@ class FrameSliceRecording(BaseRecording): def __init__(self, parent_recording, start_frame=None, end_frame=None): channel_ids = parent_recording.get_channel_ids() - assert parent_recording.get_num_segments() == 1, "FrameSliceRecording work only with one segment" + assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment" parent_size = parent_recording.get_num_samples(0) if start_frame is None: diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index 5da5350f06..ed1391b0e2 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -36,7 +36,7 @@ class FrameSliceSorting(BaseSorting): def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike_frames=True): unit_ids = parent_sorting.get_unit_ids() - assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting work only with one segment" + assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting only works with one segment" if start_frame is None: start_frame = 0 @@ -49,10 +49,10 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = parent_n_samples assert ( end_frame <= parent_n_samples - ), "`end_frame` should be smaller than the sortings total number of samples." + ), "`end_frame` should be smaller than the sortings' total number of samples." assert ( start_frame <= parent_n_samples - ), "`start_frame` should be smaller than the sortings total number of samples." + ), "`start_frame` should be smaller than the sortings' total number of samples." if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): raise ValueError( "The sorting object has spikes exceeding the recording duration. You have to remove those spikes " @@ -67,7 +67,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = max_spike_time + 1 assert start_frame < end_frame, ( - "`start_frame` should be greater than `end_frame`. " + "`start_frame` should be less than `end_frame`. " "This may be due to start_frame >= max_spike_time, if the end frame " "was not specified explicitly." ) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..dc84d31987 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -123,6 +123,9 @@ def generate_sorting( firing_rates=3.0, empty_units=None, refractory_period_ms=3.0, # in ms + add_spikes_on_borders=False, + num_spikes_per_border=3, + border_size_samples=20, seed=None, ): """ @@ -142,6 +145,12 @@ def generate_sorting( List of units that will have no spikes. (used for testing mainly). refractory_period_ms : float, default: 3.0 The refractory period in ms + add_spikes_on_borders : bool, default: False + If True, spikes will be added close to the borders of the segments. + num_spikes_per_border : int, default: 3 + The number of spikes to add close to the borders of the segments. + border_size_samples : int, default: 20 + The size of the border in samples to add border spikes. seed : int, default: None The random seed @@ -151,11 +160,13 @@ def generate_sorting( The sorting object """ seed = _ensure_seed(seed) + rng = np.random.default_rng(seed) num_segments = len(durations) unit_ids = np.arange(num_units) spikes = [] for segment_index in range(num_segments): + num_samples = int(sampling_frequency * durations[segment_index]) times, labels = synthesize_random_firings( num_units=num_units, sampling_frequency=sampling_frequency, @@ -166,7 +177,7 @@ def generate_sorting( ) if empty_units is not None: - keep = ~np.in1d(labels, empty_units) + keep = ~np.isin(labels, empty_units) times = times[keep] labels = labels[keep] @@ -175,7 +186,23 @@ def generate_sorting( spikes_in_seg["unit_index"] = labels spikes_in_seg["segment_index"] = segment_index spikes.append(spikes_in_seg) + + if add_spikes_on_borders: + spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype) + spikes_on_borders["segment_index"] = segment_index + spikes_on_borders["unit_index"] = rng.choice(num_units, size=2 * num_spikes_per_border, replace=True) + # at start + spikes_on_borders["sample_index"][:num_spikes_per_border] = rng.integers( + 0, border_size_samples, num_spikes_per_border + ) + # at end + spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers( + num_samples - border_size_samples, num_samples, num_spikes_per_border + ) + spikes.append(spikes_on_borders) + spikes = np.concatenate(spikes) + spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) @@ -219,7 +246,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] + units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])] if len(units_not_used) == 0: continue @@ -596,6 +623,7 @@ def __init__( dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") + assert strategy in ("tile_pregenerated", "on_the_fly"), "'strategy' must be 'tile_pregenerated' or 'on_the_fly'" BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype) @@ -626,6 +654,7 @@ def __init__( "num_channels": num_channels, "durations": durations, "sampling_frequency": sampling_frequency, + "noise_level": noise_level, "dtype": dtype, "seed": seed, "strategy": strategy, @@ -848,13 +877,13 @@ def generate_single_fake_waveform( default_unit_params_range = dict( - alpha=(5_000.0, 15_000.0), + alpha=(6_000.0, 9_000.0), depolarization_ms=(0.09, 0.14), repolarization_ms=(0.5, 0.8), recovery_ms=(1.0, 1.5), positive_amplitude=(0.05, 0.15), smooth_ms=(0.03, 0.07), - decay_power=(1.2, 1.8), + decay_power=(1.4, 1.8), ) @@ -1056,6 +1085,8 @@ def __init__( dtype = parent_recording.dtype if parent_recording is not None else templates.dtype BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype) + # Important : self._serializablility is not change here because it will depend on the sorting parents itself. + n_units = len(sorting.unit_ids) assert len(templates) == n_units self.spike_vector = sorting.to_spike_vector() @@ -1071,11 +1102,11 @@ def __init__( # handle also upsampling and jitter upsample_factor = templates.shape[3] elif templates.ndim == 5: - # handle also dirft + # handle also drift raise NotImplementedError("Drift will be implented soon...") # upsample_factor = templates.shape[3] else: - raise ValueError("templates have wring dim should 3 or 4") + raise ValueError("templates have wrong dim should 3 or 4") if upsample_factor is not None: assert upsample_vector is not None @@ -1431,5 +1462,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) return recording, sorting diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index e5581c7a67..d039206296 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -96,7 +96,7 @@ def is_set_global_dataset_folder(): ######################################## global global_job_kwargs -global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) +global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global global_job_kwargs_set global_job_kwargs_set = False diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c0ee77d2fd..cf7a67489c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -167,11 +167,11 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.check_if_dumpable(): + if not recording.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( - "Recording is not dumpable and can't be processed in parallel. " - "You can use the `recording.save()` function to make it dumpable or set 'n_jobs' to 1." + "Recording is not serializable to memory and can't be processed in parallel. " + "You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1." ) return n_jobs @@ -380,10 +380,6 @@ def run(self): self.gather_func(res) else: n_jobs = min(self.n_jobs, len(all_chunks)) - ######## Do you want to limit the number of threads per process? - ######## It has to be done to speed up numpy a lot if multicores - ######## Otherwise, np.dot will be slow. How to do that, up to you - ######## This is just a suggestion, but here it adds a dependency # parallel with ProcessPoolExecutor( @@ -436,3 +432,59 @@ def function_wrapper(args): else: with threadpool_limits(limits=max_threads_per_process): return _func(segment_index, start_frame, end_frame, _worker_ctx) + + +# Here some utils copy/paste from DART (Charlie Windolf) + + +class MockFuture: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__(self, f, *args): + self.f = f + self.args = args + + def result(self): + return self.f(*self.args) + + +class MockPoolExecutor: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__( + self, + max_workers=None, + mp_context=None, + initializer=None, + initargs=None, + context=None, + ): + if initializer is not None: + initializer(*initargs) + self.map = map + self.imap = map + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return + + def submit(self, f, *args): + return MockFuture(f, *args) + + +class MockQueue: + """Another helper class for turning off concurrency when debugging.""" + + def __init__(self): + self.q = [] + self.put = self.q.append + self.get = lambda: self.q.pop(0) + + +def get_poolexecutor(n_jobs): + if n_jobs == 1: + return MockPoolExecutor + else: + return ProcessPoolExecutor diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index b11f40a441..a0ded216d1 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -111,8 +111,7 @@ def __init__(self, recording, peaks): # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(peaks["segment_index"], segment_index) - i1 = np.searchsorted(peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -125,8 +124,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -183,8 +181,7 @@ def __init__( # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(self.peaks["segment_index"], segment_index) - i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -197,8 +194,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -436,6 +432,7 @@ def run_node_pipeline( job_name="pipeline", mp_context=None, gather_mode="memory", + gather_kwargs={}, squeeze_output=True, folder=None, names=None, @@ -452,7 +449,7 @@ def run_node_pipeline( if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": - gather_func = GatherToNpy(folder, names) + gather_func = GatherToNpy(folder, names, **gather_kwargs) else: raise ValueError(f"wrong gather_mode : {gather_mode}") @@ -597,9 +594,9 @@ class GatherToNpy: * create the npy v1.0 header at the end with the correct shape and dtype """ - def __init__(self, folder, names, npy_header_size=1024): + def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): self.folder = Path(folder) - self.folder.mkdir(parents=True, exist_ok=False) + self.folder.mkdir(parents=True, exist_ok=exist_ok) assert names is not None self.names = names self.npy_header_size = npy_header_size diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index 80979ce6c9..69c48356e5 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -27,6 +27,9 @@ def __init__( num_segments = len(file_paths) data = np.load(file_paths[0], mmap_mode="r") + if channel_ids is None: + channel_ids = np.arange(data["snippet"].shape[2]) + BaseSnippets.__init__( self, sampling_frequency, @@ -84,7 +87,7 @@ def write_snippets(snippets, file_paths, dtype=None): arr = np.empty(n, dtype=snippets_t, order="F") arr["frame"] = snippets.get_frames(segment_index=i) arr["snippet"] = snippets.get_snippets(segment_index=i).astype(dtype, copy=False) - + file_paths[i].parent.mkdir(parents=True, exist_ok=True) np.save(file_paths[i], arr) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 97f22615df..3d7ec6cd1a 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -64,7 +64,8 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" t_starts = [float(t_start) for t_start in t_starts] - self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for i, traces in enumerate(traces_list): if t_starts is None: @@ -126,8 +127,10 @@ def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True - self._is_json_serializable = False + self._serializablility["memory"] = True + self._serializablility["json"] = False + # theorically this should be False but for simplicity make generators simples we still need this. + self._serializablility["pickle"] = True if spikes.size == 0: nseg = 1 @@ -338,8 +341,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): if self.spikes_in_seg is None: # the slicing of segment is done only once the first time # this fasten the constructor a lot - s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left") - s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left") + s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1]) self.spikes_in_seg = self.spikes[s0:s1] unit_index = self.unit_ids.index(unit_id) @@ -358,8 +360,10 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting" BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True - self._is_json_serializable = False + + self._serializablility["memory"] = True + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.shm = SharedMemory(shm_name, create=False) self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) @@ -517,8 +521,9 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore dtype=dtype, ) - self._is_dumpable = False - self._is_json_serializable = False + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for snippets, spikesframes in zip(snippets_list, spikesframes_list): snp_segment = NumpySnippetsSegment(snippets, spikesframes) diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 1ff31127f4..879700cc15 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -181,9 +181,10 @@ def __init__(self, oldapi_recording_extractor): dtype=oldapi_recording_extractor.get_dtype(return_scaled=False), ) - # set _is_dumpable to False to use dumping mechanism of old extractor - self._is_dumpable = False - self._is_json_serializable = False + # set to False to use dumping mechanism of old extractor + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.annotate(is_filtered=oldapi_recording_extractor.is_filtered) @@ -268,8 +269,9 @@ def __init__(self, oldapi_sorting_extractor): sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor) self.add_sorting_segment(sorting_segment) - self._is_dumpable = False - self._is_json_serializable = False + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False # add old properties copy_properties(oldapi_extractor=oldapi_sorting_extractor, new_extractor=self) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index e5901d7ee0..ff9cd99389 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -302,7 +302,7 @@ def get_chunk_with_margin( return traces_chunk, left_margin, right_margin -def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): +def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), flip=False): """ Order channels by depth, by first ordering the x-axis, and then the y-axis. @@ -316,6 +316,9 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') + flip: bool, default: False + If flip is False then the order is bottom first (starting from tip of the probe). + If flip is True then the order is upper first. Returns ------- @@ -341,6 +344,8 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): assert dim < ndim, "Invalid dimensions!" locations_to_sort += (locations[:, dim],) order_f = np.lexsort(locations_to_sort) + if flip: + order_f = order_f[::-1] order_r = np.argsort(order_f, kind="stable") return order_f, order_r diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index f70c45bfe5..85e36cf7a5 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -174,8 +174,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # Return (0 * num_channels) array of correct dtype return self.parent_segments[0].get_traces(0, 0, channel_indices) - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) @@ -469,8 +468,7 @@ def get_unit_spike_train( if end_frame is None: end_frame = self.get_num_samples() - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 4c3680b021..896e3800d7 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from .recording_tools import get_channel_distances, get_noise_levels @@ -33,7 +35,9 @@ class ChannelSparsity: """ - Handle channel sparsity for a set of units. + Handle channel sparsity for a set of units. That is, for every unit, + it indicates which channels are used to represent the waveform and the rest + of the non-represented channels are assumed to be zero. Internally, sparsity is stored as a boolean mask. @@ -92,13 +96,21 @@ def __init__(self, mask, unit_ids, channel_ids): assert self.mask.shape[0] == self.unit_ids.shape[0] assert self.mask.shape[1] == self.channel_ids.shape[0] - # some precomputed dict + # Those are computed at first call self._unit_id_to_channel_ids = None self._unit_id_to_channel_indices = None + self.num_channels = self.channel_ids.size + self.num_units = self.unit_ids.size + if self.mask.shape[0]: + self.max_num_active_channels = self.mask.sum(axis=1).max() + else: + # empty sorting without units + self.max_num_active_channels = 0 + def __repr__(self): - ratio = np.mean(self.mask) - txt = f"ChannelSparsity - units: {self.unit_ids.size} - channels: {self.channel_ids.size} - ratio: {ratio:0.2f}" + density = np.mean(self.mask) + txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}" return txt @property @@ -119,6 +131,85 @@ def unit_id_to_channel_indices(self): self._unit_id_to_channel_indices[unit_id] = channel_inds return self._unit_id_to_channel_indices + def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray: + """ + Sparsify the waveforms according to a unit_id corresponding sparsity. + + + Given a unit_id, this method selects only the active channels for + that unit and removes the rest. + + Parameters + ---------- + waveforms : np.array + Dense waveforms with shape (num_waveforms, num_samples, num_channels) or a + single dense waveform (template) with shape (num_samples, num_channels). + unit_id : str + The unit_id for which to sparsify the waveform. + + Returns + ------- + sparsified_waveforms : np.array + Sparse waveforms with shape (num_waveforms, num_samples, num_active_channels) + or a single sparsified waveform (template) with shape (num_samples, num_active_channels). + """ + + assert_msg = ( + "Waveforms must be dense to sparsify them. " + f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}" + ) + assert self.are_waveforms_dense(waveforms=waveforms), assert_msg + + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + sparsified_waveforms = waveforms[..., non_zero_indices] + + return sparsified_waveforms + + def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray: + """ + Densify sparse waveforms that were sparisified according to a unit's channel sparsity. + + Given a unit_id its sparsified waveform, this method places the waveform back + into its original form within a dense array. + + Parameters + ---------- + waveforms : np.array + The sparsified waveforms array of shape (num_waveforms, num_samples, num_active_channels) or a single + sparse waveform (template) with shape (num_samples, num_active_channels). + unit_id : str + The unit_id that was used to sparsify the waveform. + + Returns + ------- + densified_waveforms : np.array + The densified waveforms array of shape (num_waveforms, num_samples, num_channels) or a single dense + waveform (template) with shape (num_samples, num_channels). + + """ + + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + + assert_msg = ( + "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " + f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels." + ) + assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg + + densified_shape = waveforms.shape[:-1] + (self.num_channels,) + densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) + densified_waveforms[..., non_zero_indices] = waveforms + + return densified_waveforms + + def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: + return waveforms.shape[-1] == self.num_channels + + def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool: + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + return waveforms.shape[-1] == num_active_channels + @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): """ @@ -144,16 +235,16 @@ def to_dict(self): ) @classmethod - def from_dict(cls, d): + def from_dict(cls, dictionary: dict): unit_id_to_channel_ids_corrected = {} - for unit_id in d["unit_ids"]: - if unit_id in d["unit_id_to_channel_ids"]: - unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][unit_id] + for unit_id in dictionary["unit_ids"]: + if unit_id in dictionary["unit_id_to_channel_ids"]: + unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][unit_id] else: - unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][str(unit_id)] - d["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected + unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][str(unit_id)] + dictionary["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected - return cls.from_unit_id_to_channel_ids(**d) + return cls.from_unit_id_to_channel_ids(**dictionary) ## Some convinient function to compute sparsity from several strategy @classmethod diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 95278b76da..b6022e27c0 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np import warnings @@ -5,7 +6,9 @@ from .recording_tools import get_channel_distances, get_noise_levels -def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: str = "extremum"): +def get_template_amplitudes( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum" +): """ Get amplitude per channel for each unit. @@ -13,9 +16,9 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index @@ -24,8 +27,8 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st peak_values: dict Dictionary with unit ids as keys and template amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore @@ -57,7 +60,10 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st def get_template_extremum_channel( - waveform_extractor, peak_sign: str = "neg", mode: str = "extremum", outputs: str = "id" + waveform_extractor, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" = "extremum", + outputs: "id" | "index" = "id", ): """ Compute the channel with the extremum peak for each unit. @@ -66,12 +72,12 @@ def get_template_extremum_channel( ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index - outputs: str + outputs: "id" | "index", default: "id" * 'id': channel id * 'index': channel index @@ -159,7 +165,7 @@ def get_template_channel_sparsity( get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc) -def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str = "neg"): +def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg"): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. @@ -169,8 +175,8 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels Returns ------- @@ -203,7 +209,9 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str return shifts -def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", mode: str = "at_index"): +def get_template_extremum_amplitude( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index" +): """ Computes amplitudes on the best channel. @@ -211,9 +219,9 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "at_index" Where the amplitude is computed 'extremum': max or min 'at_index': take value at spike index @@ -223,8 +231,8 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", amplitudes: dict Dictionary with unit ids as keys and amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index ea1a9cf0d2..a944be3da0 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -31,39 +31,39 @@ def make_nested_extractors(extractor): ) -def test_check_if_dumpable(): +def test_check_if_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - extractors_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_dumpable: - assert extractor.check_if_dumpable() + # make a list of memory serializable objects + extractors_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_mem_serializable: + assert extractor.check_if_memory_serializable() - # make not dumpable - test_extractor._is_dumpable = False - extractors_not_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_not_dumpable: - assert not extractor.check_if_dumpable() + # make not not memory serilizable + test_extractor._serializablility["memory"] = False + extractors_not_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_not_mem_serializable: + assert not extractor.check_if_memory_serializable() -def test_check_if_json_serializable(): +def test_check_if_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - test_extractor._is_json_serializable = True + # make a list of json serializable objects + test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) - assert extractor.check_if_json_serializable() + assert extractor.check_serializablility("json") - # make not dumpable - test_extractor._is_json_serializable = False + # make of not json serializable objects + test_extractor._serializablility["json"] = False extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: print(extractor) - assert not extractor.check_if_json_serializable() + assert not extractor.check_serializablility("json") if __name__ == "__main__": - test_check_if_dumpable() - test_check_if_json_serializable() + test_check_if_memory_serializable() + test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index a3cd0caa92..223b2a8a3a 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -142,7 +142,6 @@ def test_write_memory_recording(): recording = NoiseGeneratorRecording( num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" ) - # make dumpable recording = recording.save() # write with loop diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 9ba5de42d6..9a9c61766f 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -26,15 +26,44 @@ def test_generate_recording(): - # TODO even this is extenssivly tested in all other function + # TODO even this is extensively tested in all other functions pass def test_generate_sorting(): - # TODO even this is extenssivly tested in all other function + # TODO even this is extensively tested in all other functions pass +def test_generate_sorting_with_spikes_on_borders(): + num_spikes_on_borders = 10 + border_size_samples = 10 + segment_duration = 10 + for nseg in [1, 2, 3]: + sorting = generate_sorting( + durations=[segment_duration] * nseg, + sampling_frequency=30000, + num_units=10, + add_spikes_on_borders=True, + num_spikes_per_border=num_spikes_on_borders, + border_size_samples=border_size_samples, + ) + # check that segments are correctly sorted + all_spikes = sorting.to_spike_vector() + np.testing.assert_array_equal(all_spikes["segment_index"], np.sort(all_spikes["segment_index"])) + + spikes = sorting.to_spike_vector(concatenated=False) + # at least num_border spikes at borders for all segments + for spikes_in_segment in spikes: + # check that sample indices are correctly sorted within segments + np.testing.assert_array_equal(spikes_in_segment["sample_index"], np.sort(spikes_in_segment["sample_index"])) + num_samples = int(segment_duration * 30000) + assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders + assert ( + np.sum(spikes_in_segment["sample_index"] >= num_samples - border_size_samples) >= num_spikes_on_borders + ) + + def measure_memory_allocation(measure_in_process: bool = True) -> float: """ A local utility to measure memory allocation at a specific point in time. @@ -399,7 +428,7 @@ def test_generate_ground_truth_recording(): if __name__ == "__main__": strategy = "tile_pregenerated" # strategy = "on_the_fly" - test_noise_generator_memory() + # test_noise_generator_memory() # test_noise_generator_under_giga() # test_noise_generator_correct_shape(strategy) # test_noise_generator_consistency_across_calls(strategy, 0, 5) @@ -410,3 +439,4 @@ def test_generate_ground_truth_recording(): # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() + test_generate_sorting_with_spikes_on_borders() diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 8216a4aae6..d0672405d6 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -37,16 +37,20 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True) + job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict( + n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict( + n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=10) job_kwargs_split = fix_job_kwargs(new_job_kwargs) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 7d7af6025b..a904e4dd32 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -36,7 +36,7 @@ def test_ensure_n_jobs(): n_jobs = ensure_n_jobs(recording, n_jobs=1) assert n_jobs == 1 - # dumpable + # check serializable n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1) assert n_jobs > 1 @@ -45,7 +45,7 @@ def test_ensure_chunk_size(): recording = generate_recording(num_channels=2) dtype = recording.get_dtype() assert dtype == "float32" - # make dumpable + # make serializable recording = recording.save() chunk_size = ensure_chunk_size(recording, total_memory="512M", chunk_size=None, chunk_memory=None, n_jobs=2) @@ -90,7 +90,7 @@ def init_func(arg1, arg2, arg3): def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) - # make dumpable + # make serializable recording = recording.save() init_args = "a", 120, "yep" diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 473648c5ec..1c491bd7a6 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -142,9 +142,11 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract self.extractor_list = extractor_list self.extractor_dict = extractor_dict + BaseExtractor.__init__(self, main_ids=["1", "2"]) # this already the case by default - self._is_dumpable = True - self._is_json_serializable = True + self._serializablility["memory"] = True + self._serializablility["json"] = True + self._serializablility["pickle"] = True self._kwargs = { "attribute": attribute, @@ -195,3 +197,8 @@ def test_encoding_numpy_scalars_within_nested_extractors_list(nested_extractor_l def test_encoding_numpy_scalars_within_nested_extractors_dict(nested_extractor_dict): json.dumps(nested_extractor_dict, cls=SIJsonEncoder) + + +if __name__ == "__main__": + nested_extractor = nested_extractor() + test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 6e92d155fe..1d99b192ee 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -138,11 +138,13 @@ def test_order_channels_by_depth(): order_1d, order_r1d = order_channels_by_depth(rec, dimensions="y") order_2d, order_r2d = order_channels_by_depth(rec, dimensions=("x", "y")) locations_rev = locations_copy[order_1d][order_r1d] + order_2d_fliped, order_r2d_fliped = order_channels_by_depth(rec, dimensions=("x", "y"), flip=True) assert np.array_equal(locations[:, 1], locations_copy[order_1d][:, 1]) assert np.array_equal(locations_copy[order_1d][:, 1], locations_copy[order_2d][:, 1]) assert np.array_equal(locations, locations_copy[order_2d]) assert np.array_equal(locations_copy, locations_copy[order_2d][order_r2d]) + assert np.array_equal(order_2d[::-1], order_2d_fliped) if __name__ == "__main__": diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a6b94c9b84..ac114ac161 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -34,7 +34,7 @@ def test_ChannelSparsity(): for key, v in sparsity.unit_id_to_channel_ids.items(): assert key in unit_ids - assert np.all(np.in1d(v, channel_ids)) + assert np.all(np.isin(v, channel_ids)) for key, v in sparsity.unit_id_to_channel_indices.items(): assert key in unit_ids @@ -55,5 +55,93 @@ def test_ChannelSparsity(): assert np.array_equal(sparsity.mask, sparsity4.mask) +def test_sparsify_waveforms(): + seed = 0 + rng = np.random.default_rng(seed=seed) + + num_units = 3 + num_samples = 5 + num_channels = 4 + + is_mask_valid = False + while not is_mask_valid: + sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool") + is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0) + + unit_ids = np.arange(num_units) + channel_ids = np.arange(num_channels) + sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids) + + for unit_id in unit_ids: + waveforms_dense = rng.random(size=(num_units, num_samples, num_channels)) + + # Test are_waveforms_dense + assert sparsity.are_waveforms_dense(waveforms_dense) + + # Test sparsify + waveforms_sparse = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id) + non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels) + + # Test round-trip (note that this is loosy) + unit_id = unit_ids[unit_id] + non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] + waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) + assert np.array_equal(waveforms_dense[..., non_zero_indices], waveforms_dense2[..., non_zero_indices]) + + # Test sparsify with one waveform (template) + template_dense = waveforms_dense.mean(axis=0) + template_sparse = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id) + assert template_sparse.shape == (num_samples, num_active_channels) + + # Test round trip with template + template_dense2 = sparsity.densify_waveforms(template_sparse, unit_id=unit_id) + assert np.array_equal(template_dense[..., non_zero_indices], template_dense2[:, non_zero_indices]) + + +def test_densify_waveforms(): + seed = 0 + rng = np.random.default_rng(seed=seed) + + num_units = 3 + num_samples = 5 + num_channels = 4 + + is_mask_valid = False + while not is_mask_valid: + sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool") + is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0) + + unit_ids = np.arange(num_units) + channel_ids = np.arange(num_channels) + sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids) + + for unit_id in unit_ids: + non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + waveforms_sparse = rng.random(size=(num_units, num_samples, num_active_channels)) + + # Test are waveforms sparse + assert sparsity.are_waveforms_sparse(waveforms_sparse, unit_id=unit_id) + + # Test densify + waveforms_dense = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) + assert waveforms_dense.shape == (num_units, num_samples, num_channels) + + # Test round-trip + waveforms_sparse2 = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id) + assert np.array_equal(waveforms_sparse, waveforms_sparse2) + + # Test densify with one waveform (template) + template_sparse = waveforms_sparse.mean(axis=0) + template_dense = sparsity.densify_waveforms(template_sparse, unit_id=unit_id) + assert template_dense.shape == (num_samples, num_channels) + + # Test round trip with template + template_sparse2 = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id) + assert np.array_equal(template_sparse, template_sparse2) + + if __name__ == "__main__": test_ChannelSparsity() diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 107ef5f180..204f796c0e 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -6,7 +6,13 @@ import zarr -from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity +from spikeinterface.core import ( + generate_recording, + generate_sorting, + NumpySorting, + ChannelSparsity, + generate_ground_truth_recording, +) from spikeinterface import WaveformExtractor, BaseRecording, extract_waveforms, load_waveforms from spikeinterface.core.waveform_extractor import precompute_sparsity @@ -309,7 +315,7 @@ def test_recordingless(): recording = recording.save(folder=cache_folder / "recording1") sorting = sorting.save(folder=cache_folder / "sorting1") - # recording and sorting are not dumpable + # recording and sorting are not serializable wf_folder = cache_folder / "wf_recordingless" # save with relative paths @@ -340,6 +346,8 @@ def test_recordingless(): # delete original recording and rely on rec_attributes if platform.system() != "Windows": + # this avoid reference on the folder + del we, recording shutil.rmtree(cache_folder / "recording1") we_loaded = WaveformExtractor.load(wf_folder, with_recording=False) assert not we_loaded.has_recording() @@ -510,10 +518,45 @@ def test_compute_sparsity(): print(sparsity) +def test_non_json_object(): + recording, sorting = generate_ground_truth_recording( + durations=[30, 40], + sampling_frequency=30000.0, + num_channels=32, + num_units=5, + ) + + # recording is not save to keep it in memory + sorting = sorting.save() + + wf_folder = cache_folder / "test_waveform_extractor" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + + we = extract_waveforms( + recording, + sorting, + wf_folder, + mode="folder", + sparsity=None, + sparse=False, + ms_before=1.0, + ms_after=1.6, + max_spikes_per_unit=50, + n_jobs=4, + chunk_size=30000, + progress_bar=True, + ) + + # This used to fail because of json + we = load_waveforms(wf_folder) + + if __name__ == "__main__": - test_WaveformExtractor() + # test_WaveformExtractor() # test_extract_waveforms() - # test_sparsity() # test_portability() - # test_recordingless() + test_recordingless() # test_compute_sparsity() + # test_non_json_object() + test_empty_sorting() diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 32158f00df..4e98864ba9 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -95,7 +95,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None): try: property_dict[prop_name] = np.concatenate((property_dict[prop_name], values)) except Exception as e: - print(f"Skipping property '{prop_name}' for shape inconsistency") + print(f"Skipping property '{prop_name}' due to shape inconsistency") del property_dict[prop_name] break for prop_name, prop_values in property_dict.items(): diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 877c9fb00c..d4ae140b90 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -4,6 +4,7 @@ import shutil from typing import Iterable, Literal, Optional import json +import os import numpy as np from copy import deepcopy @@ -87,6 +88,7 @@ def __init__( self._template_cache = {} self._params = {} self._loaded_extensions = dict() + self._is_read_only = False self.sparsity = sparsity self.folder = folder @@ -103,6 +105,8 @@ def __init__( if (self.folder / "params.json").is_file(): with open(str(self.folder / "params.json"), "r") as f: self._params = json.load(f) + if not os.access(self.folder, os.W_OK): + self._is_read_only = True else: # this is in case of in-memory self.format = "memory" @@ -155,14 +159,28 @@ def load_from_folder( else: rec_attributes["probegroup"] = None else: - try: - recording = load_extractor(folder / "recording.json", base_folder=folder) - rec_attributes = None - except: + recording = None + if (folder / "recording.json").exists(): + try: + recording = load_extractor(folder / "recording.json", base_folder=folder) + except: + pass + elif (folder / "recording.pickle").exists(): + try: + recording = load_extractor(folder / "recording.pickle") + except: + pass + if recording is None: raise Exception("The recording could not be loaded. You can use the `with_recording=False` argument") + rec_attributes = None if sorting is None: - sorting = load_extractor(folder / "sorting.json", base_folder=folder) + if (folder / "sorting.json").exists(): + sorting = load_extractor(folder / "sorting.json", base_folder=folder) + elif (folder / "sorting.pickle").exists(): + sorting = load_extractor(folder / "sorting.pickle") + else: + raise FileNotFoundError("load_waveforms() impossible to find the sorting object (json or pickle)") # the sparsity is the sparsity of the saved/cached waveforms arrays sparsity_file = folder / "sparsity.json" @@ -267,14 +285,22 @@ def create( else: relative_to = None - if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) - if sorting.check_if_json_serializable(): + elif recording.check_serializablility("pickle"): + # In this case we loose the relative_to!! + recording.dump(folder / "recording.pickle") + + if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif sorting.check_serializablility("pickle"): + # In this case we loose the relative_to!! + # TODO later the dump to pickle should dump the dictionary and so relative could be put back + sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -399,6 +425,9 @@ def return_scaled(self) -> bool: def dtype(self): return self._params["dtype"] + def is_read_only(self) -> bool: + return self._is_read_only + def has_recording(self) -> bool: return self._recording is not None @@ -516,6 +545,10 @@ def is_extension(self, extension_name) -> bool: """ if self.folder is None: return extension_name in self._loaded_extensions + + if extension_name in self._loaded_extensions: + # extension already loaded in memory + return True else: if self.format == "binary": return (self.folder / extension_name).is_dir() and ( @@ -800,14 +833,30 @@ def select_units(self, unit_ids, new_folder=None, use_relative_path: bool = Fals sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids) else: sparsity = None - we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) - we.set_params(**self._params) + if self.has_recording(): + we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) + else: + we = WaveformExtractor( + recording=None, + sorting=sorting, + folder=None, + sparsity=sparsity, + rec_attributes=self._rec_attributes, + allow_unfiltered=True, + ) + we._params = self._params # copy memory objects if self.has_waveforms(): we._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}} for unit_id in unit_ids: - we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] - we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][unit_id] + if self.format == "memory": + we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] + we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][ + unit_id + ] + else: + we._memory_objects["wfs_arrays"][unit_id] = self.get_waveforms(unit_id) + we._memory_objects["sampled_indices"][unit_id] = self.get_sampled_indices(unit_id) # finally select extensions data for ext_name in self.get_available_extension_names(): @@ -868,14 +917,19 @@ def save( (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) - if self.sorting.check_if_json_serializable(): + elif self.recording.check_serializablility("pickle"): + self.recording.dump(folder / "recording.pickle") + + if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif self.sorting.check_serializablility("pickle"): + self.sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -920,16 +974,16 @@ def save( # write metadata zarr_root.attrs["params"] = check_json(self._params) if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["recording"] = check_json(rec_dict) - if self.sorting.check_if_json_serializable(): + if self.sorting.check_serializablility("json"): sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["sorting"] = check_json(sort_dict) else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not json serializable, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) recording_info = zarr_root.create_group("recording_info") recording_info.attrs["recording_attributes"] = check_json(rec_attributes) @@ -1419,13 +1473,13 @@ def extract_waveforms( folder=None, mode="folder", precompute_template=("average",), - ms_before=3.0, - ms_after=4.0, + ms_before=1.0, + ms_after=2.0, max_spikes_per_unit=500, overwrite=False, return_scaled=True, dtype=None, - sparse=False, + sparse=True, sparsity=None, num_spikes_for_sparsity=100, allow_unfiltered=False, @@ -1469,7 +1523,7 @@ def extract_waveforms( If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV. dtype: dtype or None Dtype of the output waveforms. If None, the recording dtype is maintained. - sparse: bool (default False) + sparse: bool, default: True If True, before extracting all waveforms the `precompute_sparsity()` function is run using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the waveforms will be sparse at extraction time, which saves a lot of memory. @@ -1688,6 +1742,7 @@ def precompute_sparsity( max_spikes_per_unit=num_spikes_for_sparsity, return_scaled=False, allow_unfiltered=allow_unfiltered, + sparse=False, **job_kwargs, ) local_sparsity = compute_sparsity(local_we, **sparse_kwargs) @@ -1740,13 +1795,33 @@ def __init__(self, waveform_extractor): if self.format == "binary": self.extension_folder = self.folder / self.extension_name if not self.extension_folder.is_dir(): - self.extension_folder.mkdir() + if self.waveform_extractor.is_read_only(): + warn( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_folder.mkdir() + else: import zarr - zarr_root = zarr.open(self.folder, mode="r+") + mode = "r+" if not self.waveform_extractor.is_read_only() else "r" + zarr_root = zarr.open(self.folder, mode=mode) if self.extension_name not in zarr_root.keys(): - self.extension_group = zarr_root.create_group(self.extension_name) + if self.waveform_extractor.is_read_only(): + warn( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_group = zarr_root.create_group(self.extension_name) else: self.extension_group = zarr_root[self.extension_name] else: @@ -1863,6 +1938,9 @@ def save(self, **kwargs): self._save(**kwargs) def _save(self, **kwargs): + # Only save if not read only + if self.waveform_extractor.is_read_only(): + return if self.format == "binary": import pandas as pd @@ -1900,7 +1978,9 @@ def _save(self, **kwargs): self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) elif isinstance(ext_data, pd.DataFrame): ext_data.to_xarray().to_zarr( - store=self.extension_group.store, group=f"{self.extension_group.name}/{ext_data_name}", mode="a" + store=self.extension_group.store, + group=f"{self.extension_group.name}/{ext_data_name}", + mode="a", ) self.extension_group[ext_data_name].attrs["dataframe"] = True else: @@ -1952,6 +2032,9 @@ def set_params(self, **params): params = self._set_params(**params) self._params = params + if self.waveform_extractor.is_read_only(): + return + params_to_save = params.copy() if "sparsity" in params and params["sparsity"] is not None: assert isinstance( diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index da8e3d64b6..a2f1296e31 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -344,15 +344,15 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes with the correct segment_index # this is a slice so no copy!! - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) in_seg_spikes = spikes[s0:s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -562,8 +562,7 @@ def _init_worker_distribute_single_buffer( # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) worker_ctx["segment_slices"] = segment_slices @@ -590,8 +589,9 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -685,8 +685,7 @@ def has_exceeding_spikes(recording, sorting): """ spike_vector = sorting.to_spike_vector() for segment_index in range(recording.get_num_segments()): - start_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index) - end_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind] if len(spike_vector_seg) > 0: if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1: diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 264ac3a56d..5295cc76d8 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting): ---------- parent_sorting: Recording The sorting object - units_to_merge: list of lists + units_to_merge: list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids: None or list @@ -24,6 +24,7 @@ class MergeUnitsSorting(BaseSorting): Default: 'keep' delta_time_ms: float or None Number of ms to consider for duplicated spikes. None won't check for duplications + Returns ------- sorting: Sorting @@ -33,7 +34,7 @@ class MergeUnitsSorting(BaseSorting): def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4): self._parent_sorting = parent_sorting - if not isinstance(units_to_merge[0], list): + if not isinstance(units_to_merge[0], (list, tuple)): # keep backward compatibility : the previous behavior was only one merge units_to_merge = [units_to_merge] @@ -59,7 +60,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties else: # we cannot automatically find new names new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError( "Unable to find 'new_unit_ids' because it is a string and parents " "already contain merges. Pass a list of 'new_unit_ids' as an argument." @@ -68,7 +69,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties # dtype int new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) else: - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 6adf9effd4..626ea79eb9 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -57,37 +57,47 @@ def apply_sortingview_curation( unit_ids_dtype = sorting.unit_ids.dtype # STEP 1: merge groups + labels_dict = sortingview_curation_dict["labelsByUnit"] if "mergeGroups" in sortingview_curation_dict and not skip_merge: merge_groups = sortingview_curation_dict["mergeGroups"] - for mg in merge_groups: + for merge_group in merge_groups: + # Store labels of units that are about to be merged + labels_to_inherit = [] + for unit in merge_group: + labels_to_inherit.extend(labels_dict.get(str(unit), [])) + labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates + if verbose: - print(f"Merging {mg}") + print(f"Merging {merge_group}") if unit_ids_dtype.kind in ("U", "S"): # if unit dtype is str, set new id as "{unit1}-{unit2}" - new_unit_id = "-".join(mg) + new_unit_id = "-".join(merge_group) + curation_sorting.merge(merge_group, new_unit_id=new_unit_id) else: # in this case, the CurationSorting takes care of finding a new unused int - new_unit_id = None - curation_sorting.merge(mg, new_unit_id=new_unit_id) + curation_sorting.merge(merge_group, new_unit_id=None) + new_unit_id = curation_sorting.max_used_id # merged unit id + labels_dict[str(new_unit_id)] = labels_to_inherit # STEP 2: gather and apply sortingview curation labels - # In sortingview, a unit is not required to have all labels. # For example, the first 3 units could be labeled as "accept". # In this case, the first 3 values of the property "accept" will be True, the rest False - labels_dict = sortingview_curation_dict["labelsByUnit"] - properties = {} - for _, labels in labels_dict.items(): - for label in labels: - if label not in properties: - properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) - for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - labels_unit = [] - for unit_label, labels in labels_dict.items(): - if unit_label in str(unit_id): - labels_unit.extend(labels) - for label in labels_unit: - properties[label][u_i] = True + + # Initialize the properties dictionary + properties = { + label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) + for labels in labels_dict.values() + for label in labels + } + + # Populate the properties dictionary + for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): + unit_id_str = str(unit_id) + if unit_id_str in labels_dict: + for label in labels_dict[unit_id_str]: + properties[label][unit_index] = True + for prop_name, prop_values in properties.items(): curation_sorting.current_sorting.set_property(prop_name, prop_values) @@ -103,5 +113,4 @@ def apply_sortingview_curation( units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True]) units_to_remove = np.unique(units_to_remove) curation_sorting.remove_units(units_to_remove) - return curation_sorting.current_sorting diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json new file mode 100644 index 0000000000..48881388bb --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json @@ -0,0 +1,19 @@ +{ + "labelsByUnit": { + "1": [ + "accept" + ], + "2": [ + "artifact" + ], + "12": [ + "artifact" + ] + }, + "mergeGroups": [ + [ + 2, + 12 + ] + ] +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-int.json b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json new file mode 100644 index 0000000000..2047c514ce --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "1": [ + "mua" + ], + "2": [ + "mua" + ], + "3": [ + "reject" + ], + "4": [ + "noise" + ], + "5": [ + "accept" + ], + "6": [ + "accept" + ], + "7": [ + "accept" + ] + }, + "mergeGroups": [ + [ + 1, + 2 + ], + [ + 3, + 4 + ], + [ + 5, + 6 + ] + ] +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-str.json b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json new file mode 100644 index 0000000000..2585b5cc50 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "a": [ + "mua" + ], + "b": [ + "mua" + ], + "c": [ + "reject" + ], + "d": [ + "noise" + ], + "e": [ + "accept" + ], + "f": [ + "accept" + ], + "g": [ + "accept" + ] + }, + "mergeGroups": [ + [ + "a", + "b" + ], + [ + "c", + "d" + ], + [ + "e", + "f" + ] + ] +} diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 9177cb5536..ce6c7dd5a6 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -1,8 +1,11 @@ import pytest from pathlib import Path import os +import json +import numpy as np import spikeinterface as si +import spikeinterface.extractors as se from spikeinterface.extractors import read_mearec from spikeinterface import set_global_tmp_folder from spikeinterface.postprocessing import ( @@ -19,7 +22,6 @@ cache_folder = Path("cache_folder") / "curation" parent_folder = Path(__file__).parent - ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) @@ -50,15 +52,15 @@ def generate_sortingview_curation_dataset(): @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_gh_curation(): + """ + Test curation using GitHub URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) - - # from GH # curated link: # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22gh://alejoe91/spikeinterface/fix-codecov/spikeinterface/curation/tests/sv-sorting-curation.json%22} gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True) - print(f"From GH: {sorting_curated_gh}") assert len(sorting_curated_gh.unit_ids) == 9 assert "#8-#9" in sorting_curated_gh.unit_ids @@ -78,6 +80,9 @@ def test_gh_curation(): @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_sha1_curation(): + """ + Test curation using SHA1 URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) @@ -86,14 +91,14 @@ def test_sha1_curation(): # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22%22} sha1_uri = "sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22" sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, verbose=True) - print(f"From SHA: {sorting_curated_sha1}") + # print(f"From SHA: {sorting_curated_sha1}") assert len(sorting_curated_sha1.unit_ids) == 9 assert "#8-#9" in sorting_curated_sha1.unit_ids assert "accept" in sorting_curated_sha1.get_property_keys() assert "mua" in sorting_curated_sha1.get_property_keys() assert "artifact" in sorting_curated_sha1.get_property_keys() - + unit_ids = sorting_curated_sha1.unit_ids sorting_curated_sha1_accepted = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, include_labels=["accept"]) sorting_curated_sha1_mua = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, exclude_labels=["mua"]) sorting_curated_sha1_art_mua = apply_sortingview_curation( @@ -105,13 +110,16 @@ def test_sha1_curation(): def test_json_curation(): + """ + Test curation using a JSON file. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) # from curation.json json_file = parent_folder / "sv-sorting-curation.json" + # print(f"Sorting: {sorting.get_unit_ids()}") sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) - print(f"From JSON: {sorting_curated_json}") assert len(sorting_curated_json.unit_ids) == 9 assert "#8-#9" in sorting_curated_json.unit_ids @@ -131,8 +139,133 @@ def test_json_curation(): assert len(sorting_curated_json_mua1.unit_ids) == 5 +def test_false_positive_curation(): + """ + Test curation for false positives. + """ + # https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html + sampling_frequency = 30000.0 + duration = 20.0 + num_timepoints = int(sampling_frequency * duration) + num_units = 20 + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, num_units + 1, size=num_spikes) + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + # print("Sorting: {}".format(sorting.get_unit_ids())) + + json_file = parent_folder / "sv-sorting-curation-false-positive.json" + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + # print("Curated:", sorting_curated_json.get_unit_ids()) + + # Assertions + assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") + assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") + assert 21 in sorting_curated_json.unit_ids + + +def test_label_inheritance_int(): + """ + Test curation for label inheritance for integer unit IDs. + """ + # Setup + sampling_frequency = 30000.0 + duration = 20.0 + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + num_units = 7 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, 1 + num_units, size=num_spikes) # 7 units: 1 to 7 + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + + json_file = parent_folder / "sv-sorting-curation-int.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file) + + # Assertions for merged units + # print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2 + assert not sorting_merge.get_unit_property(unit_id=8, key="reject") + assert not sorting_merge.get_unit_property(unit_id=8, key="noise") + assert not sorting_merge.get_unit_property(unit_id=8, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=9, key="mua") # 9 = merged unit of 3 and 4 + assert sorting_merge.get_unit_property(unit_id=9, key="reject") + assert sorting_merge.get_unit_property(unit_id=9, key="noise") + assert not sorting_merge.get_unit_property(unit_id=9, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=10, key="mua") # 10 = merged unit of 5 and 6 + assert not sorting_merge.get_unit_property(unit_id=10, key="reject") + assert not sorting_merge.get_unit_property(unit_id=10, key="noise") + assert sorting_merge.get_unit_property(unit_id=10, key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert 9 not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert 8 not in sorting_include_accept.get_unit_ids() + assert 9 not in sorting_include_accept.get_unit_ids() + assert 10 in sorting_include_accept.get_unit_ids() + + +def test_label_inheritance_str(): + """ + Test curation for label inheritance for string unit IDs. + """ + sampling_frequency = 30000.0 + duration = 20.0 + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes) + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + # print(f"Sorting: {sorting.get_unit_ids()}") + + # Apply curation + json_file = parent_folder / "sv-sorting-curation-str.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + + # Assertions for merged units + # print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id="a-b", key="mua") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="c-d", key="mua") + assert sorting_merge.get_unit_property(unit_id="c-d", key="reject") + assert sorting_merge.get_unit_property(unit_id="c-d", key="noise") + assert not sorting_merge.get_unit_property(unit_id="c-d", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="e-f", key="mua") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="reject") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="noise") + assert sorting_merge.get_unit_property(unit_id="e-f", key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert "c-d" not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert "a-b" not in sorting_include_accept.get_unit_ids() + assert "c-d" not in sorting_include_accept.get_unit_ids() + assert "e-f" in sorting_include_accept.get_unit_ids() + + if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() test_gh_curation() test_json_curation() + test_false_positive_curation() + test_label_inheritance_int() + test_label_inheritance_str() diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 7528f0ebf9..39bb875ea8 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -78,7 +78,7 @@ def test_export_to_phy_by_property(): recording = recording.save(folder=rec_folder) sorting = sorting.save(folder=sort_folder) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_group = compute_sparsity(waveform_extractor, method="by_property", by_property="group") export_to_phy( waveform_extractor, @@ -96,7 +96,7 @@ def test_export_to_phy_by_property(): # Remove one channel recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7]) - waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm) + waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm, sparse=False) sparsity_group = compute_sparsity(waveform_extractor_rm, method="by_property", by_property="group") export_to_phy( @@ -130,7 +130,7 @@ def test_export_to_phy_by_sparsity(): if f.is_dir(): shutil.rmtree(f) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_radius = compute_sparsity(waveform_extractor, method="radius", radius_um=50.0) export_to_phy( waveform_extractor, diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 5615402fdb..31a452f389 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -35,6 +35,7 @@ def export_to_phy( template_mode: str = "median", dtype: Optional[npt.DTypeLike] = None, verbose: bool = True, + use_relative_path: bool = False, **job_kwargs, ): """ @@ -64,6 +65,9 @@ def export_to_phy( Dtype to save binary data verbose: bool If True, output is verbose + use_relative_path : bool, default: False + If True and `copy_binary=True` saves the binary file `dat_path` in the `params.py` relative to `output_folder` (ie `dat_path=r'recording.dat'`). If `copy_binary=False`, then uses a path relative to the `output_folder` + If False, uses an absolute path in the `params.py` (ie `dat_path=r'path/to/the/recording.dat'`) {} """ @@ -90,11 +94,12 @@ def export_to_phy( if waveform_extractor.is_sparse(): used_sparsity = waveform_extractor.sparsity + assert sparsity is None elif sparsity is not None: used_sparsity = sparsity else: used_sparsity = ChannelSparsity.create_dense(waveform_extractor) - # convinient sparsity dict for the 3 cases to retrieve channl_inds + # convenient sparsity dict for the 3 cases to retrieve channl_inds sparse_dict = used_sparsity.unit_id_to_channel_indices empty_flag = False @@ -106,7 +111,7 @@ def export_to_phy( empty_flag = True unit_ids = non_empty_units if empty_flag: - warnings.warn("Empty units have been removed when being exported to Phy") + warnings.warn("Empty units have been removed while exporting to Phy") if len(unit_ids) == 0: raise Exception("No non-empty units in the sorting result, can't save to Phy.") @@ -149,7 +154,15 @@ def export_to_phy( # write params.py with (output_folder / "params.py").open("w") as f: - f.write(f"dat_path = r'{str(rec_path)}'\n") + if use_relative_path: + if copy_binary: + f.write(f"dat_path = r'recording.dat'\n") + elif rec_path == "None": + f.write(f"dat_path = {rec_path}\n") + else: + f.write(f"dat_path = r'{str(Path(rec_path).relative_to(output_folder))}'\n") + else: + f.write(f"dat_path = r'{str(rec_path)}'\n") f.write(f"n_channels_dat = {num_chans}\n") f.write(f"dtype = '{dtype_str}'\n") f.write(f"offset = 0\n") @@ -178,7 +191,11 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") + if waveform_extractor.is_extension("similarity"): + tmc = waveform_extractor.load_extension("similarity") + template_similarity = tmc.get_data() + else: + template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") np.save(str(output_folder / "templates.npy"), templates) np.save(str(output_folder / "template_ind.npy"), templates_ind) diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 02e7d5677d..8b70722652 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids): contact_ids = channels["contact_id"].values.astype("U") # extracting information of requested channels - keep = np.in1d(channel_ids, recording_channel_ids) + keep = np.isin(channel_ids, recording_channel_ids) channel_ids = channel_ids[keep] contact_ids = contact_ids[keep] diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 3dde998ca1..bd56208ebe 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -6,13 +6,6 @@ from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts from spikeinterface.core.core_tools import define_function_from_class -try: - import mtscomp - - HAVE_MTSCOMP = True -except: - HAVE_MTSCOMP = False - class CompressedBinaryIblExtractor(BaseRecording): """Load IBL data as an extractor object. @@ -42,7 +35,6 @@ class CompressedBinaryIblExtractor(BaseRecording): """ extractor_name = "CompressedBinaryIbl" - installed = HAVE_MTSCOMP mode = "folder" installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" name = "cbin_ibl" @@ -51,7 +43,10 @@ def __init__(self, folder_path, load_sync_channel=False, stream_name="ap"): # this work only for future neo from neo.rawio.spikeglxrawio import read_meta_file, extract_stream_info - assert HAVE_MTSCOMP + try: + import mtscomp + except: + raise ImportError(self.installation_mesg) folder_path = Path(folder_path) # check bands diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index b40b998103..0980e89f1c 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -40,7 +40,6 @@ def __init__( sampling_frequency: float | None = None, session_info_file_path: str | Path | None = None, spikes_matfile_path: str | Path | None = None, - session_info_matfile_path: str | Path | None = None, ): try: from pymatreader import read_mat @@ -67,26 +66,6 @@ def __init__( ) file_path = spikes_matfile_path if file_path is None else file_path - if session_info_matfile_path is not None: - # Raise an error if the warning period has expired - deprecation_issued = datetime.datetime(2023, 4, 1) - deprecation_deadline = deprecation_issued + datetime.timedelta(days=180) - if datetime.datetime.now() > deprecation_deadline: - raise ValueError( - "The session_info_matfile_path argument is no longer supported in. Use session_info_file_path instead." - ) - - # Otherwise, issue a DeprecationWarning - else: - warnings.warn( - "The session_info_matfile_path argument is deprecated and will be removed in six months. " - "Use session_info_file_path instead.", - DeprecationWarning, - ) - session_info_file_path = ( - session_info_matfile_path if session_info_file_path is None else session_info_file_path - ) - self.spikes_cellinfo_path = Path(file_path) self.session_path = self.spikes_cellinfo_path.parent self.session_id = self.spikes_cellinfo_path.stem.split(".")[0] @@ -139,7 +118,7 @@ def __init__( spike_times = spikes_data["times"] # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames - unit_ids = unit_ids[:].tolist() + unit_ids = [str(unit_id) for unit_id in unit_ids] spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: spiketrains_dict[unit_id] = (sampling_frequency * spiketrains_dict[unit_id]).round().astype(np.int64) diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index ebff40fae0..235dd705dc 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -11,6 +11,8 @@ NumpySorting, NpySnippetsExtractor, ZarrRecordingExtractor, + read_binary, + read_zarr, ) # sorting/recording/event from neo diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 815c617677..1eb0182318 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -216,10 +216,14 @@ def write_sorting(sorting, save_path, write_primary_channels=False): times_list = [] labels_list = [] primary_channels_list = [] - for unit_id in unit_ids: + for unit_index, unit_id in enumerate(unit_ids): times = sorting.get_unit_spike_train(unit_id=unit_id) times_list.append(times) - labels_list.append(np.ones(times.shape) * unit_id) + # unit id may not be numeric + if unit_id.dtype.kind in "iu": + labels_list.append(np.ones(times.shape, dtype=unit_id.dtype) * unit_id) + else: + labels_list.append(np.ones(times.shape, dtype=int) * unit_index) if write_primary_channels: if "max_channel" in sorting.get_unit_property_names(unit_id): primary_channels_list.append([sorting.get_unit_property(unit_id, "max_channel")] * times.shape[0]) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index a771dc47b1..bb3ae3435a 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -22,6 +22,19 @@ from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts +def drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs): + # Temporary function until neo version 0.13.0 is released + from packaging.version import Version + from importlib.metadata import version as lib_version + + neo_version = lib_version("neo") + # The possibility of ignoring timestamps errors is not present in neo <= 0.12.0 + if Version(neo_version) <= Version("0.12.0"): + neo_kwargs.pop("ignore_timestamps_errors") + + return neo_kwargs + + class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): """ Class for reading data saved by the Open Ephys GUI. @@ -45,14 +58,24 @@ class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks (experiments), specify the block index you want to load. all_annotations: bool (default False) Load exhaustively all annotation from neo. + ignore_timestamps_errors: bool (default False) + Ignore the discontinuous timestamps errors in neo. """ mode = "folder" NeoRawIOClass = "OpenEphysRawIO" name = "openephyslegacy" - def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): - neo_kwargs = self.map_to_neo_kwargs(folder_path) + def __init__( + self, + folder_path, + stream_id=None, + stream_name=None, + block_index=None, + all_annotations=False, + ignore_timestamps_errors=False, + ): + neo_kwargs = self.map_to_neo_kwargs(folder_path, ignore_timestamps_errors) NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, @@ -64,8 +87,9 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) @classmethod - def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(folder_path)} + def map_to_neo_kwargs(cls, folder_path, ignore_timestamps_errors=False): + neo_kwargs = {"dirname": str(folder_path), "ignore_timestamps_errors": ignore_timestamps_errors} + neo_kwargs = drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs) return neo_kwargs @@ -159,7 +183,10 @@ def __init__( probe = None if probe is not None: - self = self.set_probe(probe, in_place=True) + if probe.shank_ids is not None: + self.set_probe(probe, in_place=True, group_mode="by_shank") + else: + self.set_probe(probe, in_place=True) probe_name = probe.annotations["probe_name"] # load num_channels_per_adc depending on probe type if "2.0" in probe_name: diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index c91aed644d..05aee160f5 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import Optional from pathlib import Path import numpy as np @@ -13,10 +16,14 @@ class BasePhyKilosortSortingExtractor(BaseSorting): ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py) - exclude_cluster_groups: list or str, optional + exclude_cluster_groups: list or str, default: None Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). keep_good_only : bool, default: True Whether to only keep good units. + remove_empty_units : bool, default: True + If True, empty units are removed from the sorting extractor. + load_all_cluster_properties : bool, default: True + If True, all cluster properties are loaded from the tsv/csv files. """ extractor_name = "BasePhyKilosortSorting" @@ -29,11 +36,11 @@ class BasePhyKilosortSortingExtractor(BaseSorting): def __init__( self, - folder_path, - exclude_cluster_groups=None, - keep_good_only=False, - remove_empty_units=False, - load_all_cluster_properties=True, + folder_path: Path | str, + exclude_cluster_groups: Optional[list[str] | str] = None, + keep_good_only: bool = False, + remove_empty_units: bool = False, + load_all_cluster_properties: bool = True, ): try: import pandas as pd @@ -195,20 +202,33 @@ class PhySortingExtractor(BasePhyKilosortSortingExtractor): ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py). - exclude_cluster_groups: list or str, optional + exclude_cluster_groups: list or str, default: None Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). + load_all_cluster_properties : bool, default: True + If True, all cluster properties are loaded from the tsv/csv files. Returns ------- extractor : PhySortingExtractor - The loaded data. + The loaded Sorting object. """ extractor_name = "PhySorting" name = "phy" - def __init__(self, folder_path, exclude_cluster_groups=None): - BasePhyKilosortSortingExtractor.__init__(self, folder_path, exclude_cluster_groups, keep_good_only=False) + def __init__( + self, + folder_path: Path | str, + exclude_cluster_groups: Optional[list[str] | str] = None, + load_all_cluster_properties: bool = True, + ): + BasePhyKilosortSortingExtractor.__init__( + self, + folder_path, + exclude_cluster_groups, + keep_good_only=False, + load_all_cluster_properties=load_all_cluster_properties, + ) self._kwargs = { "folder_path": str(Path(folder_path).absolute()), @@ -223,8 +243,6 @@ class KiloSortSortingExtractor(BasePhyKilosortSortingExtractor): ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py). - exclude_cluster_groups: list or str, optional - Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). keep_good_only : bool, default: True Whether to only keep good units. If True, only Kilosort-labeled 'good' units are returned. @@ -234,13 +252,13 @@ class KiloSortSortingExtractor(BasePhyKilosortSortingExtractor): Returns ------- extractor : KiloSortSortingExtractor - The loaded data. + The loaded Sorting object. """ extractor_name = "KiloSortSorting" name = "kilosort" - def __init__(self, folder_path, keep_good_only=False, remove_empty_units=True): + def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove_empty_units: bool = True): BasePhyKilosortSortingExtractor.__init__( self, folder_path, diff --git a/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py b/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py index 35de8a23e2..c4c8d0c993 100644 --- a/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py +++ b/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py @@ -26,7 +26,7 @@ class CellExplorerSortingTest(SortingCommonTestSuite, unittest.TestCase): ( "cellexplorer/dataset_2/20170504_396um_0um_merge.spikes.cellinfo.mat", { - "session_info_matfile_path": local_folder + "session_info_file_path": local_folder / "cellexplorer/dataset_2/20170504_396um_0um_merge.sessionInfo.mat" }, ), diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 223bda5e30..d7e1ffac01 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -10,7 +10,6 @@ from .template_metrics import ( TemplateMetricsCalculator, compute_template_metrics, - calculate_template_metrics, get_template_metric_names, ) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..7e6c95a875 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -16,6 +16,7 @@ class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ extension_name = "amplitude_scalings" + handle_sparsity = True def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) @@ -47,9 +48,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] return dict(amplitude_scalings=new_amplitude_scalings) @@ -68,7 +69,6 @@ def _run(self, **job_kwargs): delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) return_scaled = we._params["return_scaled"] - unit_ids = we.unit_ids if ms_before is not None: assert ( @@ -82,25 +82,28 @@ def _run(self, **job_kwargs): cut_out_before = int(ms_before / 1000 * we.sampling_frequency) if ms_before is not None else nbefore cut_out_after = int(ms_after / 1000 * we.sampling_frequency) if ms_after is not None else nafter - if we.is_sparse(): + if we.is_sparse() and self._params["sparsity"] is None: sparsity = we.sparsity - elif self._params["sparsity"] is not None: + elif we.is_sparse() and self._params["sparsity"] is not None: + sparsity = self._params["sparsity"] + # assert provided sparsity is sparser than the one in the waveform extractor + waveform_sparsity = we.sparsity + assert np.all( + np.sum(waveform_sparsity.mask, 1) - np.sum(sparsity.mask, 1) > 0 + ), "The provided sparsity needs to be sparser than the one in the waveform extractor!" + elif not we.is_sparse() and self._params["sparsity"] is not None: sparsity = self._params["sparsity"] else: if self._params["max_dense_channels"] is not None: assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) - sparsity_inds = sparsity.unit_id_to_channel_indices - - # easier to use in chunk function as spikes use unit_index instead o id - unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} + sparsity_mask = sparsity.mask all_templates = we.get_all_templates() # precompute segment slice segment_slices = [] for segment_index in range(we.get_num_segments()): - i0 = np.searchsorted(self.spikes["segment_index"], segment_index) - i1 = np.searchsorted(self.spikes["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append(slice(i0, i1)) # and run @@ -113,7 +116,7 @@ def _run(self, **job_kwargs): self.spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -262,7 +265,7 @@ def _init_worker_amplitude_scalings( spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -282,7 +285,7 @@ def _init_worker_amplitude_scalings( worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after worker_ctx["return_scaled"] = return_scaled - worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["sparsity_mask"] = sparsity_mask worker_ctx["handle_collisions"] = handle_collisions worker_ctx["delta_collision_samples"] = delta_collision_samples @@ -306,7 +309,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) recording = worker_ctx["recording"] all_templates = worker_ctx["all_templates"] segment_slices = worker_ctx["segment_slices"] - unit_inds_to_channel_indices = worker_ctx["unit_inds_to_channel_indices"] + sparsity_mask = worker_ctx["sparsity_mask"] nbefore = worker_ctx["nbefore"] cut_out_before = worker_ctx["cut_out_before"] cut_out_after = worker_ctx["cut_out_after"] @@ -317,8 +320,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) spikes_in_segment = spikes[segment_slices[segment_index]] - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) if i0 != i1: local_spikes = spikes_in_segment[i0:i1] @@ -335,11 +337,12 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! - i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) - i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + i0_margin, i1_margin = np.searchsorted( + spikes_in_segment["sample_index"], [start_frame - left, end_frame + right] + ) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( - local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices + local_spikes, local_spikes_w_margin, delta_collision_samples, sparsity_mask ) else: collisions_local = {} @@ -354,7 +357,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] - sparse_indices = unit_inds_to_channel_indices[unit_index] + (sparse_indices,) = np.nonzero(sparsity_mask[unit_index]) template = all_templates[unit_index][:, sparse_indices] template = template[nbefore - cut_out_before : nbefore + cut_out_after] sample_centered = sample_index - start_frame @@ -365,7 +368,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) template = template[cut_out_before - sample_index :] elif sample_index + cut_out_after > end_frame + right: local_waveform = traces_with_margin[cut_out_start:, sparse_indices] - template = template[: -(sample_index + cut_out_after - end_frame)] + template = template[: -(sample_index + cut_out_after - (end_frame + right))] else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape @@ -393,7 +396,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ) @@ -410,14 +413,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) ### Collision handling ### -def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): +def _are_unit_indices_overlapping(sparsity_mask, i, j): """ Returns True if the unit indices i and j are overlapping, False otherwise Parameters ---------- - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask i: int The first unit index j: int @@ -428,13 +431,13 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): bool True if the unit indices i and j are overlapping, False otherwise """ - if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + if np.any(sparsity_mask[i] & sparsity_mask[j]): return True else: return False -def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, sparsity_mask): """ Finds the collisions between spikes. @@ -446,8 +449,8 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ An array of spikes within the added margin delta_collision_samples: int The maximum number of samples between two spikes to consider them as overlapping - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask Returns ------- @@ -462,14 +465,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] # find the possible spikes per and post within delta_collision_samples - consecutive_window_pre = np.searchsorted( + consecutive_window_pre, consecutive_window_post = np.searchsorted( spikes_w_margin["sample_index"], - spike["sample_index"] - delta_collision_samples, - ) - consecutive_window_post = np.searchsorted( - spikes_w_margin["sample_index"], - spike["sample_index"] + delta_collision_samples, + [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples], ) + # exclude the spike itself (it is included in the collision_spikes by construction) pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) @@ -480,7 +480,7 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ # find the overlapping spikes in space as well for possible_overlapping_spike_index in possible_overlapping_spike_indices: if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, + sparsity_mask, spike["unit_index"], spikes_w_margin[possible_overlapping_spike_index]["unit_index"], ): @@ -501,7 +501,7 @@ def fit_collision( right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ): @@ -528,8 +528,8 @@ def fit_collision( The number of samples before the spike to consider for the fit. all_templates: np.ndarray A numpy array of shape (n_units, n_samples, n_channels) containing the templates. - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices. + sparsity_mask: boolean mask + The sparsity mask cut_out_before: int The number of samples to cut out before the spike. cut_out_after: int @@ -547,14 +547,16 @@ def fit_collision( sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity - sparse_indices = np.array([], dtype="int") + common_sparse_mask = np.zeros(sparsity_mask.shape[1], dtype="int") for spike in collision: - sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] - sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + mask_i = sparsity_mask[spike["unit_index"]] + common_sparse_mask = np.logical_or(common_sparse_mask, mask_i) + (sparse_indices,) = np.nonzero(common_sparse_mask) local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + num_samples_local_waveform = local_waveform.shape[0] y = local_waveform.T.flatten() X = np.zeros((len(y), len(collision))) @@ -567,8 +569,10 @@ def fit_collision( # deal with borders if sample_centered - cut_out_before < 0: full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] - elif sample_centered + cut_out_after > end_frame + right: - full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + elif sample_centered + cut_out_after > num_samples_local_waveform: + full_template[sample_centered - cut_out_before :] = template_cut[ + : -(cut_out_after + sample_centered - num_samples_local_waveform) + ] else: full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 6cd5238abd..6e693635eb 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -137,8 +137,8 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ def compute_correlograms( waveform_or_sorting_extractor, load_if_exists=False, - window_ms: float = 100.0, - bin_ms: float = 5.0, + window_ms: float = 50.0, + bin_ms: float = 1.0, method: str = "auto", ): """Compute auto and cross correlograms. diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 233625e09e..8383dcbb43 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -72,7 +72,7 @@ def _select_extension_data(self, unit_ids): new_extension_data[k] = v return new_extension_data - def get_projections(self, unit_id): + def get_projections(self, unit_id, sparse=False): """ Returns the computed projections for the sampled waveforms of a unit id. @@ -80,13 +80,22 @@ def get_projections(self, unit_id): ---------- unit_id : int or str The unit id to return PCA projections for + sparse: bool, default False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- - proj: np.array - The PCA projections (num_waveforms, num_components, num_channels) + projections: np.array + The PCA projections (num_waveforms, num_components, num_channels). + In case sparsity is used, only the projections on sparse channels are returned. """ - return self._extension_data[f"pca_{unit_id}"] + projections = self._extension_data[f"pca_{unit_id}"] + mode = self._params["mode"] + if mode in ("by_channel_local", "by_channel_global") and sparse: + sparsity = self.get_sparsity() + if sparsity is not None: + projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] + return projections def get_pca_model(self): """ @@ -134,7 +143,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): all_labels = [] #  can be unit_id or unit_index all_projections = [] for unit_index, unit_id in enumerate(unit_ids): - proj = self.get_projections(unit_id) + proj = self.get_projections(unit_id, sparse=False) if channel_ids is not None: chan_inds = self.waveform_extractor.channel_ids_to_indices(channel_ids) proj = proj[:, :, chan_inds] @@ -151,7 +160,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): return all_labels, all_projections - def project_new(self, new_waveforms, unit_id=None): + def project_new(self, new_waveforms, unit_id=None, sparse=False): """ Projects new waveforms or traces snippets on the PC components. @@ -161,6 +170,8 @@ def project_new(self, new_waveforms, unit_id=None): Array with new waveforms to project with shape (num_waveforms, num_samples, num_channels) unit_id: int or str In case PCA is sparse and mode is by_channel_local, the unit_id of 'new_waveforms' + sparse: bool, default: False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- @@ -211,6 +222,10 @@ def project_new(self, new_waveforms, unit_id=None): wfs_flat = new_waveforms.reshape(new_waveforms.shape[0], -1) projections = pca_model.transform(wfs_flat) + # take care of sparsity (not in case of concatenated) + if mode in ("by_channel_local", "by_channel_global") and sparse: + if sparsity is not None: + projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] return projections def get_sparsity(self): @@ -600,8 +615,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): seg_size = recording.get_num_samples(segment_index=segment_index) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) if i0 != i1: # protect from spikes on border : spike_time<0 or spike_time>seg_size diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 62a4e2c320..ccd2121174 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids): # load filter and save amplitude files sorting = self.waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) - (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + (keep_unit_indices,) = np.nonzero(np.isin(sorting.unit_ids, unit_ids)) new_extension_data = dict() for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) + filtered_idxs = np.isin(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data @@ -73,12 +73,6 @@ def _run(self, **job_kwargs): func = _spike_amplitudes_chunk init_func = _init_worker_spike_amplitudes n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) - if n_jobs != 1: - # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.check_if_dumpable(), ( - "The sorting object is not dumpable and cannot be processed in parallel. You can use the " - "`sorting.save()` function to make it dumpable" - ) init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor( recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs @@ -218,9 +212,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): d = np.diff(spike_times) assert np.all(d >= 0) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) - + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) n_spikes = i1 - i0 amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype()) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..4cbe4d665e 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_spike_locations = self._extension_data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 681f6f3e84..3f47c505ad 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -4,16 +4,29 @@ 22/04/2020 """ import numpy as np +import warnings +from typing import Optional from copy import deepcopy -from ..core import WaveformExtractor +from ..core import WaveformExtractor, ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.waveform_extractor import BaseWaveformExtractorExtension -import warnings + + +global DEBUG +DEBUG = False + + +def get_single_channel_template_metric_names(): + return deepcopy(list(_single_channel_metric_name_to_func.keys())) + + +def get_multi_channel_template_metric_names(): + return deepcopy(list(_multi_channel_metric_name_to_func.keys())) def get_template_metric_names(): - return deepcopy(list(_metric_name_to_func.keys())) + return get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() class TemplateMetricsCalculator(BaseWaveformExtractorExtension): @@ -26,20 +39,31 @@ class TemplateMetricsCalculator(BaseWaveformExtractorExtension): """ extension_name = "template_metrics" + min_channels_for_multi_channel_warning = 10 - def __init__(self, waveform_extractor): + def __init__(self, waveform_extractor: WaveformExtractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) - def _set_params(self, metric_names=None, peak_sign="neg", upsampling_factor=10, sparsity=None, window_slope_ms=0.7): + def _set_params( + self, + metric_names=None, + peak_sign="neg", + upsampling_factor=10, + sparsity=None, + metrics_kwargs=None, + include_multi_channel_metrics=False, + ): if metric_names is None: - metric_names = get_template_metric_names() - + metric_names = get_single_channel_template_metric_names() + if include_multi_channel_metrics: + metric_names += get_multi_channel_template_metric_names() + metrics_kwargs = metrics_kwargs or dict() params = dict( metric_names=[str(name) for name in metric_names], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - window_slope_ms=float(window_slope_ms), + metrics_kwargs=metrics_kwargs, ) return params @@ -60,6 +84,9 @@ def _run(self): unit_ids = self.waveform_extractor.sorting.unit_ids sampling_frequency = self.waveform_extractor.sampling_frequency + metrics_single_channel = [m for m in metric_names if m in get_single_channel_template_metric_names()] + metrics_multi_channel = [m for m in metric_names if m in get_multi_channel_template_metric_names()] + if sparsity is None: extremum_channels_ids = get_template_extremum_channel( self.waveform_extractor, peak_sign=peak_sign, outputs="id" @@ -79,6 +106,8 @@ def _run(self): template_metrics = pd.DataFrame(index=multi_index, columns=metric_names) all_templates = self.waveform_extractor.get_all_templates() + channel_locations = self.waveform_extractor.get_channel_locations() + for unit_index, unit_id in enumerate(unit_ids): template_all_chans = all_templates[unit_index] chan_ids = np.array(extremum_channels_ids[unit_id]) @@ -87,6 +116,7 @@ def _run(self): chan_ind = self.waveform_extractor.channel_ids_to_indices(chan_ids) template = template_all_chans[:, chan_ind] + # compute single_channel metrics for i, template_single in enumerate(template.T): if sparsity is None: index = unit_id @@ -100,15 +130,50 @@ def _run(self): template_upsampled = template_single sampling_frequency_up = sampling_frequency - for metric_name in metric_names: + trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) + + for metric_name in metrics_single_channel: func = _metric_name_to_func[metric_name] value = func( template_upsampled, sampling_frequency=sampling_frequency_up, - window_ms=self._params["window_slope_ms"], + trough_idx=trough_idx, + peak_idx=peak_idx, + **self._params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value + # compute metrics multi_channel + for metric_name in metrics_multi_channel: + # retrieve template (with sparsity if waveform extractor is sparse) + template = self.waveform_extractor.get_template(unit_id=unit_id) + + if template.shape[1] < self.min_channels_for_multi_channel_warning: + warnings.warn( + f"With less than {self.min_channels_for_multi_channel_warning} channels, " + "multi-channel metrics might not be reliable." + ) + if self.waveform_extractor.is_sparse(): + channel_locations_sparse = channel_locations[self.waveform_extractor.sparsity.mask[unit_index]] + else: + channel_locations_sparse = channel_locations + + if upsampling_factor > 1: + assert isinstance(upsampling_factor, (int, np.integer)), "'upsample' must be an integer" + template_upsampled = resample_poly(template, up=upsampling_factor, down=1, axis=0) + sampling_frequency_up = upsampling_factor * sampling_frequency + else: + template_upsampled = template + sampling_frequency_up = sampling_frequency + + func = _metric_name_to_func[metric_name] + value = func( + template_upsampled, + channel_locations=channel_locations_sparse, + sampling_frequency=sampling_frequency_up, + **self._params["metrics_kwargs"], + ) + template_metrics.at[index, metric_name] = value self._extension_data["metrics"] = template_metrics def get_data(self): @@ -132,14 +197,31 @@ def get_extension_function(): WaveformExtractor.register_extension(TemplateMetricsCalculator) +_default_function_kwargs = dict( + recovery_window_ms=0.7, + peak_relative_threshold=0.2, + peak_width_ms=0.1, + depth_direction="y", + min_channels_for_velocity=5, + min_r2_velocity=0.5, + exp_peak_function="ptp", + min_r2_exp_decay=0.5, + spread_threshold=0.2, + spread_smooth_um=20, + column_range=None, +) + + def compute_template_metrics( waveform_extractor, - load_if_exists=False, - metric_names=None, - peak_sign="neg", - upsampling_factor=10, - sparsity=None, - window_slope_ms=0.7, + load_if_exists: bool = False, + metric_names: Optional[list[str]] = None, + peak_sign: Optional[str] = "neg", + upsampling_factor: int = 10, + sparsity: Optional[ChannelSparsity] = None, + include_multi_channel_metrics: bool = False, + metrics_kwargs: dict = None, + debug_plots: bool = False, ): """ Compute template metrics including: @@ -148,6 +230,14 @@ def compute_template_metrics( * halfwidth * repolarization_slope * recovery_slope + * num_positive_peaks + * num_negative_peaks + + Optionally, the following multi-channel metrics can be computed (when include_multi_channel_metrics=True): + * velocity_above + * velocity_below + * exp_decay + * spread Parameters ---------- @@ -157,34 +247,77 @@ def compute_template_metrics( Whether to load precomputed template metrics, if they already exist. metric_names : list, optional List of metrics to compute (see si.postprocessing.get_template_metric_names()), by default None - peak_sign : str, optional - "pos" | "neg", by default 'neg' - upsampling_factor : int, optional - Upsample factor, by default 10 - sparsity: dict or None - Default is sparsity=None and template metric is computed on extremum channel only. - If given, the dictionary should contain a unit ids as keys and a channel id or a list of channel ids as values. - For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. - window_slope_ms: float - Window in ms after the positiv peak to compute slope, by default 0.7 + peak_sign : {"neg", "pos"}, default: "neg" + Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. + upsampling_factor : int, default: 10 + The upsampling factor to upsample the templates + sparsity: ChannelSparsity or None, default: None + If None, template metrics are computed on the extremum channel only. + If sparsity is given, template metrics are computed on all sparse channels of each unit. + For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. + include_multi_channel_metrics: bool, default: False + Whether to compute multi-channel metrics + metrics_kwargs: dict + Additional arguments to pass to the metric functions. Including: + * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 + * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 + * peak_width_ms: the width in samples to detect peaks, default: 0.2 + * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) + * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 + * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 + * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" + * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 + * spread_threshold: the threshold to compute the spread, default: 0.2 + * spread_smooth_um: the smoothing in um to compute the spread, default: 20 + * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None + - If None, all channels all channels are considered + - If 0 or 1, only the "column" that includes the max channel is considered + - If > 1, only channels within range (+/-) um from the max channel horizontal position are used Returns ------- - tempalte_metrics : pd.DataFrame + template_metrics : pd.DataFrame Dataframe with the computed template metrics. If 'sparsity' is None, the index is the unit_id. If 'sparsity' is given, the index is a multi-index (unit_id, channel_id) + + Notes + ----- + If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, + so that one metric value will be computed per unit. + For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". """ + if debug_plots: + global DEBUG + DEBUG = True if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) else: tmc = TemplateMetricsCalculator(waveform_extractor) + # For 2D metrics, external sparsity must be None, so that one metric value will be computed per unit. + if include_multi_channel_metrics or ( + metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) + ): + assert sparsity is None, ( + "If multi-channel metrics are computed, sparsity must be None, " + "so that each unit will correspond to 1 row of the output dataframe." + ) + assert ( + waveform_extractor.get_channel_locations().shape[1] == 2 + ), "If multi-channel metrics are computed, channel locations must be 2D." + default_kwargs = _default_function_kwargs.copy() + if metrics_kwargs is None: + metrics_kwargs = default_kwargs + else: + default_kwargs.update(metrics_kwargs) + metrics_kwargs = default_kwargs tmc.set_params( metric_names=metric_names, peak_sign=peak_sign, upsampling_factor=upsampling_factor, sparsity=sparsity, - window_slope_ms=window_slope_ms, + include_multi_channel_metrics=include_multi_channel_metrics, + metrics_kwargs=metrics_kwargs, ) tmc.run() @@ -197,7 +330,19 @@ def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough (minimum of template) and peak (maximum of template, after trough). - Assumes negative trough and positive peak + Assumes negative trough and positive peak. + + Parameters + ---------- + template: numpy.ndarray + The 1D template waveform + + Returns + ------- + trough_idx: int + The index of the trough + peak_idx: int + The index of the peak """ assert template.ndim == 1 trough_idx = np.argmin(template) @@ -205,41 +350,92 @@ def get_trough_and_peak_idx(template): return trough_idx, peak_idx -def get_peak_to_valley(template, **kwargs): +######################################################################################### +# Single-channel metrics +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ - Time between trough and peak in s + Return the peak to valley duration in seconds of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + + Returns + ------- + ptv: float + The peak to valley duration in seconds """ - sampling_frequency = kwargs["sampling_frequency"] - trough_idx, peak_idx = get_trough_and_peak_idx(template) + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) ptv = (peak_idx - trough_idx) / sampling_frequency return ptv -def get_peak_trough_ratio(template, **kwargs): +def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs): """ - Ratio between peak heigth and trough depth + Return the peak to trough ratio of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + + Returns + ------- + ptratio: float + The peak to trough ratio """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) - ptratio = template[peak_idx] / template[trough_idx] + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + ptratio = template_single[peak_idx] / template_single[trough_idx] return ptratio -def get_half_width(template, **kwargs): +def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ - Width of waveform at its half of amplitude in s + Return the half width of input waveforms in seconds. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + + Returns + ------- + hw: float + The half width in seconds """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) - sampling_frequency = kwargs["sampling_frequency"] + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) if peak_idx == 0: return np.nan - trough_val = template[trough_idx] + trough_val = template_single[trough_idx] # threshold is half of peak heigth (assuming baseline is 0) threshold = 0.5 * trough_val - (cpre_idx,) = np.where(template[:trough_idx] < threshold) - (cpost_idx,) = np.where(template[trough_idx:] < threshold) + (cpre_idx,) = np.where(template_single[:trough_idx] < threshold) + (cpost_idx,) = np.where(template_single[trough_idx:] < threshold) if len(cpre_idx) == 0 or len(cpost_idx) == 0: hw = np.nan @@ -254,7 +450,7 @@ def get_half_width(template, **kwargs): return hw -def get_repolarization_slope(template, **kwargs): +def get_repolarization_slope(template_single, sampling_frequency, trough_idx=None, **kwargs): """ Return slope of repolarization period between trough and baseline @@ -264,17 +460,25 @@ def get_repolarization_slope(template, **kwargs): Optionally the function returns also the indices per waveform where the potential crosses baseline. - """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) - sampling_frequency = kwargs["sampling_frequency"] + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + """ + if trough_idx is None: + trough_idx = get_trough_and_peak_idx(template_single) - times = np.arange(template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency if trough_idx == 0: return np.nan - (rtrn_idx,) = np.nonzero(template[trough_idx:] >= 0) + (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) if len(rtrn_idx) == 0: return np.nan # first time after trough, where template is at baseline @@ -285,11 +489,11 @@ def get_repolarization_slope(template, **kwargs): import scipy.stats - res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template[trough_idx:return_to_base_idx]) + res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template_single[trough_idx:return_to_base_idx]) return res.slope -def get_recovery_slope(template, window_ms=0.7, **kwargs): +def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwargs): """ Return the recovery slope of input waveforms. After repolarization, the neuron hyperpolarizes untill it peaks. The recovery slope is the @@ -299,41 +503,450 @@ def get_recovery_slope(template, window_ms=0.7, **kwargs): Takes a numpy array of waveforms and returns an array with recovery slopes per waveform. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - recovery_window_ms: the window in ms after the peak to compute the recovery_slope """ + import scipy.stats - trough_idx, peak_idx = get_trough_and_peak_idx(template) - sampling_frequency = kwargs["sampling_frequency"] + assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg" + recovery_window_ms = kwargs["recovery_window_ms"] + if peak_idx is None: + _, peak_idx = get_trough_and_peak_idx(template_single) - times = np.arange(template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency if peak_idx == 0: return np.nan - max_idx = int(peak_idx + ((window_ms / 1000) * sampling_frequency)) - max_idx = np.min([max_idx, template.shape[0]]) + max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) + max_idx = np.min([max_idx, template_single.shape[0]]) - import scipy.stats - - res = scipy.stats.linregress(times[peak_idx:max_idx], template[peak_idx:max_idx]) + res = scipy.stats.linregress(times[peak_idx:max_idx], template_single[peak_idx:max_idx]) return res.slope -_metric_name_to_func = { +def get_num_positive_peaks(template_single, sampling_frequency, **kwargs): + """ + Count the number of positive peaks in the template. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - peak_relative_threshold: the relative threshold to detect positive and negative peaks + - peak_width_ms: the width in samples to detect peaks + """ + from scipy.signal import find_peaks + + assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" + assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" + peak_relative_threshold = kwargs["peak_relative_threshold"] + peak_width_ms = kwargs["peak_width_ms"] + max_value = np.max(np.abs(template_single)) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) + + pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + + return len(pos_peaks[0]) + + +def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): + """ + Count the number of negative peaks in the template. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - peak_relative_threshold: the relative threshold to detect positive and negative peaks + - peak_width_ms: the width in samples to detect peaks + """ + from scipy.signal import find_peaks + + assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" + assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" + peak_relative_threshold = kwargs["peak_relative_threshold"] + peak_width_ms = kwargs["peak_width_ms"] + max_value = np.max(np.abs(template_single)) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) + + neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + + return len(neg_peaks[0]) + + +_single_channel_metric_name_to_func = { "peak_to_valley": get_peak_to_valley, "peak_trough_ratio": get_peak_trough_ratio, "half_width": get_half_width, "repolarization_slope": get_repolarization_slope, "recovery_slope": get_recovery_slope, + "num_positive_peaks": get_num_positive_peaks, + "num_negative_peaks": get_num_negative_peaks, } -# back-compatibility -def calculate_template_metrics(*args, **kwargs): - warnings.warn( - "The 'calculate_template_metrics' function is deprecated. " "Use 'compute_template_metrics' instead", - DeprecationWarning, - stacklevel=2, - ) - return compute_template_metrics(*args, **kwargs) +######################################################################################### +# Multi-channel metrics + + +def transform_column_range(template, channel_locations, column_range, depth_direction="y"): + """ + Transform template anch channel locations based on column range. + """ + column_dim = 0 if depth_direction == "y" else 1 + if column_range is None: + template_column_range = template + channel_locations_column_range = channel_locations + else: + max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] + column_mask = np.abs(channel_locations[:, column_dim] - max_channel_x) <= column_range + template_column_range = template[:, column_mask] + channel_locations_column_range = channel_locations[column_mask] + return template_column_range, channel_locations_column_range + +def sort_template_and_locations(template, channel_locations, depth_direction="y"): + """ + Sort template and locations. + """ + depth_dim = 1 if depth_direction == "y" else 0 + sort_indices = np.argsort(channel_locations[:, depth_dim]) + return template[:, sort_indices], channel_locations[sort_indices, :] + + +def fit_velocity(peak_times, channel_dist): + """ + Fit velocity from peak times and channel distances using ribust Theilsen estimator. + """ + # from scipy.stats import linregress + # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) + + from sklearn.linear_model import TheilSenRegressor + + theil = TheilSenRegressor() + theil.fit(peak_times.reshape(-1, 1), channel_dist) + slope = theil.coef_[0] + intercept = theil.intercept_ + score = theil.score(peak_times.reshape(-1, 1), channel_dist) + return slope, intercept, score + + +def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs): + """ + Compute the velocity above the max channel of the template. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - min_channels_for_velocity: the minimum number of channels above or below to compute velocity + - min_r2_velocity: the minimum r2 to accept the velocity fit + - column_range: the range in um in the x-direction to consider channels for velocity + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" + + depth_direction = kwargs["depth_direction"] + min_channels_for_velocity = kwargs["min_channels_for_velocity"] + min_r2_velocity = kwargs["min_r2_velocity"] + column_range = kwargs["column_range"] + + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + # find location of max channel + max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 + max_channel_location = channel_locations[max_channel_idx] + + channels_above = channel_locations[:, depth_dim] >= max_channel_location[depth_dim] + + # we only consider samples forward in time with respect to the max channel + # TODO: not sure + # template_above = template[max_sample_idx:, channels_above] + template_above = template[:, channels_above] + channel_locations_above = channel_locations[channels_above] + + peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time + distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) + velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above) + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_above,) = np.nonzero(channels_above) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_above else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_above, distances_um_above, "o") + x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) + axs[1].plot(x, intercept + x * velocity_above) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}" + ) + plt.show() + + if np.sum(channels_above) < min_channels_for_velocity or score < min_r2_velocity: + velocity_above = np.nan + + return velocity_above + + +def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs): + """ + Compute the velocity below the max channel of the template. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - min_channels_for_velocity: the minimum number of channels above or below to compute velocity + - min_r2_velocity: the minimum r2 to accept the velocity fit + - column_range: the range in um in the x-direction to consider channels for velocity + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" + + depth_direction = kwargs["depth_direction"] + min_channels_for_velocity = kwargs["min_channels_for_velocity"] + min_r2_velocity = kwargs["min_r2_velocity"] + column_range = kwargs["column_range"] + + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + # find location of max channel + max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 + max_channel_location = channel_locations[max_channel_idx] + + channels_below = channel_locations[:, depth_dim] <= max_channel_location[depth_dim] + + # we only consider samples forward in time with respect to the max channel + # template_below = template[max_sample_idx:, channels_below] + template_below = template[:, channels_below] + channel_locations_below = channel_locations[channels_below] + + peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time + distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) + velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below) + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_below,) = np.nonzero(channels_below) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_below else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_below, distances_um_below, "o") + x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) + axs[1].plot(x, intercept + x * velocity_below) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}" + ) + plt.show() + + if np.sum(channels_below) < min_channels_for_velocity or score < min_r2_velocity: + velocity_below = np.nan + + return velocity_below + + +def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): + """ + Compute the exponential decay of the template amplitude over distance. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") + - min_r2_exp_decay: the minimum r2 to accept the exp decay fit + """ + from scipy.optimize import curve_fit + from sklearn.metrics import r2_score + + def exp_decay(x, decay, amp0, offset): + return amp0 * np.exp(-decay * x) + offset + + assert "exp_peak_function" in kwargs, "exp_peak_function must be given as kwarg" + exp_peak_function = kwargs["exp_peak_function"] + assert "min_r2_exp_decay" in kwargs, "min_r2_exp_decay must be given as kwarg" + min_r2_exp_decay = kwargs["min_r2_exp_decay"] + # exp decay fit + if exp_peak_function == "ptp": + fun = np.ptp + elif exp_peak_function == "min": + fun = np.min + peak_amplitudes = np.abs(fun(template, axis=0)) + max_channel_location = channel_locations[np.argmax(peak_amplitudes)] + channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) + distances_sort_indices = np.argsort(channel_distances) + # np.float128 avoids overflow error + channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.float128) + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.float128) + try: + amp0 = peak_amplitudes_sorted[0] + offset0 = np.min(peak_amplitudes_sorted) + + popt, _ = curve_fit( + exp_decay, + channel_distances_sorted, + peak_amplitudes_sorted, + bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), + p0=[1e-3, peak_amplitudes_sorted[0], offset0], + ) + r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) + exp_decay_value = popt[0] + + if r2 < min_r2_exp_decay: + exp_decay_value = np.nan + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") + x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) + ax.plot(x, exp_decay(x, *popt)) + ax.set_xlabel("Distance from max channel (um)") + ax.set_ylabel("Peak amplitude") + ax.set_title( + f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - " + f"R2: {np.round(r2, 4)}" + ) + fig.suptitle("Exp decay") + plt.show() + except: + exp_decay_value = np.nan + + return exp_decay_value + + +def get_spread(template, channel_locations, sampling_frequency, **kwargs): + """ + Compute the spread of the template amplitude over distance. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - spread_threshold: the threshold to compute the spread + - column_range: the range in um in the x-direction to consider channels for velocity + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + depth_direction = kwargs["depth_direction"] + assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg" + spread_threshold = kwargs["spread_threshold"] + assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" + spread_smooth_um = kwargs["spread_smooth_um"] + assert "column_range" in kwargs, "column_range must be given as kwarg" + column_range = kwargs["column_range"] + + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + MM = np.ptp(template, 0) + MM = MM / np.max(MM) + channel_depths = channel_locations[:, depth_dim] + + if spread_smooth_um is not None and spread_smooth_um > 0: + from scipy.ndimage import gaussian_filter1d + + spread_sigma = spread_smooth_um / np.median(np.diff(np.unique(channel_depths))) + MM = gaussian_filter1d(MM, spread_sigma) + + channel_locations_above_theshold = channel_locations[MM > spread_threshold] + channel_depth_above_theshold = channel_locations_above_theshold[:, depth_dim] + spread = np.ptp(channel_depth_above_theshold) + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + axs[0].imshow( + template.T, + aspect="auto", + origin="lower", + extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]], + ) + axs[1].plot(channel_depths, MM, "o-") + axs[1].axhline(spread_threshold, ls="--", color="r") + axs[1].set_xlabel("Depth (um)") + axs[1].set_ylabel("Amplitude") + axs[1].set_title(f"Spread: {np.round(spread, 3)} um") + fig.suptitle("Spread") + plt.show() + + return spread + + +_multi_channel_metric_name_to_func = { + "velocity_above": get_velocity_above, + "velocity_below": get_velocity_below, + "exp_decay": get_exp_decay, + "spread": get_spread, +} -calculate_template_metrics.__doc__ = compute_template_metrics.__doc__ +_metric_name_to_func = {**_single_channel_metric_name_to_func, **_multi_channel_metric_name_to_func} diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b9c72f9b99..50e2ecdb57 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -2,10 +2,11 @@ import numpy as np import pandas as pd import shutil +import platform from pathlib import Path -from spikeinterface import extract_waveforms, load_extractor, compute_sparsity -from spikeinterface.extractors import toy_example +from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity +from spikeinterface.core.generate import generate_ground_truth_recording if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "postprocessing" @@ -26,7 +27,18 @@ def setUp(self): self.cache_folder = cache_folder # 1-segment - recording, sorting = toy_example(num_segments=1, num_units=10, num_channels=12) + recording, sorting = generate_ground_truth_recording( + durations=[10], + sampling_frequency=30000, + num_channels=12, + num_units=10, + dtype="float32", + seed=91, + generate_sorting_kwargs=dict(add_spikes_on_borders=True), + noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), + ) + + # add gains and offsets and save gain = 0.1 recording.set_channel_gains(gain) recording.set_channel_offsets(0) @@ -45,6 +57,7 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, @@ -53,7 +66,16 @@ def setUp(self): self.sparsity1 = compute_sparsity(we1, method="radius", radius_um=50) # 2-segments - recording, sorting = toy_example(num_segments=2, num_units=10) + recording, sorting = generate_ground_truth_recording( + durations=[10, 5], + sampling_frequency=30000, + num_channels=12, + num_units=10, + dtype="float32", + seed=91, + generate_sorting_kwargs=dict(add_spikes_on_borders=True), + noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), + ) recording.set_channel_gains(gain) recording.set_channel_offsets(0) if (cache_folder / "toy_rec_2seg").is_dir(): @@ -71,16 +93,28 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, ) self.we2 = we2 + + # make we read-only + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + if not we_ro_folder.is_dir(): + shutil.copytree(we2.folder, we_ro_folder) + # change permissions (R+X) + we_ro_folder.chmod(0o555) + self.we_ro = load_waveforms(we_ro_folder) + self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30) we_memory = extract_waveforms( recording, sorting, mode="memory", + sparse=False, ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, @@ -97,6 +131,12 @@ def setUp(self): folder=cache_folder / "toy_sorting_2seg_sparse", format="binary", sparsity=sparsity, overwrite=True ) + def tearDown(self): + # allow pytest to delete RO folder + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + we_ro_folder.chmod(0o777) + def _test_extension_folder(self, we, in_memory=False): if self.extension_function_kwargs_list is None: extension_function_kwargs_list = [dict()] @@ -177,3 +217,11 @@ def test_extension(self): assert ext_data_mem.equals(ext_data_zarr) else: print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.") + + # read-only - Extension is memory only + if platform.system() != "Windows": + _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) + assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() + ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) + assert ext_ro.format == "memory" + assert ext_ro.extension_folder is None diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 5d64525b52..49591d9b89 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -86,14 +86,18 @@ def test_sparse(self): pc.set_params(n_components=5, mode=mode, sparsity=sparsity) pc.run() for i, unit_id in enumerate(unit_ids): - proj = pc.get_projections(unit_id) - assert proj.shape[1:] == (5, 4) + proj_sparse = pc.get_projections(unit_id, sparse=True) + assert proj_sparse.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + proj_dense = pc.get_projections(unit_id, sparse=False) + assert proj_dense.shape[1:] == (5, num_channels) # test project_new unit_id = 3 new_wfs = we.get_waveforms(unit_id) - new_proj = pc.project_new(new_wfs, unit_id=unit_id) - assert new_proj.shape == (new_wfs.shape[0], 5, 4) + new_proj_sparse = pc.project_new(new_wfs, unit_id=unit_id, sparse=True) + assert new_proj_sparse.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + new_proj_dense = pc.project_new(new_wfs, unit_id=unit_id, sparse=False) + assert new_proj_dense.shape == (new_wfs.shape[0], 5, num_channels) if DEBUG: import matplotlib.pyplot as plt @@ -197,8 +201,8 @@ def test_project_new(self): if __name__ == "__main__": test = PrincipalComponentsExtensionTest() test.setUp() - test.test_extension() - test.test_shapes() - test.test_compute_for_all_spikes() + # test.test_extension() + # test.test_shapes() + # test.test_compute_for_all_spikes() test.test_sparse() - test.test_project_new() + # test.test_project_new() diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 9895e2ec4c..a27ccc77f8 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -17,9 +17,13 @@ def test_sparse_metrics(self): tm_sparse = self.extension_class.get_extension_function()(self.we1, sparsity=self.sparsity1) print(tm_sparse) + def test_multi_channel_metrics(self): + tm_multi = self.extension_class.get_extension_function()(self.we1, include_multi_channel_metrics=True) + print(tm_multi) + if __name__ == "__main__": test = TemplateMetricsExtensionTest() test.setUp() - test.test_extension() - test.test_sparse_metrics() + # test.test_extension() + test.test_multi_channel_metrics() diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 740fdd234b..48ceb34a4e 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -96,7 +96,7 @@ def get_extension_function(): def compute_unit_locations( - waveform_extractor, load_if_exists=False, method="center_of_mass", outputs="numpy", **method_kwargs + waveform_extractor, load_if_exists=False, method="monopolar_triangulation", outputs="numpy", **method_kwargs ): """ Localize units in 2D or 3D with several methods given the template. @@ -570,6 +570,8 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals def get_grid_convolution_templates_and_weights( contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 ): + import sklearn.metrics + x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max() y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max() @@ -593,8 +595,6 @@ def get_grid_convolution_templates_and_weights( template_positions[:, 0] = all_x.flatten() template_positions[:, 1] = all_y.flatten() - import sklearn - # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) nearest_template_mask = dist < radius_um diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index a2349c1ee9..cc18d51d2e 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -97,7 +97,7 @@ def __init__( chunk_size=500, seed=0, ): - assert direction in ("upper", "lower", "both") + assert direction in ("upper", "lower", "both"), "'direction' must be 'upper', 'lower', or 'both'" if fill_value is None or quantile_threshold is not None: random_data = get_random_data_chunks( diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index d2ac227217..6d6ce256de 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -83,7 +83,7 @@ def __init__( ref_channel_ids = np.asarray(ref_channel_ids) assert np.all( [ch in recording.get_channel_ids() for ch in ref_channel_ids] - ), "Some wrong 'ref_channel_ids'!" + ), "Some 'ref_channel_ids' are wrong!" elif reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 0b8d8a730b..55e34ba5dd 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -18,13 +18,18 @@ class DepthOrderRecording(ChannelSliceRecording): If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') + flip: bool, default: False + If flip is False then the order is bottom first (starting from tip of the probe). + If flip is True then the order is upper first. """ name = "depth_order" installed = True - def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): - order_f, order_r = order_channels_by_depth(parent_recording, channel_ids=channel_ids, dimensions=dimensions) + def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y"), flip=False): + order_f, order_r = order_channels_by_depth( + parent_recording, channel_ids=channel_ids, dimensions=dimensions, flip=flip + ) reordered_channel_ids = parent_recording.channel_ids[order_f] ChannelSliceRecording.__init__( self, @@ -35,6 +40,7 @@ def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): parent_recording=parent_recording, channel_ids=channel_ids, dimensions=dimensions, + flip=flip, ) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 0f4800c6e8..e6e2836a35 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -18,7 +18,7 @@ def detect_bad_channels( nyquist_threshold=0.8, direction="y", chunk_duration_s=0.3, - num_random_chunks=10, + num_random_chunks=100, welch_window_ms=10.0, highpass_filter_cutoff=300, neighborhood_r2_threshold=0.9, @@ -81,9 +81,10 @@ def detect_bad_channels( highpass_filter_cutoff : float If the recording is not filtered, the cutoff frequency of the highpass filter, by default 300 chunk_duration_s : float - Duration of each chunk, by default 0.3 + Duration of each chunk, by default 0.5 num_random_chunks : int - Number of random chunks, by default 10 + Number of random chunks, by default 100 + Having many chunks is important for reproducibility. welch_window_ms : float Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms neighborhood_r2_threshold : float, default 0.95 @@ -174,20 +175,18 @@ def detect_bad_channels( channel_locations = recording.get_channel_locations() dim = ["x", "y", "z"].index(direction) assert dim < channel_locations.shape[1], f"Direction {direction} is wrong" - locs_depth = channel_locations[:, dim] - if np.array_equal(np.sort(locs_depth), locs_depth): + order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) + if np.all(np.diff(order_f) == 1): + # already ordered order_f = None order_r = None - else: - # sort by x, y to avoid ambiguity - order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) # Create empty channel labels and fill with bad-channel detection estimate for each chunk chunk_channel_labels = np.zeros((recording.get_num_channels(), len(random_data)), dtype=np.int8) for i, random_chunk in enumerate(random_data): - random_chunk_sorted = random_chunk[order_f] if order_f is not None else random_chunk - chunk_channel_labels[:, i] = detect_bad_channels_ibl( + random_chunk_sorted = random_chunk[:, order_f] if order_f is not None else random_chunk + chunk_labels = detect_bad_channels_ibl( raw=random_chunk_sorted, fs=recording.sampling_frequency, psd_hf_threshold=psd_hf_threshold, @@ -198,11 +197,10 @@ def detect_bad_channels( nyquist_threshold=nyquist_threshold, welch_window_ms=welch_window_ms, ) + chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels # Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output. mode_channel_labels, _ = scipy.stats.mode(chunk_channel_labels, axis=1, keepdims=False) - if order_r is not None: - mode_channel_labels = mode_channel_labels[order_r] (bad_inds,) = np.where(mode_channel_labels != 0) bad_channel_ids = recording.channel_ids[bad_inds] @@ -213,9 +211,9 @@ def detect_bad_channels( if bad_channel_ids.size > recording.get_num_channels() / 3: warnings.warn( - "Over 1/3 of channels are detected as bad. In the precense of a high" + "Over 1/3 of channels are detected as bad. In the presence of a high" "number of dead / noisy channels, bad channel detection may fail " - "(erroneously label good channels as dead)." + "(good channels may be erroneously labeled as dead)." ) elif method == "neighborhood_r2": diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 51c1fb4ad6..b31088edf7 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -71,10 +71,10 @@ def __init__( ): import scipy.signal - assert filter_mode in ("sos", "ba") + assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'" fs = recording.get_sampling_frequency() if coeff is None: - assert btype in ("bandpass", "highpass") + assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'" # coefficient # self.coeff is 'sos' or 'ab' style filter_coeff = scipy.signal.iirfilter( @@ -258,7 +258,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): if dtype.kind == "u": raise TypeError( "The notch filter only supports signed types. Use the 'dtype' argument" - "to specify a signed type (e.g. 'int16', 'float32'" + "to specify a signed type (e.g. 'int16', 'float32')" ) BasePreprocessor.__init__(self, recording, dtype=dtype) diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 790279d647..d3a08297c6 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -50,9 +50,9 @@ def __init__( margin_ms=5.0, ): assert HAVE_PYOPENCL, "You need to install pyopencl (and GPU driver!!)" - - assert btype in ("bandpass", "lowpass", "highpass", "bandstop") - assert filter_mode in ("sos",) + btype_modes = ("bandpass", "lowpass", "highpass", "bandstop") + assert btype in btype_modes, f"'btype' must be in {btype_modes}" + assert filter_mode in ("sos",), "'filter_mode' must be 'sos'" # coefficient sf = recording.get_sampling_frequency() @@ -96,8 +96,8 @@ def __init__(self, parent_recording_segment, executor, margin): self.margin = margin def get_traces(self, start_frame, end_frame, channel_indices): - assert start_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" - assert end_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" + assert start_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" + assert end_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" chunk_size = end_frame - start_frame if chunk_size != self.executor.chunk_size: @@ -157,7 +157,7 @@ def process(self, traces): if traces.shape[0] != self.full_size: if self.full_size is not None: - print(f"Warning : chunk_size have change {self.chunk_size} {traces.shape[0]}, need recompile CL!!!") + print(f"Warning : chunk_size has changed {self.chunk_size} {traces.shape[0]}, need to recompile CL!!!") self.create_buffers_and_compile() event = pyopencl.enqueue_copy(self.queue, self.input_cl, traces) diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index aa98410568..4df4a409bc 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -212,7 +212,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces * self.taper[np.newaxis, :] # apply actual HP filter - import scipy + import scipy.signal traces = scipy.signal.sosfiltfilt(self.sos_filter, traces, axis=1) diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index e634d55e7f..95ecd0fe52 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non self.bad_channel_ids = bad_channel_ids self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids) - self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs) + self._good_channel_idxs = ~np.isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs) self._bad_channel_idxs.setflags(write=False) if sigma_um is None: diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index e2ef6e6794..6ab1a9afce 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -333,7 +333,7 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump_to_json(folder / "recording.json") np.save(folder / "peaks.npy", peaks) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 7d43982853..bd53866b6a 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -68,7 +68,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("pool_channel", "by_channel") + assert mode in ("pool_channel", "by_channel"), "'mode' must be 'pool_channel' or 'by_channel'" random_data = get_random_data_chunks(recording, **random_chunk_kwargs) @@ -260,7 +260,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("median+mad", "mean+std") + assert mode in ("median+mad", "mean+std"), "'mode' must be 'median+mad' or 'mean+std'" # fix dtype dtype_ = fix_dtype(recording, dtype) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 9c8b2589a0..bdba55038d 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -42,7 +42,9 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non assert "inter_sample_shift" in recording.get_property_keys(), "'inter_sample_shift' is not a property!" sample_shifts = recording.get_property("inter_sample_shift") else: - assert len(inter_sample_shift) == recording.get_num_channels(), "sample " + assert ( + len(inter_sample_shift) == recording.get_num_channels() + ), "the 'inter_sample_shift' must be same size at the num_channels " sample_shifts = np.asarray(inter_sample_shift) margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 3148539165..1eafa48a0b 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -107,8 +107,6 @@ def __init__( time_jitter=0, waveforms_kwargs={"allow_unfiltered": True, "mode": "memory"}, ): - import scipy.interpolate - available_modes = ("zeros", "linear", "cubic", "average", "median") num_seg = recording.get_num_segments() @@ -165,7 +163,9 @@ def __init__( for l in np.unique(labels): assert l in artifacts.keys(), f"Artefacts are provided but label {l} has no value!" else: - assert "ms_before" != None and "ms_after" != None, f"ms_before/after should not be None for mode {mode}" + assert ( + ms_before is not None and ms_after is not None + ), f"ms_before/after should not be None for mode {mode}" sorting = NumpySorting.from_times_labels(list_triggers, list_labels, recording.get_sampling_frequency()) sorting = sorting.save() waveforms_kwargs.update({"ms_before": ms_before, "ms_after": ms_after}) @@ -234,8 +234,6 @@ def __init__( time_pad, sparsity, ): - import scipy.interpolate - BasePreprocessorSegment.__init__(self, parent_recording_segment) self.triggers = np.asarray(triggers, dtype="int64") @@ -283,6 +281,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): elif trig + pad[1] >= end_frame - start_frame: traces[trig - pad[0] :, :] = 0 elif self.mode in ["linear", "cubic"]: + import scipy.interpolate + for trig in triggers: if pad is None: pre_data_end_idx = trig - 1 diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index ee28485983..d3f875959e 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -499,9 +499,8 @@ def compute_sliding_rp_violations( ) -def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs): - """ - Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of +def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): + """Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of "synchrony_size" spikes at the exact same sample index. Parameters @@ -510,6 +509,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k The waveform extractor object. synchrony_sizes : list or tuple, default: (2, 4, 8) The synchrony sizes to compute. + unit_ids : list or None, default: None + List of unit ids to compute the synchrony metrics. If None, all units are used. Returns ------- @@ -522,16 +523,20 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k Based on concepts described in [Gruen]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ - assert np.all(s > 1 for s in synchrony_sizes), "Synchrony sizes must be greater than 1" + assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit() sorting = waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) + if unit_ids is None: + unit_ids = sorting.unit_ids + # Pre-allocate synchrony counts synchrony_counts = {} for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64) + all_unit_ids = list(sorting.unit_ids) for segment_index in range(sorting.get_num_segments()): spikes_in_segment = spikes[segment_index] @@ -539,20 +544,21 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True) # add counts for this segment - for unit_index in np.arange(len(sorting.unit_ids)): + for unit_id in unit_ids: + unit_index = all_unit_ids.index(unit_id) spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index] # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue - spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] + spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) # add counts for this segment synchrony_metrics_dict = { f"sync_spike_{synchrony_size}": { - unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id] - for unit_index, unit_id in enumerate(sorting.unit_ids) + unit_id: synchrony_counts[synchrony_size][all_unit_ids.index(unit_id)] / spike_counts[unit_id] + for unit_id in unit_ids } for synchrony_size in synchrony_sizes } @@ -563,7 +569,181 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k return synchrony_metrics -_default_params["synchrony_metrics"] = dict(synchrony_sizes=(0, 2, 4)) +_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) + + +def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): + """Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution + computed in non-overlapping time bins. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + bin_size_s : float, default: 5 + The size of the bin in seconds. + percentiles : tuple, default: (5, 95) + The percentiles to compute. + unit_ids : list or None + List of unit ids to compute the firing range. If None, all units are used. + + Returns + ------- + firing_ranges : dict + The firing range for each unit. + + Notes + ----- + Designed by Simon Musall and ported to SpikeInterface by Alessio Buccino. + """ + sampling_frequency = waveform_extractor.sampling_frequency + bin_size_samples = int(bin_size_s * sampling_frequency) + sorting = waveform_extractor.sorting + if unit_ids is None: + unit_ids = sorting.unit_ids + + if all( + [ + waveform_extractor.get_num_samples(segment_index) < bin_size_samples + for segment_index in range(waveform_extractor.get_num_segments()) + ] + ): + warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.") + return {unit_id: np.nan for unit_id in unit_ids} + + # for each segment, we compute the firing rate histogram and we concatenate them + firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} + for segment_index in range(waveform_extractor.get_num_segments()): + num_samples = waveform_extractor.get_num_samples(segment_index) + edges = np.arange(0, num_samples + 1, bin_size_samples) + + for unit_id in unit_ids: + spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + spike_counts, _ = np.histogram(spike_times, bins=edges) + firing_rates = spike_counts / bin_size_s + firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) + + # finally we compute the percentiles + firing_ranges = {} + for unit_id in unit_ids: + firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( + firing_rate_histograms[unit_id], percentiles[0] + ) + + return firing_ranges + + +_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(5, 95)) + + +def compute_amplitude_cv_metrics( + waveform_extractor, + average_num_spikes_per_bin=50, + percentiles=(5, 95), + min_num_bins=10, + amplitude_extension="spike_amplitudes", + unit_ids=None, +): + """Calculate coefficient of variation of spike amplitudes within defined temporal bins. + From the distribution of coefficient of variations, both the median and the "range" (the distance between the + percentiles defined by `percentiles` parameter) are returned. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + average_num_spikes_per_bin : int, default: 50 + The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate + of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is + 100, then the temporal bin size will be 100/10 Hz = 10 s. + min_num_bins : int, default: 10 + The minimum number of bins to compute the median and range. If the number of bins is less than this then + the median and range are set to NaN. + amplitude_extension : str, default: 'spike_amplitudes' + The name of the extension to load the amplitudes from. 'spike_amplitudes' or 'amplitude_scalings'. + unit_ids : list or None + List of unit ids to compute the amplitude spread. If None, all units are used. + + Returns + ------- + amplitude_cv_median : dict + The median of the CV + amplitude_cv_range : dict + The range of the CV, computed as the distance between the percentiles. + + Notes + ----- + Designed by Simon Musall and Alessio Buccino. + """ + res = namedtuple("amplitude_cv", ["amplitude_cv_median", "amplitude_cv_range"]) + assert amplitude_extension in ( + "spike_amplitudes", + "amplitude_scalings", + ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" + sorting = waveform_extractor.sorting + total_duration = waveform_extractor.get_total_duration() + spikes = sorting.to_spike_vector() + num_spikes = sorting.count_num_spikes_per_unit() + if unit_ids is None: + unit_ids = sorting.unit_ids + + if waveform_extractor.is_extension(amplitude_extension): + sac = waveform_extractor.load_extension(amplitude_extension) + amps = sac.get_data(outputs="concatenated") + if amplitude_extension == "spike_amplitudes": + amps = np.concatenate(amps) + else: + warnings.warn("") + empty_dict = {unit_id: np.nan for unit_id in unit_ids} + return empty_dict + + # precompute segment slice + segment_slices = [] + for segment_index in range(waveform_extractor.get_num_segments()): + i0 = np.searchsorted(spikes["segment_index"], segment_index) + i1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + segment_slices.append(slice(i0, i1)) + + all_unit_ids = list(sorting.unit_ids) + amplitude_cv_medians, amplitude_cv_ranges = {}, {} + for unit_id in unit_ids: + firing_rate = num_spikes[unit_id] / total_duration + temporal_bin_size_samples = int( + (average_num_spikes_per_bin / firing_rate) * waveform_extractor.sampling_frequency + ) + + amp_spreads = [] + # bins and amplitude means are computed for each segment + for segment_index in range(waveform_extractor.get_num_segments()): + sample_bin_edges = np.arange( + 0, waveform_extractor.get_num_samples(segment_index) + 1, temporal_bin_size_samples + ) + spikes_in_segment = spikes[segment_slices[segment_index]] + amps_in_segment = amps[segment_slices[segment_index]] + unit_mask = spikes_in_segment["unit_index"] == all_unit_ids.index(unit_id) + spike_indices_unit = spikes_in_segment["sample_index"][unit_mask] + amps_unit = amps_in_segment[unit_mask] + amp_mean = np.abs(np.mean(amps_unit)) + for t0, t1 in zip(sample_bin_edges[:-1], sample_bin_edges[1:]): + i0 = np.searchsorted(spike_indices_unit, t0) + i1 = np.searchsorted(spike_indices_unit, t1) + amp_spreads.append(np.std(amps_unit[i0:i1]) / amp_mean) + + if len(amp_spreads) < min_num_bins: + amplitude_cv_medians[unit_id] = np.nan + amplitude_cv_ranges[unit_id] = np.nan + else: + amplitude_cv_medians[unit_id] = np.median(amp_spreads) + amplitude_cv_ranges[unit_id] = np.percentile(amp_spreads, percentiles[1]) - np.percentile( + amp_spreads, percentiles[0] + ) + + return res(amplitude_cv_medians, amplitude_cv_ranges) + + +_default_params["amplitude_cv"] = dict( + average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes" +) def compute_amplitude_cutoffs( @@ -848,16 +1028,14 @@ def compute_drift_metrics( spike_vector = sorting.to_spike_vector() # retrieve spikes in segment - i0 = np.searchsorted(spike_vector["segment_index"], segment_index) - i1 = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spikes_in_segment = spike_vector[i0:i1] spike_locations_in_segment = spike_locations[i0:i1] # compute median positions (if less than min_spikes_per_interval, median position is 0) median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 59000211d4..ed06f7d738 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -152,8 +152,8 @@ def calculate_pc_metrics( neighbor_unit_ids = unit_ids neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids) - labels = all_labels[np.in1d(all_labels, neighbor_unit_ids)] - pcs = all_pcs[np.in1d(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] + pcs = all_pcs[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) func_args = ( @@ -506,7 +506,7 @@ def nearest_neighbors_isolation( other_units_ids = [ unit_id for unit_id in other_units_ids - if np.sum(np.in1d(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) + if np.sum(np.isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) >= (n_channels_target_unit * min_spatial_overlap) ] @@ -536,10 +536,10 @@ def nearest_neighbors_isolation( if waveform_extractor.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ - :, :, np.in1d(closest_chans_target_unit, common_channel_idxs) + :, :, np.isin(closest_chans_target_unit, common_channel_idxs) ] waveforms_other_unit_sampled = waveforms_other_unit_sampled[ - :, :, np.in1d(closest_chans_other_unit, common_channel_idxs) + :, :, np.isin(closest_chans_other_unit, common_channel_idxs) ] else: waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs] diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 90dbb47a3a..97f14ec6f4 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -12,6 +12,8 @@ compute_amplitude_medians, compute_drift_metrics, compute_synchrony_metrics, + compute_firing_ranges, + compute_amplitude_cv_metrics, ) from .pca_metrics import ( @@ -40,6 +42,8 @@ "sliding_rp_violation": compute_sliding_rp_violations, "amplitude_cutoff": compute_amplitude_cutoffs, "amplitude_median": compute_amplitude_medians, + "amplitude_cv": compute_amplitude_cv_metrics, "synchrony": compute_synchrony_metrics, + "firing_range": compute_firing_ranges, "drift": compute_drift_metrics, } diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index d927d64c4f..8a32c4cee8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -12,6 +12,7 @@ compute_principal_components, compute_spike_locations, compute_spike_amplitudes, + compute_amplitude_scalings, ) from spikeinterface.qualitymetrics import ( @@ -31,6 +32,8 @@ compute_drift_metrics, compute_amplitude_medians, compute_synchrony_metrics, + compute_firing_ranges, + compute_amplitude_cv_metrics, ) @@ -212,6 +215,16 @@ def test_calculate_firing_rate_num_spikes(waveform_extractor_simple): # np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) +def test_calculate_firing_range(waveform_extractor_simple): + we = waveform_extractor_simple + firing_ranges = compute_firing_ranges(we) + print(firing_ranges) + + with pytest.warns(UserWarning) as w: + firing_ranges_nan = compute_firing_ranges(we, bin_size_s=we.get_total_duration() + 1) + assert np.all([np.isnan(f) for f in firing_ranges_nan.values()]) + + def test_calculate_amplitude_cutoff(waveform_extractor_simple): we = waveform_extractor_simple spike_amps = compute_spike_amplitudes(we) @@ -234,6 +247,24 @@ def test_calculate_amplitude_median(waveform_extractor_simple): # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) +def test_calculate_amplitude_cv_metrics(waveform_extractor_simple): + we = waveform_extractor_simple + spike_amps = compute_spike_amplitudes(we) + amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(we, average_num_spikes_per_bin=20) + print(amp_cv_median) + print(amp_cv_range) + + amps_scalings = compute_amplitude_scalings(we) + amp_cv_median_scalings, amp_cv_range_scalings = compute_amplitude_cv_metrics( + we, + average_num_spikes_per_bin=20, + amplitude_extension="amplitude_scalings", + min_num_bins=5, + ) + print(amp_cv_median_scalings) + print(amp_cv_range_scalings) + + def test_calculate_snrs(waveform_extractor_simple): we = waveform_extractor_simple snrs = compute_snrs(we) @@ -351,11 +382,13 @@ def test_calculate_drift_metrics(waveform_extractor_simple): if __name__ == "__main__": sim_data = _simulated_data() we = _waveform_extractor_simple() - we_violations = _waveform_extractor_violations(sim_data) + # we_violations = _waveform_extractor_violations(sim_data) # test_calculate_amplitude_cutoff(we) # test_calculate_presence_ratio(we) # test_calculate_amplitude_median(we) # test_calculate_isi_violations(we) # test_calculate_sliding_rp_violations(we) # test_calculate_drift_metrics(we) - test_synchrony_metrics(we) + # test_synchrony_metrics(we) + test_calculate_firing_range(we) + # test_calculate_amplitude_cv_metrics(we) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 4fa65993d1..977beca210 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -261,7 +261,8 @@ def test_nn_metrics(self): we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2 ) for metric_name in metrics.columns: - assert np.allclose(metrics[metric_name], metrics_par[metric_name]) + # NaNs are skipped + assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) def test_recordingless(self): we = self.we_long @@ -305,7 +306,7 @@ def test_empty_units(self): test.setUp() # test.test_drift_metrics() # test.test_extension() - # test.test_nn_metrics() + test.test_nn_metrics() # test.test_peak_sign() # test.test_empty_units() - test.test_recordingless() + # test.test_recordingless() diff --git a/src/spikeinterface/sorters/__init__.py b/src/spikeinterface/sorters/__init__.py index a0d437559d..ba663327e8 100644 --- a/src/spikeinterface/sorters/__init__.py +++ b/src/spikeinterface/sorters/__init__.py @@ -1,11 +1,4 @@ from .basesorter import BaseSorter from .sorterlist import * from .runsorter import * - -from .launcher import ( - run_sorters, - run_sorter_by_property, - collect_sorting_outputs, - iter_working_folder, - iter_sorting_output, -) +from .launcher import run_sorter_jobs, run_sorters, run_sorter_by_property diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index ff559cc78d..a956f8c811 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -137,8 +137,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) rec_file = output_folder / "spikeinterface_recording.json" - if recording.check_if_json_serializable(): - recording.dump_to_json(rec_file, relative_to=output_folder) + if recording.check_serializablility("json"): + recording.dump(rec_file, relative_to=output_folder) + elif recording.check_serializablility("pickle"): + recording.dump(output_folder / "spikeinterface_recording.pickle") else: d = {"warning": "The recording is not serializable to json"} rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") @@ -185,6 +187,26 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): return params + @classmethod + def load_recording_from_folder(cls, output_folder, with_warnings=False): + json_file = output_folder / "spikeinterface_recording.json" + pickle_file = output_folder / "spikeinterface_recording.pickle" + + if json_file.exists(): + with (json_file).open("r", encoding="utf8") as f: + recording_dict = json.load(f) + if "warning" in recording_dict.keys() and with_warnings: + warnings.warn( + "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." + ) + recording = None + else: + recording = load_extractor(json_file, base_folder=output_folder) + elif pickle_file.exists(): + recording = load_extractor(pickle_file) + + return recording + @classmethod def _dump_params(cls, recording, output_folder, sorter_params, verbose): with (output_folder / "spikeinterface_params.json").open(mode="w", encoding="utf8") as f: @@ -271,7 +293,7 @@ def run_from_folder(cls, output_folder, raise_error, verbose): return run_time @classmethod - def get_result_from_folder(cls, output_folder): + def get_result_from_folder(cls, output_folder, register_recording=True, sorting_info=True): output_folder = Path(output_folder) sorter_output_folder = output_folder / "sorter_output" # check errors in log file @@ -294,27 +316,25 @@ def get_result_from_folder(cls, output_folder): # back-compatibility sorting = cls._get_result_from_folder(output_folder) - # register recording to Sorting object - # check if not json serializable - with (output_folder / "spikeinterface_recording.json").open("r", encoding="utf8") as f: - recording_dict = json.load(f) - if "warning" in recording_dict.keys(): - warnings.warn( - "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." - ) - else: - recording = load_extractor(output_folder / "spikeinterface_recording.json", base_folder=output_folder) + if register_recording: + # register recording to Sorting object + recording = cls.load_recording_from_folder(output_folder, with_warnings=False) if recording is not None: - # can be None when not dumpable sorting.register_recording(recording) - # set sorting info to Sorting object - with open(output_folder / "spikeinterface_recording.json", "r") as f: - rec_dict = json.load(f) - with open(output_folder / "spikeinterface_params.json", "r") as f: - params_dict = json.load(f) - with open(output_folder / "spikeinterface_log.json", "r") as f: - log_dict = json.load(f) - sorting.set_sorting_info(rec_dict, params_dict, log_dict) + + if sorting_info: + # set sorting info to Sorting object + if (output_folder / "spikeinterface_recording.json").exists(): + with open(output_folder / "spikeinterface_recording.json", "r") as f: + rec_dict = json.load(f) + else: + rec_dict = None + + with open(output_folder / "spikeinterface_params.json", "r") as f: + params_dict = json.load(f) + with open(output_folder / "spikeinterface_log.json", "r") as f: + log_dict = json.load(f) + sorting.set_sorting_info(rec_dict, params_dict, log_dict) return sorting @@ -411,3 +431,14 @@ def get_job_kwargs(params, verbose): if not verbose: job_kwargs["progress_bar"] = False return job_kwargs + + +def is_log_ok(output_folder): + # log is OK when run_time is not None + if (output_folder / "spikeinterface_log.json").is_file(): + with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: + log = json.load(logfile) + run_time = log.get("run_time", None) + ok = run_time is not None + return ok + return False diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index a8d702ebe9..5180e6f1cc 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -147,9 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: new_api = False - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) p = params diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index 69f97fd11c..f6f0b3eaeb 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -89,9 +89,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort4 - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index df6d276bf5..a88c59d688 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -115,9 +115,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - recording: BaseRecording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index 2a41d793d5..1962d56206 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -148,9 +148,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): # saved by setup recording diff --git a/src/spikeinterface/sorters/internal/si_based.py b/src/spikeinterface/sorters/internal/si_based.py index 1496ffbbd1..989fab1258 100644 --- a/src/spikeinterface/sorters/internal/si_based.py +++ b/src/spikeinterface/sorters/internal/si_based.py @@ -1,4 +1,4 @@ -from spikeinterface.core import load_extractor +from spikeinterface.core import load_extractor, NumpyRecording from spikeinterface.sorters import BaseSorter @@ -14,7 +14,6 @@ def is_installed(cls): @classmethod def _setup_recording(cls, recording, output_folder, params, verbose): - # nothing to do here because the spikeinterface_recording.json is here anyway pass @classmethod diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9de2762562..a16b642dd5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -3,11 +3,10 @@ import os import shutil import numpy as np -import os from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore +from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter try: import hdbscan @@ -22,17 +21,16 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True}, - "filtering": {"dtype": "float32"}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, + "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, "clustering": {}, "matching": {}, - "registration": {}, "apply_preprocessing": True, - "shared_memory": False, - "job_kwargs": {}, + "shared_memory": True, + "job_kwargs": {"n_jobs": -1}, } @classmethod @@ -54,23 +52,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs["verbose"] = verbose job_kwargs["progress_bar"] = verbose - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + sampling_rate = recording.get_sampling_frequency() num_channels = recording.get_num_channels() ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - # if recording.is_filtered == True: - # print('Looks like the recording is already filtered, check preprocessing!') - recording_f = bandpass_filter(recording, **filtering_params) + recording_f = highpass_filter(recording, **filtering_params) recording_f = common_reference(recording_f) else: recording_f = recording + # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") + noise_levels = np.ones(num_channels, dtype=np.float32) ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() @@ -91,7 +88,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - noise_levels = np.ones(num_channels, dtype=np.float32) selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks( peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params @@ -103,11 +99,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets clustering_params = params["clustering"].copy() - clustering_params.update(params["waveforms"]) - clustering_params.update(params["general"]) + clustering_params["waveforms"] = params["waveforms"].copy() + + for k in ["ms_before", "ms_after"]: + clustering_params["waveforms"][k] = params["general"][k] + clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params.update({"noise_levels": noise_levels}) labels, peak_labels = find_cluster_from_peaks( recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params @@ -115,21 +115,25 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We get the labels for our peaks mask = peak_labels > -1 - sorting = NumpySorting.from_times_labels(selected_peaks["sample_index"][mask], peak_labels[mask], sampling_rate) + sorting = NumpySorting.from_times_labels( + selected_peaks["sample_index"][mask], peak_labels[mask].astype(int), sampling_rate + ) clustering_folder = sorter_output_folder / "clustering" if clustering_folder.exists(): shutil.rmtree(clustering_folder) - sorting = sorting.save(folder=clustering_folder) - ## We get the templates our of such a clustering waveforms_params = params["waveforms"].copy() waveforms_params.update(job_kwargs) + for k in ["ms_before", "ms_after"]: + waveforms_params[k] = params["general"][k] + if params["shared_memory"]: mode = "memory" waveforms_folder = None else: + sorting = sorting.save(folder=clustering_folder) mode = "folder" waveforms_folder = sorter_output_folder / "waveforms" @@ -143,10 +147,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in matching_job_params: + matching_job_params.pop(value) + matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( - recording_f, method="circus-omp", method_kwargs=matching_params, **matching_job_params + recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params ) if verbose: diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 42f51d3a77..e256915fa6 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -1,33 +1,53 @@ +import shutil from .si_based import ComponentsBasedSorter -from spikeinterface.core import load_extractor, BaseRecording, get_noise_levels, extract_waveforms, NumpySorting +from spikeinterface.core import ( + load_extractor, + BaseRecording, + get_noise_levels, + extract_waveforms, + NumpySorting, + get_channel_distances, +) +from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer from spikeinterface.core.job_tools import fix_job_kwargs + from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore +from spikeinterface.core.basesorting import minimum_spike_dtype import numpy as np +import pickle +import json + class Tridesclous2Sorter(ComponentsBasedSorter): sorter_name = "tridesclous2" _default_params = { "apply_preprocessing": True, - "general": {"ms_before": 2.5, "ms_after": 3.5, "radius_um": 100}, - "filtering": {"freq_min": 300, "freq_max": 8000.0}, - "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 0.4}, - "hdbscan_kwargs": { - "min_cluster_size": 25, - "allow_single_cluster": True, - "core_dist_n_jobs": -1, - "cluster_selection_method": "leaf", + "waveforms": { + "ms_before": 0.5, + "ms_after": 1.5, + "radius_um": 120.0, }, - "waveforms": {"max_spikes_per_unit": 300}, + "filtering": {"freq_min": 300.0, "freq_max": 12000.0}, + "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "localization": {"max_distance_um": 1000, "optimizer": "minimize_with_log_penality"}, - "matching": { - "peak_shift_ms": 0.2, + "features": {}, + "svd": {"n_components": 6}, + "clustering": { + "split_radius_um": 40.0, + "merge_radius_um": 40.0, + }, + "templates": { + "ms_before": 1.5, + "ms_after": 2.5, + # "peak_shift_ms": 0.2, }, - "job_kwargs": {}, + "matching": {"peak_shift_ms": 0.2, "radius_um": 100.0}, + "job_kwargs": {"n_jobs": -1}, + "save_array": True, } @classmethod @@ -40,18 +60,27 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = fix_job_kwargs(job_kwargs) job_kwargs["progress_bar"] = verbose - # this is importanted only on demand because numba import are too heavy - from spikeinterface.sortingcomponents.peak_detection import detect_peaks - from spikeinterface.sortingcomponents.peak_localization import localize_peaks - from spikeinterface.sortingcomponents.peak_selection import select_peaks - from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.matching import find_spikes_from_templates + from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, + ) + from spikeinterface.sortingcomponents.peak_detection import detect_peaks, DetectPeakLocallyExclusive + from spikeinterface.sortingcomponents.peak_selection import select_peaks + from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeGridConvolution + from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection + + from spikeinterface.sortingcomponents.clustering.split import split_clusters + from spikeinterface.sortingcomponents.clustering.merge import merge_clusters + from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse + + from sklearn.decomposition import TruncatedSVD import hdbscan - recording_raw = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) num_chans = recording_raw.get_num_channels() sampling_frequency = recording_raw.get_sampling_frequency() @@ -59,6 +88,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # preprocessing if params["apply_preprocessing"]: recording = bandpass_filter(recording_raw, **params["filtering"]) + # TODO what is the best about zscore>common_reference or the reverse recording = common_reference(recording) recording = zscore(recording, dtype="float32") noise_levels = np.ones(num_chans, dtype="float32") @@ -68,83 +98,258 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # detection detection_params = params["detection"].copy() - detection_params["radius_um"] = params["general"]["radius_um"] detection_params["noise_levels"] = noise_levels - peaks = detect_peaks(recording, method="locally_exclusive", **detection_params, **job_kwargs) + all_peaks = detect_peaks(recording, method="locally_exclusive", **detection_params, **job_kwargs) if verbose: - print("We found %d peaks in total" % len(peaks)) + print("We found %d peaks in total" % len(all_peaks)) # selection selection_params = params["selection"].copy() - selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_chans - selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - selection_params["noise_levels"] = noise_levels - some_peaks = select_peaks( - peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params - ) + n_peaks = params["selection"]["n_peaks_per_channel"] * num_chans + n_peaks = max(selection_params["min_n_peaks"], n_peaks) + peaks = select_peaks(all_peaks, method="uniform", n_peaks=n_peaks) if verbose: - print("We kept %d peaks for clustering" % len(some_peaks)) + print("We kept %d peaks for clustering" % len(peaks)) + + # SVD for time compression + few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) + few_wfs = extract_waveform_at_max_channel(recording, few_peaks, **job_kwargs) + + wfs = few_wfs[:, :, 0] + tsvd = TruncatedSVD(params["svd"]["n_components"]) + tsvd.fit(wfs) + + model_folder = sorter_output_folder / "tsvd_model" + + model_folder.mkdir(exist_ok=True) + with open(model_folder / "pca_model.pkl", "wb") as f: + pickle.dump(tsvd, f) + + ms_before = params["waveforms"]["ms_before"] + ms_after = params["waveforms"]["ms_after"] + model_params = { + "ms_before": ms_before, + "ms_after": ms_after, + "sampling_frequency": float(sampling_frequency), + } + with open(model_folder / "params.json", "w") as f: + json.dump(model_params, f) + + # features + + features_folder = sorter_output_folder / "features" + node0 = PeakRetriever(recording, peaks) + + # node1 = ExtractDenseWaveforms(rec, parents=[node0], return_output=False, + # ms_before=0.5, + # ms_after=1.5, + # ) + + # node2 = LocalizeCenterOfMass(rec, parents=[node0, node1], return_output=True, + # local_radius_um=75.0, + # feature="ptp", ) + + # node2 = LocalizeGridConvolution(rec, parents=[node0, node1], return_output=True, + # local_radius_um=40., + # upsampling_um=5.0, + # ) + + radius_um = params["waveforms"]["radius_um"] + node3 = ExtractSparseWaveforms( + recording, + parents=[node0], + return_output=True, + ms_before=ms_before, + ms_after=ms_after, + radius_um=radius_um, + ) + + model_folder_path = sorter_output_folder / "tsvd_model" + + node4 = TemporalPCAProjection( + recording, parents=[node0, node3], return_output=True, model_folder_path=model_folder_path + ) + + # pipeline_nodes = [node0, node1, node2, node3, node4] + pipeline_nodes = [node0, node3, node4] + + output = run_node_pipeline( + recording, + pipeline_nodes, + job_kwargs, + gather_mode="npy", + gather_kwargs=dict(exist_ok=True), + folder=features_folder, + names=["sparse_wfs", "sparse_tsvd"], + ) + + # TODO make this generic in GatherNPY ??? + sparse_mask = node3.neighbours_mask + np.save(features_folder / "sparse_mask.npy", sparse_mask) + np.save(features_folder / "peaks.npy", peaks) + + # Clustering: channel index > split > merge + split_radius_um = params["clustering"]["split_radius_um"] + neighbours_mask = get_channel_distances(recording) < split_radius_um - # localization - localization_params = params["localization"].copy() - localization_params["radius_um"] = params["general"]["radius_um"] - peak_locations = localize_peaks( - recording, some_peaks, method="monopolar_triangulation", **localization_params, **job_kwargs + original_labels = peaks["channel_index"] + + min_cluster_size = 50 + + post_split_label, split_count = split_clusters( + original_labels, + recording, + features_folder, + method="local_feature_clustering", + method_kwargs=dict( + # clusterer="hdbscan", + clusterer="isocut5", + feature_name="sparse_tsvd", + # feature_name="sparse_wfs", + neighbours_mask=neighbours_mask, + waveforms_sparse_mask=sparse_mask, + min_size_split=min_cluster_size, + min_cluster_size=min_cluster_size, + min_samples=50, + n_pca_features=3, + ), + recursive=True, + recursive_depth=3, + returns_split_count=True, + **job_kwargs, + ) + + merge_radius_um = params["clustering"]["merge_radius_um"] + + post_merge_label, peak_shifts = merge_clusters( + peaks, + post_split_label, + recording, + features_folder, + radius_um=merge_radius_um, + # method="project_distribution", + # method_kwargs=dict( + # waveforms_sparse_mask=sparse_mask, + # feature_name="sparse_wfs", + # projection="centroid", + # criteria="distrib_overlap", + # threshold_overlap=0.3, + # min_cluster_size=min_cluster_size + 1, + # num_shift=5, + # ), + method="normalized_template_diff", + method_kwargs=dict( + waveforms_sparse_mask=sparse_mask, + threshold_diff=0.2, + min_cluster_size=min_cluster_size + 1, + num_shift=5, + ), + **job_kwargs, ) - # ~ print(peak_locations.dtype) + # sparse_wfs = np.load(features_folder / "sparse_wfs.npy", mmap_mode="r") - # features = localisations only - peak_features = np.zeros((peak_locations.size, 3), dtype="float64") - for i, dim in enumerate(["x", "y", "z"]): - peak_features[:, i] = peak_locations[dim] + new_peaks = peaks.copy() + new_peaks["sample_index"] -= peak_shifts - # clusering is hdbscan + # clean very small cluster before peeler + minimum_cluster_size = 25 + labels_set, count = np.unique(post_merge_label, return_counts=True) + to_remove = labels_set[count < minimum_cluster_size] - out = hdbscan.hdbscan(peak_features, **params["hdbscan_kwargs"]) - peak_labels = out[0] + mask = np.isin(post_merge_label, to_remove) + post_merge_label[mask] = -1 - mask = peak_labels >= 0 - labels = np.unique(peak_labels[mask]) + # final label sets + labels_set = np.unique(post_merge_label) + labels_set = labels_set[labels_set >= 0] - # extract waveform for template matching + mask = post_merge_label >= 0 sorting_temp = NumpySorting.from_times_labels( - some_peaks["sample_index"][mask], peak_labels[mask], sampling_frequency + new_peaks["sample_index"][mask], + post_merge_label[mask], + sampling_frequency, + unit_ids=labels_set, ) sorting_temp = sorting_temp.save(folder=sorter_output_folder / "sorting_temp") - waveforms_params = params["waveforms"].copy() - waveforms_params["ms_before"] = params["general"]["ms_before"] - waveforms_params["ms_after"] = params["general"]["ms_after"] + + ms_before = params["templates"]["ms_before"] + ms_after = params["templates"]["ms_after"] + max_spikes_per_unit = 300 + we = extract_waveforms( - recording, sorting_temp, sorter_output_folder / "waveforms_temp", **waveforms_params, **job_kwargs + recording, + sorting_temp, + sorter_output_folder / "waveforms_temp", + ms_before=ms_before, + ms_after=ms_after, + max_spikes_per_unit=max_spikes_per_unit, + **job_kwargs, ) - ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces matching_params = params["matching"].copy() matching_params["waveform_extractor"] = we matching_params["noise_levels"] = noise_levels matching_params["peak_sign"] = params["detection"]["peak_sign"] matching_params["detect_threshold"] = params["detection"]["detect_threshold"] - matching_params["radius_um"] = params["general"]["radius_um"] - - # TODO: route that params - # ~ 'num_closest' : 5, - # ~ 'sample_shift': 3, - # ~ 'ms_before': 0.8, - # ~ 'ms_after': 1.2, - # ~ 'num_peeler_loop': 2, - # ~ 'num_template_try' : 1, + matching_params["radius_um"] = params["detection"]["radius_um"] spikes = find_spikes_from_templates( recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs ) - if verbose: - print("We found %d spikes" % len(spikes)) + if params["save_array"]: + np.save(sorter_output_folder / "noise_levels.npy", noise_levels) + np.save(sorter_output_folder / "all_peaks.npy", all_peaks) + np.save(sorter_output_folder / "post_split_label.npy", post_split_label) + np.save(sorter_output_folder / "split_count.npy", split_count) + np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label) + np.save(sorter_output_folder / "spikes.npy", spikes) - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) + final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype) + final_spikes["sample_index"] = spikes["sample_index"] + final_spikes["unit_index"] = spikes["cluster_index"] + final_spikes["segment_index"] = spikes["segment_index"] + + sorting = NumpySorting(final_spikes, sampling_frequency, labels_set) sorting = sorting.save(folder=sorter_output_folder / "sorting") return sorting + + +def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): + """ + Helper function to extractor waveforms at max channel from a peak list + + + """ + n = rec.get_num_channels() + unit_ids = np.arange(n, dtype="int64") + sparsity_mask = np.eye(n, dtype="bool") + + spikes = np.zeros( + peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + ) + spikes["sample_index"] = peaks["sample_index"] + spikes["unit_index"] = peaks["channel_index"] + spikes["segment_index"] = peaks["segment_index"] + + nbefore = int(ms_before * rec.sampling_frequency / 1000.0) + nafter = int(ms_after * rec.sampling_frequency / 1000.0) + + all_wfs = extract_waveforms_to_single_buffer( + rec, + spikes, + unit_ids, + nbefore, + nafter, + mode="shared_memory", + return_scaled=False, + sparsity_mask=sparsity_mask, + copy=True, + **job_kwargs, + ) + + return all_wfs diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 52098f45cd..704f6843f2 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -4,61 +4,198 @@ from pathlib import Path import shutil import numpy as np -import json import tempfile import os import stat import subprocess import sys +import warnings -from spikeinterface.core import load_extractor, aggregate_units -from spikeinterface.core.core_tools import check_json +from spikeinterface.core import aggregate_units from .sorterlist import sorter_dict -from .runsorter import run_sorter, run_sorter - - -def _run_one(arg_list): - # the multiprocessing python module force to have one unique tuple argument - ( - sorter_name, - recording, - output_folder, - verbose, - sorter_params, - docker_image, - singularity_image, - with_output, - ) = arg_list - - if isinstance(recording, dict): - recording = load_extractor(recording) +from .runsorter import run_sorter +from .basesorter import is_log_ok + +_default_engine_kwargs = dict( + loop=dict(), + joblib=dict(n_jobs=-1, backend="loky"), + processpoolexecutor=dict(max_workers=2, mp_context=None), + dask=dict(client=None), + slurm=dict(tmp_script_folder=None, cpus_per_task=1, mem="1G"), +) + + +_implemented_engine = list(_default_engine_kwargs.keys()) + + +def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=False): + """ + Run several :py:func:`run_sorter()` sequentially or in parallel given a list of jobs. + + For **engine="loop"** this is equivalent to: + + ..code:: + + for job in job_list: + run_sorter(**job) + + The following engines block the I/O: + * "loop" + * "joblib" + * "multiprocessing" + * "dask" + + The following engines are *asynchronous*: + * "slurm" + + Where *blocking* means that this function is blocking until the results are returned. + This is in opposition to *asynchronous*, where the function returns `None` almost immediately (aka non-blocking), + but the results must be retrieved by hand when jobs are finished. No mechanisim is provided here to be know + when jobs are finish. + In this *asynchronous* case, the :py:func:`~spikeinterface.sorters.read_sorter_folder()` helps to retrieve individual results. + + + Parameters + ---------- + job_list: list of dict + A list a dict that are propagated to run_sorter(...) + engine: str "loop", "joblib", "dask", "slurm" + The engine to run the list. + * "loop": a simple loop. This engine is + engine_kwargs: dict + + return_output: bool, dfault False + Return a sortings or None. + This also overwrite kwargs in in run_sorter(with_sorting=True/False) + + Returns + ------- + sortings: None or list of sorting + With engine="loop" or "joblib" you can optional get directly the list of sorting result if return_output=True. + """ + + assert engine in _implemented_engine, f"engine must be in {_implemented_engine}" + + engine_kwargs_ = dict() + engine_kwargs_.update(_default_engine_kwargs[engine]) + engine_kwargs_.update(engine_kwargs) + engine_kwargs = engine_kwargs_ + + if return_output: + assert engine in ( + "loop", + "joblib", + "processpoolexecutor", + ), "Only 'loop', 'joblib', and 'processpoolexecutor' support return_output=True." + out = [] + for kwargs in job_list: + kwargs["with_output"] = True else: - recording = recording - - # because this is checks in run_sorters before this call - remove_existing_folder = False - # result is retrieve later - delete_output_folder = False - # because we won't want the loop/worker to break - raise_error = False - - run_sorter( - sorter_name, - recording, - output_folder=output_folder, - remove_existing_folder=remove_existing_folder, - delete_output_folder=delete_output_folder, - verbose=verbose, - raise_error=raise_error, - docker_image=docker_image, - singularity_image=singularity_image, - with_output=with_output, - **sorter_params, - ) + out = None + for kwargs in job_list: + kwargs["with_output"] = False + + if engine == "loop": + # simple loop in main process + for kwargs in job_list: + sorting = run_sorter(**kwargs) + if return_output: + out.append(sorting) + + elif engine == "joblib": + from joblib import Parallel, delayed + + n_jobs = engine_kwargs["n_jobs"] + backend = engine_kwargs["backend"] + sortings = Parallel(n_jobs=n_jobs, backend=backend)(delayed(run_sorter)(**kwargs) for kwargs in job_list) + if return_output: + out.extend(sortings) + + elif engine == "processpoolexecutor": + from concurrent.futures import ProcessPoolExecutor + + max_workers = engine_kwargs["max_workers"] + mp_context = engine_kwargs["mp_context"] + with ProcessPoolExecutor(max_workers=max_workers, mp_context=mp_context) as executor: + futures = [] + for kwargs in job_list: + res = executor.submit(run_sorter, **kwargs) + futures.append(res) + for futur in futures: + sorting = futur.result() + if return_output: + out.append(sorting) -_implemented_engine = ("loop", "joblib", "dask", "slurm") + elif engine == "dask": + client = engine_kwargs["client"] + assert client is not None, "For dask engine you have to provide : client = dask.distributed.Client(...)" + + tasks = [] + for kwargs in job_list: + task = client.submit(run_sorter, **kwargs) + tasks.append(task) + + for task in tasks: + task.result() + + elif engine == "slurm": + # generate python script for slurm + tmp_script_folder = engine_kwargs["tmp_script_folder"] + if tmp_script_folder is None: + tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_") + tmp_script_folder = Path(tmp_script_folder) + cpus_per_task = engine_kwargs["cpus_per_task"] + mem = engine_kwargs["mem"] + + tmp_script_folder.mkdir(exist_ok=True, parents=True) + + for i, kwargs in enumerate(job_list): + script_name = tmp_script_folder / f"si_script_{i}.py" + with open(script_name, "w") as f: + kwargs_txt = "" + for k, v in kwargs.items(): + kwargs_txt += " " + if k == "recording": + # put None temporally + kwargs_txt += "recording=None" + else: + if isinstance(v, str): + kwargs_txt += f'{k}="{v}"' + elif isinstance(v, Path): + kwargs_txt += f'{k}="{str(v.absolute())}"' + else: + kwargs_txt += f"{k}={v}" + kwargs_txt += ",\n" + + # recording_dict = task_args[1] + recording_dict = kwargs["recording"].to_dict() + slurm_script = _slurm_script.format( + python=sys.executable, recording_dict=recording_dict, kwargs_txt=kwargs_txt + ) + f.write(slurm_script) + os.fchmod(f.fileno(), mode=stat.S_IRWXU) + + subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"]) + + return out + + +_slurm_script = """#! {python} +from numpy import array +from spikeinterface import load_extractor +from spikeinterface.sorters import run_sorter + +rec_dict = {recording_dict} + +kwargs = dict( +{kwargs_txt} +) +kwargs['recording'] = load_extractor(rec_dict) + +run_sorter(**kwargs) +""" def run_sorter_by_property( @@ -66,7 +203,7 @@ def run_sorter_by_property( recording, grouping_property, working_folder, - mode_if_folder_exists="raise", + mode_if_folder_exists=None, engine="loop", engine_kwargs={}, verbose=False, @@ -93,11 +230,10 @@ def run_sorter_by_property( Property to split by before sorting working_folder: str The working directory. - mode_if_folder_exists: {'raise', 'overwrite', 'keep'} - The mode when the subfolder of recording/sorter already exists. - * 'raise' : raise error if subfolder exists - * 'overwrite' : delete and force recompute - * 'keep' : do not compute again if f=subfolder exists and log is OK + mode_if_folder_exists: None + Must be None. This is deprecated. + If not None then a warning is raise. + Will be removed in next release. engine: {'loop', 'joblib', 'dask'} Which engine to use to run sorter. engine_kwargs: dict @@ -127,46 +263,49 @@ def run_sorter_by_property( engine_kwargs={"n_jobs": 4}) """ + if mode_if_folder_exists is not None: + warnings.warn( + "run_sorter_by_property(): mode_if_folder_exists is not used anymore", + DeprecationWarning, + stacklevel=2, + ) + + working_folder = Path(working_folder).absolute() assert grouping_property in recording.get_property_keys(), ( f"The 'grouping_property' {grouping_property} is not " f"a recording property!" ) recording_dict = recording.split_by(grouping_property) - sorting_output = run_sorters( - [sorter_name], - recording_dict, - working_folder, - mode_if_folder_exists=mode_if_folder_exists, - engine=engine, - engine_kwargs=engine_kwargs, - verbose=verbose, - with_output=True, - docker_images={sorter_name: docker_image}, - singularity_images={sorter_name: singularity_image}, - sorter_params={sorter_name: sorter_params}, - ) - grouping_property_values = None - sorting_list = [] - for output_name, sorting in sorting_output.items(): - prop_name, sorter_name = output_name - sorting_list.append(sorting) - if grouping_property_values is None: - grouping_property_values = np.array( - [prop_name] * len(sorting.get_unit_ids()), dtype=np.dtype(type(prop_name)) - ) - else: - grouping_property_values = np.concatenate( - (grouping_property_values, [prop_name] * len(sorting.get_unit_ids())) - ) + job_list = [] + for k, rec in recording_dict.items(): + job = dict( + sorter_name=sorter_name, + recording=rec, + output_folder=working_folder / str(k), + verbose=verbose, + docker_image=docker_image, + singularity_image=singularity_image, + **sorter_params, + ) + job_list.append(job) + + sorting_list = run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=True) + + unit_groups = [] + for sorting, group in zip(sorting_list, recording_dict.keys()): + num_units = sorting.get_unit_ids().size + unit_groups.extend([group] * num_units) + unit_groups = np.array(unit_groups) aggregate_sorting = aggregate_units(sorting_list) - aggregate_sorting.set_property(key=grouping_property, values=grouping_property_values) + aggregate_sorting.set_property(key=grouping_property, values=unit_groups) aggregate_sorting.register_recording(recording) return aggregate_sorting +# This is deprecated and will be removed def run_sorters( sorter_list, recording_dict_or_list, @@ -180,7 +319,9 @@ def run_sorters( docker_images={}, singularity_images={}, ): - """Run several sorter on several recordings. + """ + This function is deprecated and will be removed in version 0.100 + Please use run_sorter_jobs() instead. Parameters ---------- @@ -221,6 +362,13 @@ def run_sorters( results : dict The output is nested dict[(rec_name, sorter_name)] of SortingExtractor. """ + + warnings.warn( + "run_sorters() is deprecated please use run_sorter_jobs() instead. This will be removed in 0.100", + DeprecationWarning, + stacklevel=2, + ) + working_folder = Path(working_folder) mode_if_folder_exists in ("raise", "keep", "overwrite") @@ -247,8 +395,7 @@ def run_sorters( dtype_rec_name = np.dtype(type(list(recording_dict.keys())[0])) assert dtype_rec_name.kind in ("i", "u", "S", "U"), "Dict keys can only be integers or strings!" - need_dump = engine != "loop" - task_args_list = [] + job_list = [] for rec_name, recording in recording_dict.items(): for sorter_name in sorter_list: output_folder = working_folder / str(rec_name) / sorter_name @@ -268,181 +415,21 @@ def run_sorters( params = sorter_params.get(sorter_name, {}) docker_image = docker_images.get(sorter_name, None) singularity_image = singularity_images.get(sorter_name, None) - _check_container_images(docker_image, singularity_image, sorter_name) - - if need_dump: - if not recording.check_if_dumpable(): - raise Exception("recording not dumpable call recording.save() before") - recording_arg = recording.to_dict(recursive=True) - else: - recording_arg = recording - - task_args = ( - sorter_name, - recording_arg, - output_folder, - verbose, - params, - docker_image, - singularity_image, - with_output, - ) - task_args_list.append(task_args) - if engine == "loop": - # simple loop in main process - for task_args in task_args_list: - _run_one(task_args) - - elif engine == "joblib": - from joblib import Parallel, delayed - - n_jobs = engine_kwargs.get("n_jobs", -1) - backend = engine_kwargs.get("backend", "loky") - Parallel(n_jobs=n_jobs, backend=backend)(delayed(_run_one)(task_args) for task_args in task_args_list) - - elif engine == "dask": - client = engine_kwargs.get("client", None) - assert client is not None, "For dask engine you have to provide : client = dask.distributed.Client(...)" - - tasks = [] - for task_args in task_args_list: - task = client.submit(_run_one, task_args) - tasks.append(task) - - for task in tasks: - task.result() - - elif engine == "slurm": - # generate python script for slurm - tmp_script_folder = engine_kwargs.get("tmp_script_folder", None) - if tmp_script_folder is None: - tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_") - tmp_script_folder = Path(tmp_script_folder) - cpus_per_task = engine_kwargs.get("cpus_per_task", 1) - mem = engine_kwargs.get("mem", "1G") - - for i, task_args in enumerate(task_args_list): - script_name = tmp_script_folder / f"si_script_{i}.py" - with open(script_name, "w") as f: - arg_list_txt = "(\n" - for j, arg in enumerate(task_args): - arg_list_txt += "\t" - if j != 1: - if isinstance(arg, str): - arg_list_txt += f'"{arg}"' - elif isinstance(arg, Path): - arg_list_txt += f'"{str(arg.absolute())}"' - else: - arg_list_txt += f"{arg}" - else: - arg_list_txt += "recording" - arg_list_txt += ",\r" - arg_list_txt += ")" - - recording_dict = task_args[1] - slurm_script = _slurm_script.format( - python=sys.executable, recording_dict=recording_dict, arg_list_txt=arg_list_txt - ) - f.write(slurm_script) - os.fchmod(f.fileno(), mode=stat.S_IRWXU) - - print(slurm_script) - - subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"]) + job = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=output_folder, + verbose=verbose, + docker_image=docker_image, + singularity_image=singularity_image, + **params, + ) + job_list.append(job) - non_blocking_engine = ("loop", "joblib") - if engine in non_blocking_engine: - # dump spikeinterface_job.json - # only for non blocking engine - for rec_name, recording in recording_dict.items(): - for sorter_name in sorter_list: - output_folder = working_folder / str(rec_name) / sorter_name - with open(output_folder / "spikeinterface_job.json", "w") as f: - dump_dict = {"rec_name": rec_name, "sorter_name": sorter_name, "engine": engine} - if engine != "dask": - dump_dict.update({"engine_kwargs": engine_kwargs}) - json.dump(check_json(dump_dict), f) + sorting_list = run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=with_output) if with_output: - if engine not in non_blocking_engine: - print( - f'Warning!! With engine="{engine}" you cannot have directly output results\n' - "Use : run_sorters(..., with_output=False)\n" - "And then: results = collect_sorting_outputs(output_folders)" - ) - return - - results = collect_sorting_outputs(working_folder) + keys = [(rec_name, sorter_name) for rec_name in recording_dict for sorter_name in sorter_list] + results = dict(zip(keys, sorting_list)) return results - - -_slurm_script = """#! {python} -from numpy import array -from spikeinterface.sorters.launcher import _run_one - -recording = {recording_dict} - -arg_list = {arg_list_txt} - -_run_one(arg_list) -""" - - -def is_log_ok(output_folder): - # log is OK when run_time is not None - if (output_folder / "spikeinterface_log.json").is_file(): - with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - ok = run_time is not None - return ok - return False - - -def iter_working_folder(working_folder): - working_folder = Path(working_folder) - for rec_folder in working_folder.iterdir(): - if not rec_folder.is_dir(): - continue - for output_folder in rec_folder.iterdir(): - if (output_folder / "spikeinterface_job.json").is_file(): - with open(output_folder / "spikeinterface_job.json", "r") as f: - job_dict = json.load(f) - rec_name = job_dict["rec_name"] - sorter_name = job_dict["sorter_name"] - yield rec_name, sorter_name, output_folder - else: - rec_name = rec_folder.name - sorter_name = output_folder.name - if not output_folder.is_dir(): - continue - if not is_log_ok(output_folder): - continue - yield rec_name, sorter_name, output_folder - - -def iter_sorting_output(working_folder): - """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" - for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): - SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) - yield rec_name, sorter_name, sorting - - -def collect_sorting_outputs(working_folder): - """Collect results in a working_folder. - - The output is a dict with double key access results[(rec_name, sorter_name)] of SortingExtractor. - """ - results = {} - for rec_name, sorter_name, sorting in iter_sorting_output(working_folder): - results[(rec_name, sorter_name)] = sorting - return results - - -def _check_container_images(docker_image, singularity_image, sorter_name): - if docker_image is not None: - assert singularity_image is None, f"Provide either a docker or a singularity image " f"for sorter {sorter_name}" - if singularity_image is not None: - assert docker_image is None, f"Provide either a docker or a singularity image " f"for sorter {sorter_name}" diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 6e6ccc0358..a49a605a75 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -91,7 +91,7 @@ def run_sorter( sorter_name: str, recording: BaseRecording, output_folder: Optional[str] = None, - remove_existing_folder: bool = True, + remove_existing_folder: bool = False, delete_output_folder: bool = False, verbose: bool = False, raise_error: bool = True, @@ -514,19 +514,19 @@ def run_sorter_container( res_output = container_client.run_command(cmd) cmd = f"cp -r {si_dev_path_unix} {si_source_folder}" res_output = container_client.run_command(cmd) - cmd = f"pip install {si_source_folder}/spikeinterface[full]" + cmd = f"pip install --user {si_source_folder}/spikeinterface[full]" else: si_source = "remote repository" - cmd = "pip install --upgrade --no-input git+https://github.com/SpikeInterface/spikeinterface.git#egg=spikeinterface[full]" + cmd = "pip install --user --upgrade --no-input git+https://github.com/SpikeInterface/spikeinterface.git#egg=spikeinterface[full]" if verbose: print(f"Installing dev spikeinterface from {si_source}") res_output = container_client.run_command(cmd) - cmd = "pip install --upgrade --no-input https://github.com/NeuralEnsemble/python-neo/archive/master.zip" + cmd = "pip install --user --upgrade --no-input https://github.com/NeuralEnsemble/python-neo/archive/master.zip" res_output = container_client.run_command(cmd) else: if verbose: print(f"Installing spikeinterface=={si_version} in {container_image}") - cmd = f"pip install --upgrade --no-input spikeinterface[full]=={si_version}" + cmd = f"pip install --user --upgrade --no-input spikeinterface[full]=={si_version}" res_output = container_client.run_command(cmd) else: # TODO version checking @@ -540,7 +540,7 @@ def run_sorter_container( if extra_requirements: if verbose: print(f"Installing extra requirements: {extra_requirements}") - cmd = f"pip install --upgrade --no-input {' '.join(extra_requirements)}" + cmd = f"pip install --user --upgrade --no-input {' '.join(extra_requirements)}" res_output = container_client.run_command(cmd) # run sorter on folder @@ -624,10 +624,20 @@ def run_sorter_container( ) -def read_sorter_folder(output_folder, raise_error=True): +def read_sorter_folder(output_folder, register_recording=True, sorting_info=True, raise_error=True): """ Load a sorting object from a spike sorting output folder. The 'output_folder' must contain a valid 'spikeinterface_log.json' file + + + Parameters + ---------- + output_folder: Pth or str + The sorter folder + register_recording: bool, default: True + Attach recording (when json or pickle) to the sorting + sorting_info: bool, default: True + Attach sorting info to the sorting. """ output_folder = Path(output_folder) log_file = output_folder / "spikeinterface_log.json" @@ -647,7 +657,9 @@ def read_sorter_folder(output_folder, raise_error=True): sorter_name = log["sorter_name"] SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) + sorting = SorterClass.get_result_from_folder( + output_folder, register_recording=register_recording, sorting_info=sorting_info + ) return sorting diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index cd8bc0fa5d..fdadf533f5 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -1,4 +1,5 @@ import os +import sys import shutil import time @@ -6,8 +7,10 @@ from pathlib import Path from spikeinterface.core import load_extractor -from spikeinterface.extractors import toy_example -from spikeinterface.sorters import run_sorters, run_sorter_by_property, collect_sorting_outputs + +# from spikeinterface.extractors import toy_example +from spikeinterface import generate_ground_truth_recording +from spikeinterface.sorters import run_sorter_jobs, run_sorters, run_sorter_by_property if hasattr(pytest, "global_test_folder"): @@ -15,10 +18,17 @@ else: cache_folder = Path("cache_folder") / "sorters" +base_output = cache_folder / "sorter_output" + +# no need to have many +num_recordings = 2 +sorters = ["tridesclous2"] + def setup_module(): - rec, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1) - for i in range(4): + base_seed = 42 + for i in range(num_recordings): + rec, _ = generate_ground_truth_recording(num_channels=8, durations=[10.0], seed=base_seed + i) rec_folder = cache_folder / f"toy_rec_{i}" if rec_folder.is_dir(): shutil.rmtree(rec_folder) @@ -31,19 +41,106 @@ def setup_module(): rec.save(folder=rec_folder) -def test_run_sorters_with_list(): - working_folder = cache_folder / "test_run_sorters_list" +def get_job_list(): + jobs = [] + for i in range(num_recordings): + for sorter_name in sorters: + recording = load_extractor(cache_folder / f"toy_rec_{i}") + kwargs = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=base_output / f"{sorter_name}_rec{i}", + verbose=True, + raise_error=False, + ) + jobs.append(kwargs) + + return jobs + + +@pytest.fixture(scope="module") +def job_list(): + return get_job_list() + + +def test_run_sorter_jobs_loop(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs(job_list, engine="loop", return_output=True) + print(sortings) + + +def test_run_sorter_jobs_joblib(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs( + job_list, engine="joblib", engine_kwargs=dict(n_jobs=2, backend="loky"), return_output=True + ) + print(sortings) + + +def test_run_sorter_jobs_processpoolexecutor(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs( + job_list, engine="processpoolexecutor", engine_kwargs=dict(max_workers=2), return_output=True + ) + print(sortings) + + +@pytest.mark.skipif(True, reason="This is tested locally") +def test_run_sorter_jobs_dask(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + + # create a dask Client for a slurm queue + from dask.distributed import Client + + test_mode = "local" + # test_mode = "client_slurm" + + if test_mode == "local": + client = Client() + elif test_mode == "client_slurm": + from dask_jobqueue import SLURMCluster + + cluster = SLURMCluster( + processes=1, + cores=1, + memory="12GB", + python=sys.executable, + walltime="12:00:00", + ) + cluster.scale(2) + client = Client(cluster) + + # dask + t0 = time.perf_counter() + run_sorter_jobs(job_list, engine="dask", engine_kwargs=dict(client=client)) + t1 = time.perf_counter() + print(t1 - t0) + + +@pytest.mark.skip("Slurm launcher need a machine with slurm") +def test_run_sorter_jobs_slurm(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + + working_folder = cache_folder / "test_run_sorters_slurm" if working_folder.is_dir(): shutil.rmtree(working_folder) - # make dumpable - rec0 = load_extractor(cache_folder / "toy_rec_0") - rec1 = load_extractor(cache_folder / "toy_rec_1") - - recording_list = [rec0, rec1] - sorter_list = ["tridesclous"] + tmp_script_folder = working_folder / "slurm_scripts" - run_sorters(sorter_list, recording_list, working_folder, engine="loop", verbose=False, with_output=False) + run_sorter_jobs( + job_list, + engine="slurm", + engine_kwargs=dict( + tmp_script_folder=tmp_script_folder, + cpus_per_task=32, + mem="32G", + ), + ) def test_run_sorter_by_property(): @@ -59,7 +156,7 @@ def test_run_sorter_by_property(): rec0_by = rec0.split_by("group") group_names0 = list(rec0_by.keys()) - sorter_name = "tridesclous" + sorter_name = "tridesclous2" sorting0 = run_sorter_by_property(sorter_name, rec0, "group", working_folder1, engine="loop", verbose=False) assert "group" in sorting0.get_property_keys() assert all([g in group_names0 for g in sorting0.get_property("group")]) @@ -68,12 +165,31 @@ def test_run_sorter_by_property(): rec1_by = rec1.split_by("group") group_names1 = list(rec1_by.keys()) - sorter_name = "tridesclous" + sorter_name = "tridesclous2" sorting1 = run_sorter_by_property(sorter_name, rec1, "group", working_folder2, engine="loop", verbose=False) assert "group" in sorting1.get_property_keys() assert all([g in group_names1 for g in sorting1.get_property("group")]) +# run_sorters is deprecated +# This will test will be removed in next release +def test_run_sorters_with_list(): + working_folder = cache_folder / "test_run_sorters_list" + if working_folder.is_dir(): + shutil.rmtree(working_folder) + + # make serializable + rec0 = load_extractor(cache_folder / "toy_rec_0") + rec1 = load_extractor(cache_folder / "toy_rec_1") + + recording_list = [rec0, rec1] + sorter_list = ["tridesclous2"] + + run_sorters(sorter_list, recording_list, working_folder, engine="loop", verbose=False, with_output=False) + + +# run_sorters is deprecated +# This will test will be removed in next release def test_run_sorters_with_dict(): working_folder = cache_folder / "test_run_sorters_dict" if working_folder.is_dir(): @@ -84,9 +200,9 @@ def test_run_sorters_with_dict(): recording_dict = {"toy_tetrode": rec0, "toy_octotrode": rec1} - sorter_list = ["tridesclous", "tridesclous2"] + sorter_list = ["tridesclous2"] - sorter_params = {"tridesclous": dict(detect_threshold=5.6), "tridesclous2": dict()} + sorter_params = {"tridesclous2": dict()} # simple loop t0 = time.perf_counter() @@ -116,143 +232,19 @@ def test_run_sorters_with_dict(): ) -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_joblib(): - working_folder = cache_folder / "test_run_sorters_joblib" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "tridesclous", - ] - - # joblib - t0 = time.perf_counter() - run_sorters( - sorter_list, - recording_dict, - working_folder / "with_joblib", - engine="joblib", - engine_kwargs={"n_jobs": 4}, - with_output=False, - mode_if_folder_exists="keep", - ) - t1 = time.perf_counter() - print(t1 - t0) - - -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_dask(): - working_folder = cache_folder / "test_run_sorters_dask" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "tridesclous", - ] - - # create a dask Client for a slurm queue - from dask.distributed import Client - from dask_jobqueue import SLURMCluster - - python = "/home/samuel.garcia/.virtualenvs/py36/bin/python3.6" - cluster = SLURMCluster( - processes=1, - cores=1, - memory="12GB", - python=python, - walltime="12:00:00", - ) - cluster.scale(5) - client = Client(cluster) - - # dask - t0 = time.perf_counter() - run_sorters( - sorter_list, - recording_dict, - working_folder, - engine="dask", - engine_kwargs={"client": client}, - with_output=False, - mode_if_folder_exists="keep", - ) - t1 = time.perf_counter() - print(t1 - t0) - - -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_slurm(): - working_folder = cache_folder / "test_run_sorters_slurm" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - # create recording - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "spykingcircus2", - "tridesclous2", - ] - - tmp_script_folder = working_folder / "slurm_scripts" - tmp_script_folder.mkdir(parents=True) - - run_sorters( - sorter_list, - recording_dict, - working_folder, - engine="slurm", - engine_kwargs={ - "tmp_script_folder": tmp_script_folder, - "cpus_per_task": 32, - "mem": "32G", - }, - with_output=False, - mode_if_folder_exists="keep", - verbose=True, - ) - - -def test_collect_sorting_outputs(): - working_folder = cache_folder / "test_run_sorters_dict" - results = collect_sorting_outputs(working_folder) - print(results) - - -def test_sorter_installation(): - # This import is to get error on github when import fails - import tridesclous - - # import circus - - if __name__ == "__main__": setup_module() - # pass - # test_run_sorters_with_list() - - # test_run_sorter_by_property() + job_list = get_job_list() - test_run_sorters_with_dict() + test_run_sorter_jobs_loop(job_list) + # test_run_sorter_jobs_joblib(job_list) + # test_run_sorter_jobs_processpoolexecutor(job_list) + # test_run_sorter_jobs_multiprocessing(job_list) + # test_run_sorter_jobs_dask(job_list) + # test_run_sorter_jobs_slurm(job_list) - # test_run_sorters_joblib() - - # test_run_sorters_dask() - - # test_run_sorters_slurm() + # test_run_sorter_by_property() - # test_collect_sorting_outputs() + # this deprecated + # test_run_sorters_with_list() + # test_run_sorters_with_dict() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index d68b8e5449..bd413417bf 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -524,7 +524,7 @@ def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5) template_real = template_real.reshape(template_real.size, 1).T if metric == "cosine": - dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real, metric).flatten().tolist() + dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real).flatten().tolist() else: dist = sklearn.metrics.pairwise_distances(template, template_real, metric).flatten().tolist() res += dist diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 07c7db155c..4efabbc9c5 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -502,7 +502,7 @@ def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine seg_num = 0 # TODO: make compatible with multiple segments idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] if len(intersection) == 0: print(f"No {label}s found for unit {unit_id}") @@ -552,7 +552,7 @@ def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cos for label in ["TP", "FN"]: idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] wfs_sliced = wfs[intersection, :, :] @@ -600,29 +600,38 @@ def plot_comparison_matching( else: ax = axs[j] comp1, comp2 = comp_per_method[method1], comp_per_method[method2] - for performance, color in zip(performance_names, colors): - perf1 = comp1.get_performance()[performance] - perf2 = comp2.get_performance()[performance] - ax.plot(perf2, perf1, ".", label=performance, color=color) - ax.plot([0, 1], [0, 1], "k--", alpha=0.5) - ax.set_ylim(ylim) - ax.set_xlim(ylim) - ax.spines[["right", "top"]].set_visible(False) - ax.set_aspect("equal") - - if j == 0: - ax.set_ylabel(f"{method1}") - else: - ax.set_yticks([]) - if i == num_methods - 1: - ax.set_xlabel(f"{method2}") + if i <= j: + for performance, color in zip(performance_names, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.plot(perf2, perf1, ".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + if j == i: + ax.set_ylabel(f"{method1}") + else: + ax.set_yticks([]) + if i == j: + ax.set_xlabel(f"{method2}") + else: + ax.set_xticks([]) + if i == num_methods - 1 and j == num_methods - 1: + patches = [] + for color, name in zip(colors, performance_names): + patches.append(mpatches.Patch(color=color, label=name)) + ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) else: + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) ax.set_xticks([]) - if i == num_methods - 1 and j == num_methods - 1: - patches = [] - for color, name in zip(colors, performance_names): - patches.append(mpatches.Patch(color=color, label=name)) - ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) + ax.set_yticks([]) plt.tight_layout(h_pad=0, w_pad=0) return fig, axs diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index dd35670abd..abf40b2da6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -487,7 +487,7 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - axes[0].plot(benchmark.temporal_bins, mean_error, label=benchmark.title, color=c) + axes[0].plot(benchmark.temporal_bins, mean_error, lw=1, label=benchmark.title, color=c) parts = axes[1].violinplot(mean_error, [count], showmeans=True) if c is not None: for pc in parts["bodies"]: @@ -500,8 +500,8 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo axes[2].plot(benchmark.spatial_bins, depth_error, label=benchmark.title, color=c) ax0 = ax = axes[0] - ax.set_xlabel("time [s]") - ax.set_ylabel("error [um]") + ax.set_xlabel("Time [s]") + ax.set_ylabel("Error [μm]") if show_legend: ax.legend() _simpleaxis(ax) @@ -514,7 +514,7 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo ax2 = axes[2] ax2.set_yticks([]) - ax2.set_xlabel("depth [um]") + ax2.set_xlabel("Depth [μm]") # ax.set_ylabel('error') channel_positions = benchmark.recording.get_channel_locations() probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() @@ -584,23 +584,28 @@ def plot_motions_several_benchmarks(benchmarks): _simpleaxis(ax) -def plot_speed_several_benchmarks(benchmarks, ax=None, colors=None): +def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): if ax is None: fig, ax = plt.subplots(figsize=(5, 5)) for count, benchmark in enumerate(benchmarks): color = colors[count] if colors is not None else None - bottom = 0 - i = 0 - patterns = ["/", "\\", "|", "*"] - for key, value in benchmark.run_times.items(): - if count == 0: - label = key.replace("_", " ") - else: - label = None - ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i]) - bottom += value - i += 1 + + if detailed: + bottom = 0 + i = 0 + patterns = ["/", "\\", "|", "*"] + for key, value in benchmark.run_times.items(): + if count == 0: + label = key.replace("_", " ") + else: + label = None + ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i]) + bottom += value + i += 1 + else: + total_run_time = np.sum([value for key, value in benchmark.run_times.items()]) + ax.bar([count], [total_run_time], color=color, edgecolor="black") # ax.legend() ax.set_ylabel("speed (s)") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index 13a64e8168..b28b29f17c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -9,7 +9,7 @@ from spikeinterface.extractors import read_mearec from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten -from spikeinterface.sorters import run_sorter +from spikeinterface.sorters import run_sorter, read_sorter_folder from spikeinterface.widgets import plot_unit_waveforms, plot_gt_performances from spikeinterface.comparison import GroundTruthComparison @@ -184,7 +184,7 @@ def extract_waveforms(self): we.run_extract_waveforms(seed=22051977, **self.job_kwargs) self.waveforms[key] = we - def run_sorters(self): + def run_sorters(self, skip_already_done=True): for case in self.sorter_cases: label = case["label"] print("run sorter", label) @@ -192,9 +192,17 @@ def run_sorters(self): sorter_params = case["sorter_params"] recording = self.recordings[case["recording"]] output_folder = self.folder / f"tmp_sortings_{label}" - sorting = run_sorter( - sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder - ) + if output_folder.exists() and skip_already_done: + print("already done") + sorting = read_sorter_folder(output_folder) + else: + sorting = run_sorter( + sorter_name, + recording, + output_folder, + **sorter_params, + delete_output_folder=self.delete_output_folder, + ) self.sortings[label] = sorting def compute_distances_to_static(self, force=False): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1514a63dd4..73497a59fd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -133,7 +133,7 @@ def run(self, peaks=None, positions=None, delta=0.2): matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] - garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) + garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) garbage_channels = self.peaks["channel_index"][garbage_matches] garbage_peaks = times2[garbage_matches] nb_garbage = len(garbage_peaks) @@ -365,7 +365,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["full_gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["full_gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.gt_peaks["amplitude"][mask]) ax.scatter(self.gt_positions["x"][mask], self.gt_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.gt_positions["x"][mask].mean(), self.gt_positions["y"][mask].mean()) @@ -391,7 +391,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.sliced_gt_peaks["amplitude"][mask]) ax.scatter( self.sliced_gt_positions["x"][mask], self.sliced_gt_positions["y"][mask], c=colors, s=1, alpha=0.5 @@ -420,7 +420,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["garbage"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["garbage"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.garbage_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.garbage_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.garbage_peaks["amplitude"][mask]) ax.scatter(self.garbage_positions["x"][mask], self.garbage_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.garbage_positions["x"][mask].mean(), self.garbage_positions["y"][mask].mean()) diff --git a/src/spikeinterface/sortingcomponents/clustering/clean.py b/src/spikeinterface/sortingcomponents/clustering/clean.py new file mode 100644 index 0000000000..cbded0c49f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/clean.py @@ -0,0 +1,45 @@ +import numpy as np + +from .tools import FeaturesLoader, compute_template_from_sparse + +# This is work in progress ... + + +def clean_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + peak_sign="neg", +): + total_channels = recording.get_num_channels() + + if isinstance(features_dict_or_folder, dict): + features = features_dict_or_folder + else: + features = FeaturesLoader(features_dict_or_folder) + + clean_labels = peak_labels.copy() + + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + + count = np.zeros(n, dtype="int64") + for i, label in enumerate(labels_set): + count[i] = np.sum(peak_labels == label) + print(count) + + templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels) + + if peak_sign == "both": + max_values = np.max(np.abs(templates), axis=(1, 2)) + elif peak_sign == "neg": + max_values = -np.min(templates, axis=(1, 2)) + elif peak_sign == "pos": + max_values = np.max(templates, axis=(1, 2)) + print(max_values) + + return clean_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6edf5af16b..b4938717f8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -30,7 +30,7 @@ def _split_waveforms( local_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) - local_labels_with_noise[~np.in1d(local_labels_with_noise, persistent_clusters)] = -1 + local_labels_with_noise[~np.isin(local_labels_with_noise, persistent_clusters)] = -1 # remove super small cluster labels, count = np.unique(local_labels_with_noise[:valid_size], return_counts=True) @@ -43,7 +43,7 @@ def _split_waveforms( to_remove = labels[(count / valid_size) < minimum_cluster_size_ratio] # ~ print('to_remove', to_remove, count / valid_size) if to_remove.size > 0: - local_labels_with_noise[np.in1d(local_labels_with_noise, to_remove)] = -1 + local_labels_with_noise[np.isin(local_labels_with_noise, to_remove)] = -1 local_labels_with_noise[valid_size:] = -2 @@ -123,7 +123,7 @@ def _split_waveforms_nested( active_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) - active_labels_with_noise[~np.in1d(active_labels_with_noise, persistent_clusters)] = -1 + active_labels_with_noise[~np.isin(active_labels_with_noise, persistent_clusters)] = -1 active_labels = active_labels_with_noise[active_ind < valid_size] active_labels_set = np.unique(active_labels) @@ -381,9 +381,9 @@ def auto_clean_clustering( continue wfs0 = wfs_arrays[label0] - wfs0 = wfs0[:, :, np.in1d(channel_inds0, used_chans)] + wfs0 = wfs0[:, :, np.isin(channel_inds0, used_chans)] wfs1 = wfs_arrays[label1] - wfs1 = wfs1[:, :, np.in1d(channel_inds1, used_chans)] + wfs1 = wfs1[:, :, np.isin(channel_inds1, used_chans)] # TODO : remove assert wfs0.shape[2] == wfs1.shape[2] @@ -536,10 +536,10 @@ def remove_duplicates_via_matching( waveform_extractor, noise_levels, peak_labels, - sparsify_threshold=1, method_kwargs={}, job_kwargs={}, tmp_folder=None, + method="circus-omp-svd", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface import get_noise_levels @@ -547,11 +547,14 @@ def remove_duplicates_via_matching( from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms from spikeinterface.core import get_global_tmp_folder - from spikeinterface.sortingcomponents.matching.circus import get_scipy_shape import string, random, shutil, os from pathlib import Path job_kwargs = fix_job_kwargs(job_kwargs) + + if waveform_extractor.is_sparse(): + sparsity = waveform_extractor.sparsity.mask + templates = waveform_extractor.get_all_templates(mode="median").copy() nb_templates = len(templates) duration = waveform_extractor.nbefore + waveform_extractor.nafter @@ -559,9 +562,9 @@ def remove_duplicates_via_matching( fs = waveform_extractor.recording.get_sampling_frequency() num_chans = waveform_extractor.recording.get_num_channels() - for t in range(nb_templates): - is_silent = templates[t].ptp(0) < sparsify_threshold - templates[t, :, is_silent] = 0 + if waveform_extractor.is_sparse(): + for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): + templates[count][:, ~sparsity[count]] = 0 zdata = templates.reshape(nb_templates, -1) @@ -571,6 +574,8 @@ def remove_duplicates_via_matching( if tmp_folder is None: tmp_folder = get_global_tmp_folder() + tmp_folder.mkdir(parents=True, exist_ok=True) + tmp_filename = tmp_folder / "tmp.raw" f = open(tmp_filename, "wb") @@ -580,6 +585,7 @@ def remove_duplicates_via_matching( f.close() recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") + recording = recording.set_probe(waveform_extractor.recording.get_probe()) recording.annotate(is_filtered=True) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) @@ -587,44 +593,55 @@ def remove_duplicates_via_matching( chunk_size = duration + 3 * margin - dummy_filter = np.empty((num_chans, duration), dtype=np.float32) - dummy_traces = np.empty((num_chans, chunk_size), dtype=np.float32) - - fshape, axes = get_scipy_shape(dummy_filter, dummy_traces, axes=1) + local_params = method_kwargs.copy() - method_kwargs.update( + local_params.update( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], - "sparsify_threshold": sparsify_threshold, - "omp_min_sps": 0.1, - "templates": None, - "overlaps": None, + "omp_min_sps": 0.05, } ) + spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) + indices = np.argsort(counts) + ignore_ids = [] similar_templates = [[], []] - for i in range(nb_templates): + for i in indices: t_start = padding + i * duration t_stop = padding + (i + 1) * duration sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging) - - method_kwargs.update({"ignored_ids": ignore_ids + [i]}) + local_params.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method="circus-omp", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs - ) - method_kwargs.update( - { - "overlaps": computed["overlaps"], - "templates": computed["templates"], - "norms": computed["norms"], - "sparsities": computed["sparsities"], - } + sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs ) + if method == "circus-omp-svd": + local_params.update( + { + "overlaps": computed["overlaps"], + "templates": computed["templates"], + "norms": computed["norms"], + "temporal": computed["temporal"], + "spatial": computed["spatial"], + "singular": computed["singular"], + "units_overlaps": computed["units_overlaps"], + "unit_overlaps_indices": computed["unit_overlaps_indices"], + "sparsity_mask": computed["sparsity_mask"], + } + ) + elif method == "circus-omp": + local_params.update( + { + "overlaps": computed["overlaps"], + "templates": computed["templates"], + "norms": computed["norms"], + "sparsities": computed["sparsities"], + } + ) valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) if np.sum(valid) > 0: if np.sum(valid) == 1: @@ -649,7 +666,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording + del recording, sub_recording, local_params, waveform_extractor os.remove(tmp_filename) return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py new file mode 100644 index 0000000000..d35b562298 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -0,0 +1,705 @@ +from pathlib import Path +from multiprocessing import get_context +from concurrent.futures import ProcessPoolExecutor +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +import scipy.spatial +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from hdbscan import HDBSCAN + +import numpy as np +import networkx as nx + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + + +from .isocut5 import isocut5 + +from .tools import aggregate_sparse_features, FeaturesLoader, compute_template_from_sparse + + +def merge_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + radius_um=70, + method="waveforms_lda", + method_kwargs={}, + **job_kwargs, +): + """ + Merge cluster using differents methods. + + Parameters + ---------- + peaks: numpy.ndarray 1d + detected peaks (or a subset) + peak_labels: numpy.ndarray 1d + original label before merge + peak_labels.size == peaks.size + recording: Recording object + A recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method used + method_kwargs: dict + Option for the method. + Returns + ------- + merge_peak_labels: numpy.ndarray 1d + New vectors label after merges. + peak_shifts: numpy.ndarray 1d + A vector of sample shift to be reverse applied on original sample_index on peak detection + Negative shift means too early. + Posituve shift means too late. + So the correction must be applied like this externaly: + final_peaks = peaks.copy() + final_peaks['sample_index'] -= peak_shifts + + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + + features = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set, pair_mask, pair_shift, pair_values = find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=radius_um, + method=method, + method_kwargs=method_kwargs, + **job_kwargs, + ) + + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.matshow(pair_values) + + pair_values[~pair_mask] = 20 + + import hdbscan + + fig, ax = plt.subplots() + clusterer = hdbscan.HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) + clusterer.fit(pair_values) + print(clusterer.labels_) + clusterer.single_linkage_tree_.plot(cmap="viridis", colorbar=True) + # ~ fig, ax = plt.subplots() + # ~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', + # ~ edge_alpha=0.6, + # ~ node_size=80, + # ~ edge_linewidth=2) + + graph = clusterer.single_linkage_tree_.to_networkx() + + import scipy.cluster + + fig, ax = plt.subplots() + scipy.cluster.hierarchy.dendrogram(clusterer.single_linkage_tree_.to_numpy(), ax=ax) + + import networkx as nx + + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() + + plt.show() + + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + + group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) + + # apply final label and shift + merge_peak_labels = peak_labels.copy() + peak_shifts = np.zeros(peak_labels.size, dtype="int64") + for merge, shifts in zip(merges, group_shifts): + label0 = merge[0] + mask = np.in1d(peak_labels, merge) + merge_peak_labels[mask] = label0 + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + peak_shifts[peak_labels == label1] = shifts[l] + + return merge_peak_labels, peak_shifts + + +def resolve_final_shifts(labels_set, merges, pair_mask, pair_shift): + labels_set = list(labels_set) + + group_shifts = [] + for merge in merges: + shifts = np.zeros(len(merge), dtype="int64") + + label_inds = [labels_set.index(label) for label in merge] + + label0 = merge[0] + ind0 = label_inds[0] + + # First find relative shift to label0 (l=0) in the subgraph + local_pair_mask = pair_mask[label_inds, :][:, label_inds] + local_pair_shift = None + G = None + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + ind1 = label_inds[l] + if local_pair_mask[0, l]: + # easy case the pair label0<>label1 was existing + shift = pair_shift[ind0, ind1] + else: + # more complicated case need to find intermediate label and propagate the shift!! + if G is None: + # the the graph only once and only if needed + G = nx.from_numpy_array(local_pair_mask | local_pair_mask.T) + local_pair_shift = pair_shift[label_inds, :][:, label_inds] + local_pair_shift += local_pair_shift.T + + shift_chain = nx.shortest_path(G, source=l, target=0) + shift = 0 + for i in range(len(shift_chain) - 1): + shift += local_pair_shift[shift_chain[i + 1], shift_chain[i]] + shifts[l] = shift + + group_shifts.append(shifts) + + return group_shifts + + +def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"): + """ + Agglomerate merge pairs into final merge groups. + + The merges are ordered by label. + + """ + + labels_set = np.array(labels_set) + + merges = [] + + graph = nx.from_numpy_array(pair_mask | pair_mask.T) + # put real nodes names for debugging + maps = dict(zip(np.arange(labels_set.size), labels_set)) + graph = nx.relabel_nodes(graph, maps) + + groups = list(nx.connected_components(graph)) + for group in groups: + if len(group) == 1: + continue + sub_graph = graph.subgraph(group) + # print(group, sub_graph) + cliques = list(nx.find_cliques(sub_graph)) + if len(cliques) == 1 and len(cliques[0]) == len(group): + # the sub graph is full connected: no ambiguity + # merges.append(labels_set[cliques[0]]) + merges.append(cliques[0]) + elif len(cliques) > 1: + # the subgraph is not fully connected + if connection_mode == "full": + # node merge + pass + elif connection_mode == "partial": + group = list(group) + # merges.append(labels_set[group]) + merges.append(group) + elif connection_mode == "clique": + raise NotImplementedError + else: + raise ValueError + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(sub_graph) + plt.show() + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() + + # ensure ordered label + merges = [np.sort(merge) for merge in merges] + + return merges + + +def find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=70, + method="project_distribution", + method_kwargs={}, + **job_kwargs + # n_jobs=1, + # mp_context="fork", + # max_threads_per_process=1, + # progress_bar=True, +): + """ + Searh some possible merge 2 by 2. + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + # features_dict_or_folder = Path(features_dict_or_folder) + + # peaks = features_dict_or_folder['peaks'] + total_channels = recording.get_num_channels() + + # sparse_wfs = features['sparse_wfs'] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + pair_mask = np.triu(np.ones((n, n), dtype="bool")) & ~np.eye(n, dtype="bool") + pair_shift = np.zeros((n, n), dtype="int64") + pair_values = np.zeros((n, n), dtype="float64") + + # compute template (no shift at this step) + + templates = compute_template_from_sparse( + peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels, peak_shifts=None + ) + + max_chans = np.argmax(np.max(np.abs(templates), axis=1), axis=1) + + channel_locs = recording.get_channel_locations() + template_locs = channel_locs[max_chans, :] + template_dist = scipy.spatial.distance.cdist(template_locs, template_locs, metric="euclidean") + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + progress_bar = job_kwargs["progress_bar"] + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=find_pair_worker_init, + mp_context=get_context(mp_context), + initargs=( + recording, + features_dict_or_folder, + peak_labels, + labels_set, + templates, + method, + method_kwargs, + max_threads_per_process, + ), + ) as pool: + jobs = [] + for ind0, ind1 in zip(indices0, indices1): + label0 = labels_set[ind0] + label1 = labels_set[ind1] + jobs.append(pool.submit(find_pair_function_wrapper, label0, label1)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"find_merge_pairs with {method}", total=len(jobs)) + else: + iterator = jobs + + for res in iterator: + is_merge, label0, label1, shift, merge_value = res.result() + ind0 = labels_set.index(label0) + ind1 = labels_set.index(label1) + + pair_mask[ind0, ind1] = is_merge + if is_merge: + pair_shift[ind0, ind1] = shift + pair_values[ind0, ind1] = merge_value + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + return labels_set, pair_mask, pair_shift, pair_values + + +def find_pair_worker_init( + recording, + features_dict_or_folder, + original_labels, + labels_set, + templates, + method, + method_kwargs, + max_threads_per_process, +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + _ctx["original_labels"] = original_labels + _ctx["labels_set"] = labels_set + _ctx["templates"] = templates + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = find_pair_method_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + + # if isinstance(features_dict_or_folder, dict): + # _ctx["features"] = features_dict_or_folder + # else: + # _ctx["features"] = FeaturesLoader(features_dict_or_folder) + + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def find_pair_function_wrapper(label0, label1): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( + label0, + label1, + _ctx["labels_set"], + _ctx["templates"], + _ctx["original_labels"], + _ctx["peaks"], + _ctx["features"], + **_ctx["method_kwargs"], + ) + + return is_merge, label0, label1, shift, merge_value + + +class ProjectDistribution: + """ + This method is a refactorized mix between: + * old tridesclous code + * some ideas by Charlie Windolf in spikespvae + + The idea is : + * project the waveform (or features) samples on a 1d axis (using LDA for instance). + * check that it is the same or not distribution (diptest, distrib_overlap, ...) + + + """ + + name = "project_distribution" + + @staticmethod + def merge( + label0, + label1, + labels_set, + templates, + original_labels, + peaks, + features, + waveforms_sparse_mask=None, + feature_name="sparse_tsvd", + projection="centroid", + criteria="diptest", + threshold_diptest=0.5, + threshold_percentile=80.0, + threshold_overlap=0.4, + min_cluster_size=50, + num_shift=2, + ): + if num_shift > 0: + assert feature_name == "sparse_wfs" + sparse_wfs = features[feature_name] + + assert waveforms_sparse_mask is not None + + (inds0,) = np.nonzero(original_labels == label0) + chans0 = np.unique(peaks["channel_index"][inds0]) + target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0)) + + (inds1,) = np.nonzero(original_labels == label1) + chans1 = np.unique(peaks["channel_index"][inds1]) + target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + + if inds0.size < min_cluster_size or inds1.size < min_cluster_size: + is_merge = False + merge_value = 0 + final_shift = 0 + return is_merge, label0, label1, final_shift, merge_value + + target_chans = np.intersect1d(target_chans0, target_chans1) + + inds = np.concatenate([inds0, inds1]) + labels = np.zeros(inds.size, dtype="int") + labels[inds0.size :] = 1 + wfs, out = aggregate_sparse_features(peaks, inds, sparse_wfs, waveforms_sparse_mask, target_chans) + wfs = wfs[~out] + labels = labels[~out] + + cut = np.searchsorted(labels, 1) + wfs0_ = wfs[:cut, :, :] + wfs1_ = wfs[cut:, :, :] + + template0_ = np.mean(wfs0_, axis=0) + template1_ = np.mean(wfs1_, axis=0) + num_samples = template0_.shape[0] + + template0 = template0_[num_shift : num_samples - num_shift, :] + + wfs0 = wfs0_[:, num_shift : num_samples - num_shift, :] + + # best shift strategy 1 = max cosine + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # norm = np.linalg.norm(template0.flatten()) * np.linalg.norm(template1.flatten()) + # value = np.sum(template0.flatten() * template1.flatten()) / norm + # values.append(value) + # best_shift = np.argmax(values) + + # best shift strategy 2 = min dist**2 + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # value = np.sum((template1 - template0)**2) + # values.append(value) + # best_shift = np.argmin(values) + + # best shift strategy 3 : average delta argmin between channels + channel_shift = np.argmax(np.abs(template1_), axis=0) - np.argmax(np.abs(template0_), axis=0) + mask = np.abs(channel_shift) <= num_shift + channel_shift = channel_shift[mask] + if channel_shift.size > 0: + best_shift = int(np.round(np.mean(channel_shift))) + num_shift + else: + best_shift = num_shift + + wfs1 = wfs1_[:, best_shift : best_shift + template0.shape[0], :] + template1 = template1_[best_shift : best_shift + template0.shape[0], :] + + if projection == "lda": + wfs_0_1 = np.concatenate([wfs0, wfs1], axis=0) + flat_wfs = wfs_0_1.reshape(wfs_0_1.shape[0], -1) + feat = LinearDiscriminantAnalysis(n_components=1).fit_transform(flat_wfs, labels) + feat = feat[:, 0] + feat0 = feat[:cut] + feat1 = feat[cut:] + + elif projection == "centroid": + vector_0_1 = template1 - template0 + vector_0_1 /= np.sum(vector_0_1**2) + feat0 = np.sum((wfs0 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat1 = np.sum((wfs1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + # feat = np.sum((wfs_0_1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat = np.concatenate([feat0, feat1], axis=0) + + else: + raise ValueError(f"bad projection {projection}") + + if criteria == "diptest": + dipscore, cutpoint = isocut5(feat) + is_merge = dipscore < threshold_diptest + merge_value = dipscore + elif criteria == "percentile": + l0 = np.percentile(feat0, threshold_percentile) + l1 = np.percentile(feat1, 100.0 - threshold_percentile) + is_merge = l0 >= l1 + merge_value = l0 - l1 + elif criteria == "distrib_overlap": + lim0 = min(np.min(feat0), np.min(feat1)) + lim1 = max(np.max(feat0), np.max(feat1)) + bin_size = (lim1 - lim0) / 200.0 + bins = np.arange(lim0, lim1, bin_size) + + pdf0, _ = np.histogram(feat0, bins=bins, density=True) + pdf1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 *= bin_size + pdf1 *= bin_size + overlap = np.sum(np.minimum(pdf0, pdf1)) + + is_merge = overlap >= threshold_overlap + + merge_value = 1 - overlap + + else: + raise ValueError(f"bad criteria {criteria}") + + if is_merge: + final_shift = best_shift - num_shift + else: + final_shift = 0 + + # DEBUG = True + DEBUG = False + + # if DEBUG and is_merge: + # if DEBUG and (overlap > 0.1 and overlap <0.3): + if DEBUG: + # if DEBUG and not is_merge: + # if DEBUG and (overlap > 0.05 and overlap <0.25): + # if label0 == 49 and label1== 65: + import matplotlib.pyplot as plt + + flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) + flatten_wfs1 = wfs1.swapaxes(1, 2).reshape(wfs1.shape[0], -1) + + fig, axs = plt.subplots(ncols=2) + ax = axs[0] + ax.plot(flatten_wfs0.T, color="C0", alpha=0.01) + ax.plot(flatten_wfs1.T, color="C1", alpha=0.01) + m0 = np.mean(flatten_wfs0, axis=0) + m1 = np.mean(flatten_wfs1, axis=0) + ax.plot(m0, color="C0", alpha=1, lw=4, label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", alpha=1, lw=4, label=f"{label1} {inds1.size}") + + ax.legend() + + bins = np.linspace(np.percentile(feat, 1), np.percentile(feat, 99), 100) + bin_size = bins[1] - bins[0] + count0, _ = np.histogram(feat0, bins=bins, density=True) + count1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 = count0 * bin_size + pdf1 = count1 * bin_size + + ax = axs[1] + ax.plot(bins[:-1], pdf0, color="C0") + ax.plot(bins[:-1], pdf1, color="C1") + + if criteria == "diptest": + ax.set_title(f"{dipscore:.4f} {is_merge}") + elif criteria == "percentile": + ax.set_title(f"{l0:.4f} {l1:.4f} {is_merge}") + ax.axvline(l0, color="C0") + ax.axvline(l1, color="C1") + elif criteria == "distrib_overlap": + print( + lim0, + lim1, + ) + ax.set_title(f"{overlap:.4f} {is_merge}") + ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls="--", color="k") + + plt.show() + + return is_merge, label0, label1, final_shift, merge_value + + +class NormalizedTemplateDiff: + """ + Compute the normalized (some kind of) template differences. + And merge if below a threhold. + Do this at several shift. + + """ + + name = "normalized_template_diff" + + @staticmethod + def merge( + label0, + label1, + labels_set, + templates, + original_labels, + peaks, + features, + waveforms_sparse_mask=None, + threshold_diff=0.05, + min_cluster_size=50, + num_shift=5, + ): + assert waveforms_sparse_mask is not None + + (inds0,) = np.nonzero(original_labels == label0) + chans0 = np.unique(peaks["channel_index"][inds0]) + target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0)) + + (inds1,) = np.nonzero(original_labels == label1) + chans1 = np.unique(peaks["channel_index"][inds1]) + target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + + # if inds0.size < min_cluster_size or inds1.size < min_cluster_size: + # is_merge = False + # merge_value = 0 + # final_shift = 0 + # return is_merge, label0, label1, final_shift, merge_value + + target_chans = np.intersect1d(target_chans0, target_chans1) + union_chans = np.union1d(target_chans0, target_chans1) + + ind0 = list(labels_set).index(label0) + template0 = templates[ind0][:, target_chans] + + ind1 = list(labels_set).index(label1) + template1 = templates[ind1][:, target_chans] + + num_samples = template0.shape[0] + # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) + norm = np.mean(np.abs(template0) + np.abs(template1)) + all_shift_diff = [] + for shift in range(-num_shift, num_shift + 1): + temp0 = template0[num_shift : num_samples - num_shift, :] + temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] + d = np.mean(np.abs(temp0 - temp1)) / (norm) + all_shift_diff.append(d) + normed_diff = np.min(all_shift_diff) + + is_merge = normed_diff < threshold_diff + if is_merge: + merge_value = normed_diff + final_shift = np.argmin(all_shift_diff) - num_shift + else: + final_shift = 0 + merge_value = np.nan + + # DEBUG = False + DEBUG = True + if DEBUG and normed_diff < 0.2: + # if DEBUG: + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + + m0 = template0.flatten() + m1 = template1.flatten() + + ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") + + ax.set_title( + f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}" + ) + ax.legend() + plt.show() + + return is_merge, label0, label1, final_shift, merge_value + + +find_pair_method_list = [ + ProjectDistribution, + NormalizedTemplateDiff, +] +find_pair_method_dict = {e.name: e for e in find_pair_method_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index fcbcac097f..72acd49f4f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -18,7 +18,14 @@ from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms -from spikeinterface.sortingcomponents.features_from_peaks import compute_features_from_peaks, EnergyFeature +from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser +from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, +) class RandomProjectionClustering: @@ -34,18 +41,18 @@ class RandomProjectionClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, + "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, - "max_spikes_per_unit": 200, "selection_method": "closest_to_centroid", - "nb_projections": {"ptp": 8, "energy": 2}, - "ms_before": 1.5, - "ms_after": 1.5, + "nb_projections": 10, + "ms_before": 1, + "ms_after": 1, "random_seed": 42, - "cleaning_method": "matching", - "shared_memory": False, - "min_values": {"ptp": 0, "energy": 0}, + "noise_levels": None, + "smoothing_kwargs": {"window_length_ms": 0.25}, + "shared_memory": True, "tmp_folder": None, - "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "10M", "verbose": True, "progress_bar": True}, + "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } @classmethod @@ -71,60 +78,93 @@ def main_function(cls, recording, peaks, params): num_samples = nbefore + nafter num_chans = recording.get_num_channels() - noise_levels = get_noise_levels(recording, return_scaled=False) + if d["noise_levels"] is None: + noise_levels = get_noise_levels(recording, return_scaled=False) + else: + noise_levels = d["noise_levels"] np.random.seed(d["random_seed"]) - features_params = {} - features_list = [] - - noise_snippets = None - - for proj_type in ["ptp", "energy"]: - if d["nb_projections"][proj_type] > 0: - features_list += [f"random_projections_{proj_type}"] - - if d["min_values"][proj_type] == "auto": - if noise_snippets is None: - num_segments = recording.get_num_segments() - num_chunks = 3 * d["max_spikes_per_unit"] // num_segments - noise_snippets = get_random_data_chunks( - recording, num_chunks_per_segment=num_chunks, chunk_size=num_samples, seed=42 - ) - noise_snippets = noise_snippets.reshape(num_chunks, num_samples, num_chans) - - if proj_type == "energy": - data = np.linalg.norm(noise_snippets, axis=1) - min_values = np.median(data, axis=0) - elif proj_type == "ptp": - data = np.ptp(noise_snippets, axis=1) - min_values = np.median(data, axis=0) - elif d["min_values"][proj_type] > 0: - min_values = d["min_values"][proj_type] - else: - min_values = None - - projections = np.random.randn(num_chans, d["nb_projections"][proj_type]) - features_params[f"random_projections_{proj_type}"] = { - "radius_um": params["radius_um"], - "projections": projections, - "min_values": min_values, - } - - features_data = compute_features_from_peaks( - recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"] + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name + else: + tmp_folder = Path(params["tmp_folder"]).absolute() + + tmp_folder.mkdir(parents=True, exist_ok=True) + + node0 = PeakRetriever(recording, peaks) + node1 = ExtractSparseWaveforms( + recording, + parents=[node0], + return_output=False, + ms_before=params["ms_before"], + ms_after=params["ms_after"], + radius_um=params["radius_um"], ) - if len(features_data) > 1: - hdbscan_data = np.hstack((features_data[0], features_data[1])) - else: - hdbscan_data = features_data[0] + node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) + + projections = np.random.randn(num_chans, d["nb_projections"]) + projections -= projections.mean(0) + projections /= projections.std(0) + + nbefore = int(params["ms_before"] * fs / 1000) + nafter = int(params["ms_after"] * fs / 1000) + nsamples = nbefore + nafter + + import scipy + + x = np.random.randn(100, nsamples, num_chans).astype(np.float32) + x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1) + + ptps = np.ptp(x, axis=1) + a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000)) + ydata = np.cumsum(a) / a.sum() + xdata = b[1:] + + from scipy.optimize import curve_fit + + def sigmoid(x, L, x0, k, b): + y = L / (1 + np.exp(-k * (x - x0))) + b + return y + + p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess + popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) + + node3 = RandomProjectionsFeature( + recording, + parents=[node0, node2], + return_output=True, + projections=projections, + radius_um=params["radius_um"], + sigmoid=None, + sparse=True, + ) + + pipeline_nodes = [node0, node1, node2, node3] + + hdbscan_data = run_node_pipeline( + recording, pipeline_nodes, params["job_kwargs"], job_name="extracting features" + ) import sklearn clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) peak_labels = clustering[0] + # peak_labels = -1 * np.ones(len(peaks), dtype=int) + # nb_clusters = 0 + # for c in np.unique(peaks['channel_index']): + # mask = peaks['channel_index'] == c + # clustering = hdbscan.hdbscan(hdbscan_data[mask], **d['hdbscan_kwargs']) + # local_labels = clustering[0] + # valid_clusters = local_labels > -1 + # if np.sum(valid_clusters) > 0: + # local_labels[valid_clusters] += nb_clusters + # peak_labels[mask] = local_labels + # nb_clusters += len(np.unique(local_labels[valid_clusters])) + labels = np.unique(peak_labels) labels = labels[labels >= 0] @@ -133,7 +173,7 @@ def main_function(cls, recording, peaks, params): all_indices = np.arange(0, peak_labels.size) - max_spikes = params["max_spikes_per_unit"] + max_spikes = params["waveforms"]["max_spikes_per_unit"] selection_method = params["selection_method"] for unit_ind in labels: @@ -160,84 +200,53 @@ def main_function(cls, recording, peaks, params): spikes["segment_index"] = peaks[mask]["segment_index"] spikes["unit_index"] = peak_labels[mask] - cleaning_method = params["cleaning_method"] - if verbose: - print("We found %d raw clusters, starting to clean with %s..." % (len(labels), cleaning_method)) - - if cleaning_method == "cosine": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - folder=None, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=True, - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates( - wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] - ) - - elif cleaning_method == "dip": - wfs_arrays = {} - for label in labels: - mask = label == peak_labels - wfs_arrays[label] = hdbscan_data[mask] - - labels, peak_labels = remove_duplicates_via_dip(wfs_arrays, peak_labels, **params["cleaning_kwargs"]) - - elif cleaning_method == "matching": - # create a tmp folder - if params["tmp_folder"] is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = get_global_tmp_folder() / name - else: - tmp_folder = Path(params["tmp_folder"]) - - if params["shared_memory"]: - waveform_folder = None - mode = "memory" - else: - waveform_folder = tmp_folder / "waveforms" - mode = "folder" - - sorting_folder = tmp_folder / "sorting" - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) + print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) + + sorting_folder = tmp_folder / "sorting" + unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) + sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) + + if params["shared_memory"]: + waveform_folder = None + mode = "memory" + else: + waveform_folder = tmp_folder / "waveforms" + mode = "folder" sorting = sorting.save(folder=sorting_folder) - we = extract_waveforms( - recording, - sorting, - waveform_folder, - ms_before=params["ms_before"], - ms_after=params["ms_after"], - **params["job_kwargs"], - return_scaled=False, - mode=mode, - ) - - cleaning_matching_params = params["job_kwargs"].copy() - cleaning_matching_params["chunk_duration"] = "100ms" - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["verbose"] = False - cleaning_matching_params["progress_bar"] = False - - cleaning_params = params["cleaning_kwargs"].copy() - cleaning_params["tmp_folder"] = tmp_folder - - labels, peak_labels = remove_duplicates_via_matching( - we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params - ) - - if params["tmp_folder"] is None: - shutil.rmtree(tmp_folder) - else: + + we = extract_waveforms( + recording, + sorting, + waveform_folder, + **params["job_kwargs"], + **params["waveforms"], + return_scaled=False, + mode=mode, + ) + + cleaning_matching_params = params["job_kwargs"].copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in cleaning_matching_params: + cleaning_matching_params.pop(value) + cleaning_matching_params["chunk_duration"] = "100ms" + cleaning_matching_params["n_jobs"] = 1 + cleaning_matching_params["verbose"] = False + cleaning_matching_params["progress_bar"] = False + + cleaning_params = params["cleaning_kwargs"].copy() + cleaning_params["tmp_folder"] = tmp_folder + + labels, peak_labels = remove_duplicates_via_matching( + we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + ) + + del we, sorting + + if params["tmp_folder"] is None: + shutil.rmtree(tmp_folder) + else: + if not params["shared_memory"]: shutil.rmtree(tmp_folder / "waveforms") shutil.rmtree(tmp_folder / "sorting") diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index aeec14158f..08ce9f6791 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -198,7 +198,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): for chan_ind in prev_local_chan_inds: if total_count[chan_ind] == 0: continue - # ~ inds, = np.nonzero(np.in1d(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) + # ~ inds, = np.nonzero(np.isin(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) (inds,) = np.nonzero((peaks["channel_index"] == chan_ind) & (peak_labels == 0)) if inds.size <= d["min_spike_on_channel"]: chan_amps[chan_ind] = 0.0 @@ -235,12 +235,12 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # TODO: only for debug, remove later - assert np.all(np.in1d(local_chan_inds, wf_chans)) + assert np.all(np.isin(local_chan_inds, wf_chans)) # none label spikes wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, local_chan_inds)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, local_chan_inds)] wfs.append(wfs_chan) # put noise to enhance clusters @@ -517,7 +517,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # print('wf_chans', wf_chans) # TODO: only for debug, remove later - assert np.all(np.in1d(wanted_chans, wf_chans)) + assert np.all(np.isin(wanted_chans, wf_chans)) wfs_chan = wfs_arrays[chan_ind] # TODO: only for debug, remove later @@ -525,7 +525,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, wanted_chans)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, wanted_chans)] wfs.append(wfs_chan) wfs = np.concatenate(wfs, axis=0) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py new file mode 100644 index 0000000000..a31e7d62fc --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -0,0 +1,280 @@ +from multiprocessing import get_context +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +from sklearn.decomposition import TruncatedSVD +from hdbscan import HDBSCAN + +import numpy as np + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + +from .tools import aggregate_sparse_features, FeaturesLoader +from .isocut5 import isocut5 + + +# important all DEBUG and matplotlib are left in the code intentionally + + +def split_clusters( + peak_labels, + recording, + features_dict_or_folder, + method="hdbscan_on_local_pca", + method_kwargs={}, + recursive=False, + recursive_depth=None, + returns_split_count=False, + **job_kwargs, +): + """ + Run recusrsively (or not) in a multi process pool a local split method. + + Parameters + ---------- + peak_labels: numpy.array + Peak label before split + recording: Recording + Recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method name + method_kwargs: dict + The method option + recursive: bool Default False + Reccursive or not. + recursive_depth: None or int + If recursive=True, then this is the max split per spikes. + returns_split_count: bool + Optionally return the split count vector. Same size as labels. + + Returns + ------- + new_labels: numpy.ndarray + The labels of peaks after split. + split_count: numpy.ndarray + Optionally returned + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs.get("mp_context", None) + progress_bar = job_kwargs["progress_bar"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + + original_labels = peak_labels + peak_labels = peak_labels.copy() + split_count = np.zeros(peak_labels.size, dtype=int) + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=split_worker_init, + mp_context=get_context(method=mp_context), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + labels_set = np.setdiff1d(peak_labels, [-1]) + current_max_label = np.max(labels_set) + 1 + + jobs = [] + for label in labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) + else: + iterator = jobs + + for res in iterator: + is_split, local_labels, peak_indices = res.result() + if not is_split: + continue + + mask = local_labels >= 0 + peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label + peak_labels[peak_indices[~mask]] = local_labels[~mask] + + split_count[peak_indices] += 1 + + current_max_label += np.max(local_labels[mask]) + 1 + + if recursive: + if recursive_depth is not None: + # stop reccursivity when recursive_depth is reach + extra_ball = np.max(split_count[peak_indices]) < recursive_depth + else: + # reccurssive always + extra_ball = True + + if extra_ball: + new_labels_set = np.setdiff1d(peak_labels[peak_indices], [-1]) + for label in new_labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + if progress_bar: + iterator.total += 1 + + if returns_split_count: + return peak_labels, split_count + else: + return peak_labels + + +global _ctx + + +def split_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + features_dict_or_folder + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = split_methods_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def split_function_wrapper(peak_indices): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_split, local_labels = _ctx["method_class"].split( + peak_indices, _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + ) + return is_split, local_labels, peak_indices + + +class LocalFeatureClustering: + """ + This method is a refactorized mix between: + * old tridesclous code + * "herding_split()" in DART/spikepsvae by Charlie Windolf + + The idea simple : + * agregate features (svd or even waveforms) with sparse channel. + * run a local feature reduction (pca or svd) + * try a new split (hdscan or isocut5) + """ + + name = "local_feature_clustering" + + @staticmethod + def split( + peak_indices, + peaks, + features, + clusterer="hdbscan", + feature_name="sparse_tsvd", + neighbours_mask=None, + waveforms_sparse_mask=None, + min_size_split=25, + min_cluster_size=25, + min_samples=25, + n_pca_features=2, + minimum_common_channels=2, + ): + local_labels = np.zeros(peak_indices.size, dtype=np.int64) + + # can be sparse_tsvd or sparse_wfs + sparse_features = features[feature_name] + + assert waveforms_sparse_mask is not None + + # target channel subset is done intersect local channels + neighbours + local_chans = np.unique(peaks["channel_index"][peak_indices]) + + target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + + # TODO fix this a better way, this when cluster have too few overlapping channels + if target_channels.size < minimum_common_channels: + return False, None + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels + ) + + local_labels[dont_have_channels] = -2 + kept = np.flatnonzero(~dont_have_channels) + + if kept.size < min_size_split: + return False, None + + aligned_wfs = aligned_wfs[kept, :, :] + + flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) + + # final_features = PCA(n_pca_features, whiten=True).fit_transform(flatten_features) + # final_features = PCA(n_pca_features, whiten=False).fit_transform(flatten_features) + final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) + + if clusterer == "hdbscan": + clust = HDBSCAN( + min_cluster_size=min_cluster_size, + min_samples=min_samples, + allow_single_cluster=True, + cluster_selection_method="leaf", + ) + clust.fit(final_features) + possible_labels = clust.labels_ + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + elif clusterer == "isocut5": + dipscore, cutpoint = isocut5(final_features[:, 0]) + possible_labels = np.zeros(final_features.shape[0]) + if dipscore > 1.5: + mask = final_features[:, 0] > cutpoint + if np.sum(mask) > min_cluster_size and np.sum(~mask): + possible_labels[mask] = 1 + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + else: + is_split = False + else: + raise ValueError(f"wrong clusterer {clusterer}") + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + labels_set = np.setdiff1d(possible_labels, [-1]) + colors = plt.get_cmap("tab10", len(labels_set)) + colors = {k: colors(i) for i, k in enumerate(labels_set)} + colors[-1] = "k" + fix, axs = plt.subplots(nrows=2) + + flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1) + + sl = slice(None, None, 10) + for k in np.unique(possible_labels): + mask = possible_labels == k + ax = axs[0] + ax.scatter(final_features[:, 0][mask][sl], final_features[:, 1][mask][sl], s=5, color=colors[k]) + + ax = axs[1] + ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) + + axs[0].set_title(f"{clusterer} {is_split}") + + plt.show() + + if not is_split: + return is_split, None + + local_labels[kept] = possible_labels + + return is_split, local_labels + + +split_methods_list = [ + LocalFeatureClustering, +] +split_methods_dict = {e.name: e for e in split_methods_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py new file mode 100644 index 0000000000..8e25c9cb7f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -0,0 +1,196 @@ +from pathlib import Path +from typing import Any +import numpy as np + + +# TODO find a way to attach a a sparse_mask to a given features (waveforms, pca, tsvd ....) + + +class FeaturesLoader: + """ + Feature can be computed in memory or in a folder contaning npy files. + + This class read the folder and behave like a dict of array lazily. + + Parameters + ---------- + feature_folder + + preload + + """ + + def __init__(self, feature_folder, preload=["peaks"]): + self.feature_folder = Path(feature_folder) + + self.file_feature = {} + self.loaded_features = {} + for file in self.feature_folder.glob("*.npy"): + name = file.stem + if name in preload: + self.loaded_features[name] = np.load(file) + else: + self.file_feature[name] = file + + def __getitem__(self, name): + if name in self.loaded_features: + return self.loaded_features[name] + else: + return np.load(self.file_feature[name], mmap_mode="r") + + @staticmethod + def from_dict_or_folder(features_dict_or_folder): + if isinstance(features_dict_or_folder, dict): + return features_dict_or_folder + else: + return FeaturesLoader(features_dict_or_folder) + + +def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, target_channels): + """ + Aggregate sparse features that have unaligned channels and realigned then on target_channels. + + This is usefull to aligned back peaks waveform or pca or tsvd when detected a differents channels. + + + Parameters + ---------- + peaks + + peak_indices + + sparse_feature + + sparse_mask + + target_channels + + Returns + ------- + aligned_features: numpy.array + Aligned features. shape is (local_peaks.size, sparse_feature.shape[1], target_channels.size) + dont_have_channels: numpy.array + Boolean vector to indicate spikes that do not have all target channels to be taken in account + shape is peak_indices.size + """ + local_peaks = peaks[peak_indices] + + aligned_features = np.zeros( + (local_peaks.size, sparse_feature.shape[1], target_channels.size), dtype=sparse_feature.dtype + ) + dont_have_channels = np.zeros(peak_indices.size, dtype=bool) + + for chan in np.unique(local_peaks["channel_index"]): + sparse_chans = np.flatnonzero(sparse_mask[chan, :]) + peak_inds = np.flatnonzero(local_peaks["channel_index"] == chan) + if np.all(np.isin(target_channels, sparse_chans)): + # peaks feature channel have all target_channels + source_chans = np.flatnonzero(np.in1d(sparse_chans, target_channels)) + aligned_features[peak_inds, :, :] = sparse_feature[peak_indices[peak_inds], :, :][:, :, source_chans] + else: + # some channel are missing, peak are not removde + dont_have_channels[peak_inds] = True + + return aligned_features, dont_have_channels + + +def compute_template_from_sparse( + peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None +): + """ + Compute template average from single sparse waveforms buffer. + + Parameters + ---------- + peaks + + labels + + labels_set + + sparse_waveforms + + sparse_mask + + total_channels + + peak_shifts + + Returns + ------- + templates: numpy.array + Templates shape : (len(labels_set), num_samples, total_channels) + """ + n = len(labels_set) + + templates = np.zeros((n, sparse_waveforms.shape[1], total_channels), dtype=sparse_waveforms.dtype) + + for i, label in enumerate(labels_set): + peak_indices = np.flatnonzero(labels == label) + + local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(sparse_mask[local_chans, :], axis=0)) + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_waveforms, sparse_mask, target_channels + ) + + if peak_shifts is not None: + apply_waveforms_shift(aligned_wfs, peak_shifts[peak_indices], inplace=True) + + templates[i, :, :][:, target_channels] = np.mean(aligned_wfs[~dont_have_channels], axis=0) + + return templates + + +def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): + """ + Apply a shift a spike level to realign waveforms buffers. + + This is usefull to compute template after merge when to cluster are shifted. + + A negative shift need the waveforms to be moved toward the right because the trough was too early. + A positive shift need the waveforms to be moved toward the left because the trough was too late. + + Note the border sample are left as before without move. + + Parameters + ---------- + + waveforms + + peak_shifts + + inplace + + Returns + ------- + aligned_waveforms + + + """ + + print("apply_waveforms_shift") + + if inplace: + aligned_waveforms = waveforms + else: + aligned_waveforms = waveforms.copy() + + shift_set = np.unique(peak_shifts) + assert max(np.abs(shift_set)) < aligned_waveforms.shape[1] + + for shift in shift_set: + if shift == 0: + continue + mask = peak_shifts == shift + wfs = waveforms[mask] + + if shift > 0: + aligned_waveforms[mask, :-shift, :] = wfs[:, shift:, :] + else: + aligned_waveforms[mask, -shift:, :] = wfs[:, :-shift, :] + + print("apply_waveforms_shift DONE") + + return aligned_waveforms diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index bd82ffa0a6..06d22181cb 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -184,41 +184,49 @@ def __init__( return_output=True, parents=None, projections=None, - radius_um=150.0, - min_values=None, + sigmoid=None, + radius_um=None, + sparse=True, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.projections = projections - self.radius_um = radius_um - self.min_values = min_values - + self.sigmoid = sigmoid self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance < radius_um - - self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values)) - + self.radius_um = radius_um + self.sparse = sparse + self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse)) self._dtype = recording.get_dtype() def get_dtype(self): return self._dtype + def _sigmoid(self, x): + L, x0, k, b = self.sigmoid + y = L / (1 + np.exp(-k * (x - x0))) + b + return y + def compute(self, traces, peaks, waveforms): all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype) + for main_chan in np.unique(peaks["channel_index"]): (idx,) = np.nonzero(peaks["channel_index"] == main_chan) (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] - wf_ptp = (waveforms[idx][:, :, chan_inds]).ptp(axis=1) + if self.sparse: + wf_ptp = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1) + else: + wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) - if self.min_values is not None: - wf_ptp = (wf_ptp / self.min_values[chan_inds]) ** 4 + if self.sigmoid is not None: + wf_ptp *= self._sigmoid(wf_ptp) denom = np.sum(wf_ptp, axis=1) mask = denom != 0 - all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis]) + return all_projections diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 2196320378..ea36b75847 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -5,7 +5,6 @@ import scipy.spatial -from tqdm import tqdm import scipy try: @@ -16,7 +15,8 @@ except ImportError: HAVE_SKLEARN = False -from spikeinterface.core import get_noise_levels, get_random_data_chunks + +from spikeinterface.core import get_noise_levels, get_random_data_chunks, compute_sparsity from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) @@ -33,9 +33,6 @@ from .main import BaseTemplateMatchingEngine -################# -# Circus peeler # -################# from scipy.fft._helper import _init_nd_shape_and_axes @@ -131,6 +128,38 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): return ret +def compute_overlaps(templates, num_samples, num_channels, sparsities): + num_templates = len(templates) + + dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) + for i in range(num_templates): + dense_templates[i, :, sparsities[i]] = templates[i].T + + size = 2 * num_samples - 1 + + all_delays = list(range(0, num_samples + 1)) + + overlaps = {} + + for delay in all_delays: + source = dense_templates[:, :delay, :].reshape(num_templates, -1) + target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) + + overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) + + if delay < num_samples: + overlaps[size - delay + 1] = overlaps[delay].T.tocsr() + + new_overlaps = [] + + for i in range(num_templates): + data = [overlaps[j][i, :].T for j in range(size)] + data = scipy.sparse.hstack(data) + new_overlaps += [data] + + return new_overlaps + + class CircusOMPPeeler(BaseTemplateMatchingEngine): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -152,21 +181,18 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float Stopping criteria of the OMP algorithm, in percentage of the norm - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain. ptp limit for considering a channel as silent - smoothing_factor: float - Templates are smoothed via Spline Interpolation noise_levels: array The noise levels, for every channels. If None, they will be automatically computed random_chunk_kwargs: dict Parameters for computing noise levels, if not provided (sub optimal) + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. ----- """ _default_params = { - "sparsify_threshold": 1, "amplitudes": [0.6, 2], "omp_min_sps": 0.1, "waveform_extractor": None, @@ -175,36 +201,21 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): "norms": None, "random_chunk_kwargs": {}, "noise_levels": None, - "smoothing_factor": 0.25, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], + "vicinity": 0, } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold): - is_silent = template.ptp(0) < sparsify_threshold - template[:, is_silent] = 0 - (active_channels,) = np.where(np.logical_not(is_silent)) - - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): waveform_extractor = d["waveform_extractor"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] num_templates = len(d["waveform_extractor"].sorting.unit_ids) + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask + templates = waveform_extractor.get_all_templates(mode="median").copy() d["sparsities"] = {} @@ -212,52 +223,10 @@ def _prepare_templates(cls, d): d["norms"] = np.zeros(num_templates, dtype=np.float32) for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - if d["smoothing_factor"] > 0: - template = cls._regularize_template(templates[count], d["smoothing_factor"]) - else: - template = templates[count] - template, active_channels = cls._sparsify_template(template, d["sparsify_threshold"]) - d["sparsities"][count] = active_channels + template = templates[count][:, sparsity[count]] + (d["sparsities"][count],) = np.nonzero(sparsity[count]) d["norms"][count] = np.linalg.norm(template) - d["templates"][count] = template[:, active_channels] / d["norms"][count] - - return d - - @classmethod - def _prepare_overlaps(cls, d): - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - sparsities = d["sparsities"] - - dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) - for i in range(num_templates): - dense_templates[i, :, sparsities[i]] = templates[i].T - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples + 1)) - - overlaps = {} - - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) - - if delay < num_samples: - overlaps[size - delay + 1] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d["overlaps"] = new_overlaps + d["templates"][count] = template / d["norms"][count] return d @@ -276,6 +245,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["nbefore"] = d["waveform_extractor"].nbefore d["nafter"] = d["waveform_extractor"].nafter d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["vicinity"] *= d["num_samples"] if d["noise_levels"] is None: print("CircusOMPPeeler : noise should be computed outside") @@ -290,15 +260,12 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["num_templates"] = len(d["templates"]) if d["overlaps"] is None: - d = cls._prepare_overlaps(d) + d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) d["ignored_ids"] = np.array(d["ignored_ids"]) omp_min_sps = d["omp_min_sps"] - norms = d["norms"] - sparsities = d["sparsities"] - - nb_active_channels = np.array([len(sparsities[i]) for i in range(d["num_templates"])]) + # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) return d @@ -336,6 +303,7 @@ def main_function(cls, traces, d): sparsities = d["sparsities"] ignored_ids = d["ignored_ids"] stop_criteria = d["stop_criteria"] + vicinity = d["vicinity"] if "cached_fft_kernels" not in d: d["cached_fft_kernels"] = {"fshape": 0} @@ -381,7 +349,7 @@ def main_function(cls, traces, d): spikes = np.empty(scalar_products.size, dtype=spike_dtype) idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) - M = np.zeros((num_peaks, num_peaks), dtype=np.float32) + M = np.zeros((100, 100), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -393,6 +361,8 @@ def main_function(cls, traces, d): cached_overlaps = {} is_valid = scalar_products > stop_criteria + all_amplitudes = np.zeros(0, dtype=np.float32) + is_in_vicinity = np.zeros(0, dtype=np.int32) while np.any(is_valid): best_amplitude_ind = scalar_products[is_valid].argmax() @@ -412,20 +382,39 @@ def main_function(cls, traces, d): M = Z M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] - scipy.linalg.solve_triangular( - M[:num_selection, :num_selection], - M[num_selection, :num_selection], - trans=0, - lower=1, - overwrite_b=True, - check_finite=False, - ) - - v = nrm2(M[num_selection, :num_selection]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) + + if vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 else: M[0, 0] = 1 @@ -435,9 +424,16 @@ def main_function(cls, traces, d): selection = all_selections[:, :num_selection] res_sps = full_sps[selection[0], selection[1]] - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - - all_amplitudes /= norms[selection[0]] + if True: # vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + all_amplitudes /= norms[selection[0]] + else: + # This is not working, need to figure out why + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] @@ -479,6 +475,367 @@ def main_function(cls, traces, d): return spikes +class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): + """ + Orthogonal Matching Pursuit inspired from Spyking Circus sorter + + https://elifesciences.org/articles/34518 + + This is an Orthogonal Template Matching algorithm. For speed and + memory optimization, templates are automatically sparsified. Signal + is convolved with the templates, and as long as some scalar products + are higher than a given threshold, we use a Cholesky decomposition + to compute the optimal amplitudes needed to reconstruct the signal. + + IMPORTANT NOTE: small chunks are more efficient for such Peeler, + consider using 100ms chunk + + Parameters + ---------- + amplitude: tuple + (Minimal, Maximal) amplitudes allowed for every template + omp_min_sps: float + Stopping criteria of the OMP algorithm, in percentage of the norm + noise_levels: array + The noise levels, for every channels. If None, they will be automatically + computed + random_chunk_kwargs: dict + Parameters for computing noise levels, if not provided (sub optimal) + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. + ----- + """ + + _default_params = { + "amplitudes": [0.6, 2], + "omp_min_sps": 0.1, + "waveform_extractor": None, + "random_chunk_kwargs": {}, + "noise_levels": None, + "rank": 5, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, + "ignored_ids": [], + "vicinity": 0, + } + + @classmethod + def _prepare_templates(cls, d): + waveform_extractor = d["waveform_extractor"] + num_templates = len(d["waveform_extractor"].sorting.unit_ids) + + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask + + d["sparsity_mask"] = sparsity + units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) + d["units_overlaps"] = units_overlaps > 0 + d["unit_overlaps_indices"] = {} + for i in range(num_templates): + (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) + + templates = waveform_extractor.get_all_templates(mode="median").copy() + + # First, we set masked channels to 0 + for count in range(num_templates): + templates[count][:, ~d["sparsity_mask"][count]] = 0 + + # Then we keep only the strongest components + rank = d["rank"] + temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) + d["temporal"] = temporal[:, :, :rank] + d["singular"] = singular[:, :rank] + d["spatial"] = spatial[:, :rank, :] + + # We reconstruct the approximated templates + templates = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) + + d["templates"] = {} + d["norms"] = np.zeros(num_templates, dtype=np.float32) + + # And get the norms, saving compressed templates for CC matrix + for count in range(num_templates): + template = templates[count][:, d["sparsity_mask"][count]] + d["norms"][count] = np.linalg.norm(template) + d["templates"][count] = template / d["norms"][count] + + d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] + d["temporal"] = np.flip(d["temporal"], axis=1) + + d["overlaps"] = [] + for i in range(num_templates): + num_overlaps = np.sum(d["units_overlaps"][i]) + overlapping_units = np.where(d["units_overlaps"][i])[0] + + # Reconstruct unit template from SVD Matrices + data = d["temporal"][i] * d["singular"][i][np.newaxis, :] + template_i = np.matmul(data, d["spatial"][i, :, :]) + template_i = np.flipud(template_i) + + unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32) + + for count, j in enumerate(overlapping_units): + overlapped_channels = d["sparsity_mask"][j] + visible_i = template_i[:, overlapped_channels] + + spatial_filters = d["spatial"][j, :, overlapped_channels] + spatially_filtered_template = np.matmul(visible_i, spatial_filters) + visible_i = spatially_filtered_template * d["singular"][j] + + for rank in range(visible_i.shape[1]): + unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full") + + d["overlaps"].append(unit_overlaps) + + d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) + d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) + d["singular"] = d["singular"].T[:, :, np.newaxis] + + return d + + @classmethod + def initialize_and_check_kwargs(cls, recording, kwargs): + d = cls._default_params.copy() + d.update(kwargs) + + # assert isinstance(d['waveform_extractor'], WaveformExtractor) + + for v in ["omp_min_sps"]: + assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + + d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() + d["num_samples"] = d["waveform_extractor"].nsamples + d["nbefore"] = d["waveform_extractor"].nbefore + d["nafter"] = d["waveform_extractor"].nafter + d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["vicinity"] *= d["num_samples"] + + if d["noise_levels"] is None: + print("CircusOMPPeeler : noise should be computed outside") + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + + if "templates" not in d: + d = cls._prepare_templates(d) + else: + for key in [ + "norms", + "temporal", + "spatial", + "singular", + "units_overlaps", + "sparsity_mask", + "unit_overlaps_indices", + ]: + assert d[key] is not None, "If templates are provided, %d should also be there" % key + + d["num_templates"] = len(d["templates"]) + d["ignored_ids"] = np.array(d["ignored_ids"]) + + d["unit_overlaps_tables"] = {} + for i in range(d["num_templates"]): + d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) + d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) + + omp_min_sps = d["omp_min_sps"] + # d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + d["stop_criteria"] = omp_min_sps * np.maximum(d["norms"], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) + + return d + + @classmethod + def serialize_method_kwargs(cls, kwargs): + kwargs = dict(kwargs) + # remove waveform_extractor + kwargs.pop("waveform_extractor") + return kwargs + + @classmethod + def unserialize_in_worker(cls, kwargs): + return kwargs + + @classmethod + def get_margin(cls, recording, kwargs): + margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) + return margin + + @classmethod + def main_function(cls, traces, d): + templates = d["templates"] + num_templates = d["num_templates"] + num_channels = d["num_channels"] + num_samples = d["num_samples"] + overlaps = d["overlaps"] + norms = d["norms"] + nbefore = d["nbefore"] + nafter = d["nafter"] + omp_tol = np.finfo(np.float32).eps + num_samples = d["nafter"] + d["nbefore"] + neighbor_window = num_samples - 1 + min_amplitude, max_amplitude = d["amplitudes"] + ignored_ids = d["ignored_ids"] + stop_criteria = d["stop_criteria"][:, np.newaxis] + vicinity = d["vicinity"] + rank = d["rank"] + + num_timesteps = len(traces) + + num_peaks = num_timesteps - num_samples + 1 + conv_shape = (num_templates, num_peaks) + scalar_products = np.zeros(conv_shape, dtype=np.float32) + + # Filter using overlap-and-add convolution + if len(ignored_ids) > 0: + mask = ~np.isin(np.arange(num_templates), ignored_ids) + spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"][:, mask, :] + objective_by_rank = scipy.signal.oaconvolve( + scaled_filtered_data, d["temporal"][:, mask, :], axes=2, mode="valid" + ) + scalar_products[mask] += np.sum(objective_by_rank, axis=0) + scalar_products[ignored_ids] = -np.inf + else: + spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"] + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") + scalar_products += np.sum(objective_by_rank, axis=0) + + num_spikes = 0 + + spikes = np.empty(scalar_products.size, dtype=spike_dtype) + idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) + + M = np.zeros((num_templates, num_templates), dtype=np.float32) + + all_selections = np.empty((2, scalar_products.size), dtype=np.int32) + final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) + num_selection = 0 + + full_sps = scalar_products.copy() + + neighbors = {} + cached_overlaps = {} + + is_valid = scalar_products > stop_criteria + all_amplitudes = np.zeros(0, dtype=np.float32) + is_in_vicinity = np.zeros(0, dtype=np.int32) + + while np.any(is_valid): + best_amplitude_ind = scalar_products[is_valid].argmax() + best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) + + if num_selection > 0: + delta_t = selection[1] - peak_index + idx = np.where((delta_t < num_samples) & (delta_t > -num_samples))[0] + myline = neighbor_window + delta_t[idx] + myindices = selection[0, idx] + + local_overlaps = overlaps[best_cluster_ind] + overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] + table = d["unit_overlaps_tables"][best_cluster_ind] + + if num_selection == M.shape[0]: + Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) + Z[:num_selection, :num_selection] = M + M = Z + + mask = np.isin(myindices, overlapping_templates) + a, b = myindices[mask], myline[mask] + M[num_selection, idx[mask]] = local_overlaps[table[a], b] + + if vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 + else: + M[0, 0] = 1 + + all_selections[:, num_selection] = [best_cluster_ind, peak_index] + num_selection += 1 + + selection = all_selections[:, :num_selection] + res_sps = full_sps[selection[0], selection[1]] + + if True: # vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + all_amplitudes /= norms[selection[0]] + else: + # This is not working, need to figure out why + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] + + diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] + modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] + final_amplitudes[selection[0], selection[1]] = all_amplitudes + + for i in modified: + tmp_best, tmp_peak = selection[:, i] + diff_amp = diff_amplitudes[i] * norms[tmp_best] + + local_overlaps = overlaps[tmp_best] + overlapping_templates = d["units_overlaps"][tmp_best] + + if not tmp_peak in neighbors.keys(): + idx = [max(0, tmp_peak - neighbor_window), min(num_peaks, tmp_peak + num_samples)] + tdx = [neighbor_window + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak - 1] + neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} + + idx = neighbors[tmp_peak]["idx"] + tdx = neighbors[tmp_peak]["tdx"] + + to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] + scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add + + is_valid = scalar_products > stop_criteria + + is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) + valid_indices = np.where(is_valid) + + num_spikes = len(valid_indices[0]) + spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["channel_index"][:num_spikes] = 0 + spikes["cluster_index"][:num_spikes] = valid_indices[0] + spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] + + spikes = spikes[:num_spikes] + order = np.argsort(spikes["sample_index"]) + spikes = spikes[order] + + return spikes + + class CircusPeeler(BaseTemplateMatchingEngine): """ @@ -515,14 +872,12 @@ class CircusPeeler(BaseTemplateMatchingEngine): Maximal amplitude allowed for every template min_amplitude: float Minimal amplitude allowed for every template - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain a given fraction of the total norm use_sparse_matrix_threshold: float If density of the templates is below a given threshold, sparse matrix are used (memory efficient) - progress_bar_steps: bool - In order to display or not steps from the algorithm + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. ----- @@ -535,68 +890,40 @@ class CircusPeeler(BaseTemplateMatchingEngine): "detect_threshold": 5, "noise_levels": None, "random_chunk_kwargs": {}, - "sparsify_threshold": 0.99, "max_amplitude": 1.5, "min_amplitude": 0.5, "use_sparse_matrix_threshold": 0.25, - "progess_bar_steps": False, "waveform_extractor": None, - "smoothing_factor": 0.25, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold, noise_levels): - is_silent = template.std(0) < 0.1 * noise_levels - - template[:, is_silent] = 0 - - channel_norms = np.linalg.norm(template, axis=0) ** 2 - total_norm = np.linalg.norm(template) ** 2 - - idx = np.argsort(channel_norms)[::-1] - explained_norms = np.cumsum(channel_norms[idx] / total_norm) - channel = np.searchsorted(explained_norms, sparsify_threshold) - active_channels = np.sort(idx[:channel]) - template[:, idx[channel:]] = 0 - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): - parameters = d - waveform_extractor = parameters["waveform_extractor"] - num_samples = parameters["num_samples"] - num_channels = parameters["num_channels"] - num_templates = parameters["num_templates"] - max_amplitude = parameters["max_amplitude"] - min_amplitude = parameters["min_amplitude"] - use_sparse_matrix_threshold = parameters["use_sparse_matrix_threshold"] + waveform_extractor = d["waveform_extractor"] + num_samples = d["num_samples"] + num_channels = d["num_channels"] + num_templates = d["num_templates"] + use_sparse_matrix_threshold = d["use_sparse_matrix_threshold"] - parameters["norms"] = np.zeros(num_templates, dtype=np.float32) + d["norms"] = np.zeros(num_templates, dtype=np.float32) - all_units = list(parameters["waveform_extractor"].sorting.unit_ids) + all_units = list(d["waveform_extractor"].sorting.unit_ids) + + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask templates = waveform_extractor.get_all_templates(mode="median").copy() + d["sparsities"] = {} + d["circus_templates"] = {} for count, unit_id in enumerate(all_units): - if parameters["smoothing_factor"] > 0: - templates[count] = cls._regularize_template(templates[count], parameters["smoothing_factor"]) - - templates[count], _ = cls._sparsify_template( - templates[count], parameters["sparsify_threshold"], parameters["noise_levels"] - ) - parameters["norms"][count] = np.linalg.norm(templates[count]) - templates[count] /= parameters["norms"][count] + (d["sparsities"][count],) = np.nonzero(sparsity[count]) + templates[count][:, ~sparsity[count]] = 0 + d["norms"][count] = np.linalg.norm(templates[count]) + templates[count] /= d["norms"][count] + d["circus_templates"][count] = templates[count][:, sparsity[count]] templates = templates.reshape(num_templates, -1) @@ -604,54 +931,11 @@ def _prepare_templates(cls, d): if nnz <= use_sparse_matrix_threshold: templates = scipy.sparse.csr_matrix(templates) print(f"Templates are automatically sparsified (sparsity level is {nnz})") - parameters["is_dense"] = False + d["is_dense"] = False else: - parameters["is_dense"] = True + d["is_dense"] = True - parameters["templates"] = templates - - return parameters - - @classmethod - def _prepare_overlaps(cls, d): - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - is_dense = d["is_dense"] - - if not is_dense: - dense_templates = templates.toarray() - else: - dense_templates = templates - - dense_templates = dense_templates.reshape(num_templates, num_samples, num_channels) - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples + 1)) - if d["progess_bar_steps"]: - all_delays = tqdm(all_delays, desc="[1] compute overlaps") - - overlaps = {} - - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) - - if delay < num_samples: - overlaps[size - delay] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d["overlaps"] = new_overlaps + d["templates"] = templates return d @@ -687,15 +971,13 @@ def _optimize_amplitudes(cls, noise_snippets, d): alpha = 0.5 norms = parameters["norms"] all_units = list(waveform_extractor.sorting.unit_ids) - if parameters["progess_bar_steps"]: - all_units = tqdm(all_units, desc="[2] compute amplitudes") parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) noise = templates.dot(noise_snippets) / norms[:, np.newaxis] all_amps = {} for count, unit_id in enumerate(all_units): - waveform = waveform_extractor.get_waveforms(unit_id) + waveform = waveform_extractor.get_waveforms(unit_id, force_dense=True) snippets = waveform.reshape(waveform.shape[0], -1).T amps = templates.dot(snippets) / norms[:, np.newaxis] good = amps[count, :].flatten() @@ -708,16 +990,6 @@ def _optimize_amplitudes(cls, noise_snippets, d): res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) parameters["amplitudes"][count] = res.x - # import pylab as plt - # plt.hist(good, 100, alpha=0.5) - # plt.hist(bad, 100, alpha=0.5) - # plt.hist(noise[count], 100, alpha=0.5) - # ymin, ymax = plt.ylim() - # plt.plot([res.x[0], res.x[0]], [ymin, ymax], 'k--') - # plt.plot([res.x[1], res.x[1]], [ymin, ymax], 'k--') - # plt.savefig('test_%d.png' %count) - # plt.close() - return d @classmethod @@ -727,8 +999,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): default_parameters.update(kwargs) # assert isinstance(d['waveform_extractor'], WaveformExtractor) - - for v in ["sparsify_threshold", "use_sparse_matrix_threshold"]: + for v in ["use_sparse_matrix_threshold"]: assert (default_parameters[v] >= 0) and (default_parameters[v] <= 1), f"{v} should be in [0, 1]" default_parameters["num_channels"] = default_parameters["waveform_extractor"].recording.get_num_channels() @@ -746,7 +1017,13 @@ def initialize_and_check_kwargs(cls, recording, kwargs): ) default_parameters = cls._prepare_templates(default_parameters) - default_parameters = cls._prepare_overlaps(default_parameters) + + default_parameters["overlaps"] = compute_overlaps( + default_parameters["circus_templates"], + default_parameters["num_samples"], + default_parameters["num_channels"], + default_parameters["sparsities"], + ) default_parameters["exclude_sweep_size"] = int( default_parameters["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 @@ -817,31 +1094,31 @@ def main_function(cls, traces, d): sym_patch = d["sym_patch"] peak_traces = traces[margin // 2 : -margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakByChannel.detect_peaks( + peak_sample_index, peak_chan_ind = DetectPeakByChannel.detect_peaks( peak_traces, peak_sign, abs_threholds, exclude_sweep_size ) if jitter > 0: - jittered_peaks = peak_sample_ind[:, np.newaxis] + np.arange(-jitter, jitter) + jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-jitter, jitter) jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * jitter) mask = (jittered_peaks > 0) & (jittered_peaks < len(peak_traces)) jittered_peaks = jittered_peaks[mask] jittered_channels = jittered_channels[mask] - peak_sample_ind, unique_idx = np.unique(jittered_peaks, return_index=True) + peak_sample_index, unique_idx = np.unique(jittered_peaks, return_index=True) peak_chan_ind = jittered_channels[unique_idx] else: - peak_sample_ind, unique_idx = np.unique(peak_sample_ind, return_index=True) + peak_sample_index, unique_idx = np.unique(peak_sample_index, return_index=True) peak_chan_ind = peak_chan_ind[unique_idx] - num_peaks = len(peak_sample_ind) + num_peaks = len(peak_sample_index) if sym_patch: - snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_ind] - peak_sample_ind += margin // 2 + snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_index] + peak_sample_index += margin // 2 else: - peak_sample_ind += margin // 2 + peak_sample_index += margin // 2 snippet_window = np.arange(-d["nbefore"], d["nafter"]) - snippets = traces[peak_sample_ind[:, np.newaxis] + snippet_window] + snippets = traces[peak_sample_index[:, np.newaxis] + snippet_window] if num_peaks > 0: snippets = snippets.reshape(num_peaks, -1) @@ -865,10 +1142,10 @@ def main_function(cls, traces, d): best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) best_amplitude = scalar_products[best_cluster_ind, peak_index] - best_peak_sample_ind = peak_sample_ind[peak_index] + best_peak_sample_index = peak_sample_index[peak_index] best_peak_chan_ind = peak_chan_ind[peak_index] - peak_data = peak_sample_ind - peak_sample_ind[peak_index] + peak_data = peak_sample_index - peak_sample_index[peak_index] is_valid_nn = np.searchsorted(peak_data, [-neighbor_window, neighbor_window + 1]) idx_neighbor = peak_data[is_valid_nn[0] : is_valid_nn[1]] + neighbor_window @@ -880,7 +1157,7 @@ def main_function(cls, traces, d): scalar_products[:, is_valid_nn[0] : is_valid_nn[1]] += to_add scalar_products[best_cluster_ind, is_valid_nn[0] : is_valid_nn[1]] = -np.inf - spikes["sample_index"][num_spikes] = best_peak_sample_ind + spikes["sample_index"][num_spikes] = best_peak_sample_index spikes["channel_index"][num_spikes] = best_peak_chan_ind spikes["cluster_index"][num_spikes] = best_cluster_ind spikes["amplitude"][num_spikes] = best_amplitude diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index bedc04a9d5..d982943126 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -1,6 +1,6 @@ from .naive import NaiveMatching from .tdc import TridesclousPeeler -from .circus import CircusPeeler, CircusOMPPeeler +from .circus import CircusPeeler, CircusOMPPeeler, CircusOMPSVDPeeler from .wobble import WobbleMatch matching_methods = { @@ -8,5 +8,6 @@ "tridesclous": TridesclousPeeler, "circus": CircusPeeler, "circus-omp": CircusOMPPeeler, + "circus-omp-svd": CircusOMPSVDPeeler, "wobble": WobbleMatch, } diff --git a/src/spikeinterface/sortingcomponents/tests/test_merge.py b/src/spikeinterface/sortingcomponents/tests/test_merge.py new file mode 100644 index 0000000000..6b3ea2a901 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_merge.py @@ -0,0 +1,14 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + +# no proper test at the moment this is used in tridesclous2 + + +def test_merge(): + pass + + +if __name__ == "__main__": + test_merge() diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py new file mode 100644 index 0000000000..5953f74e24 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -0,0 +1,14 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + +# no proper test at the moment this is used in tridesclous2 + + +def test_split(): + pass + + +if __name__ == "__main__": + test_split() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index c0dcd7ea6e..c10c78cbfc 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,8 +1,3 @@ -# basics -# from .timeseries import plot_timeseries, TracesWidget -from .rasters import plot_rasters, RasterWidget -from .probemap import plot_probe_map, ProbeMapWidget - # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget @@ -15,9 +10,6 @@ # units on probe from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget -# comparison related -from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget -from .agreementmatrix import plot_agreement_matrix, AgreementMatrixWidget from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, @@ -41,22 +33,6 @@ from .sortingperformance import plot_sorting_performance -# ground truth study (=comparison over sorter) -from .gtstudy import ( - StudyComparisonRunTimesWidget, - plot_gt_study_run_times, - StudyComparisonUnitCountsWidget, - StudyComparisonUnitCountsAveragesWidget, - plot_gt_study_unit_counts, - plot_gt_study_unit_counts_averages, - plot_gt_study_performances, - plot_gt_study_performances_averages, - StudyComparisonPerformancesWidget, - StudyComparisonPerformancesAveragesWidget, - plot_gt_study_performances_by_template_similarity, - StudyComparisonPerformancesByTemplateSimilarity, -) - # ground truth comparions (=comparison over sorter) from .gtcomparison import ( plot_gt_performances, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py index 939475c17d..9715b7ea87 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py @@ -95,8 +95,7 @@ def plot(self): num_frames = int(duration / self.bin_duration_s) def animate_func(i): - i0 = np.searchsorted(peaks["sample_index"], bin_size * i) - i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1)) + i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s) return artists diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index 6d981e1fd4..468b96ff3b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -1,7 +1,6 @@ import numpy as np from .basewidget import BaseWidget -from spikeinterface.comparison.collisioncomparison import CollisionGTComparison class ComparisonCollisionPairByPairWidget(BaseWidget): @@ -44,6 +43,8 @@ def plot(self): self._do_plot() def _do_plot(self): + from matplotlib import pyplot as plt + fig = self.figure for ax in fig.axes: @@ -178,6 +179,8 @@ def plot(self): def _do_plot(self): import sklearn + import matplotlib.pyplot as plt + import matplotlib # compute similarity # take index of template (respect unit_ids order) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py deleted file mode 100644 index 573221f528..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py +++ /dev/null @@ -1,574 +0,0 @@ -""" -Various widgets on top of GroundTruthStudy to summary results: - * run times - * performances - * count units -""" -import numpy as np - - -from .basewidget import BaseWidget - - -class StudyComparisonRunTimesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - color: - - - """ - - def __init__(self, study, color="#F7DC6F", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.color = color - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - study = self.study - ax = self.ax - - all_run_times = study.aggregate_run_times() - av_run_times = all_run_times.reset_index().groupby("sorter_name")["run_time"].mean() - - if len(study.rec_names) == 1: - # no errors bars - yerr = None - else: - # errors bars across recording - yerr = all_run_times.reset_index().groupby("sorter_name")["run_time"].std() - - sorter_names = av_run_times.index - - x = np.arange(sorter_names.size) + 1 - ax.bar(x, av_run_times.values, width=0.8, color=self.color, yerr=yerr) - ax.set_ylabel("run times (s)") - ax.set_xticks(x) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_xlim(0, sorter_names.size + 1) - - -def plot_gt_study_run_times(*args, **kwargs): - W = StudyComparisonRunTimesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_run_times.__doc__ = StudyComparisonRunTimesWidget.__doc__ - - -class StudyComparisonUnitCountsAveragesWidget(BaseWidget): - """ - Plot averages over found units for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - log_scale: if the y-axis should be displayed as log scaled - - """ - - def __init__(self, study, cmap_name="Set2", log_scale=False, ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.log_scale = log_scale - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - study = self.study - ax = self.ax - - count_units = study.aggregate_count_units() - - if study.exhaustive_gt: - columns = ["num_well_detected", "num_false_positive", "num_redundant", "num_overmerged"] - else: - columns = ["num_well_detected", "num_redundant", "num_overmerged"] - ncol = len(columns) - - df = count_units.reset_index() - - m = df.groupby("sorter_name")[columns].mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = df.groupby("sorter_name")[columns].std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values, yerr=yerr, width=1 / (ncol + 2), color=cmap(c), label=clean_labels[c]) - - ax.legend() - if self.log_scale: - ax.set_yscale("log") - - ax.set_xticks(np.arange(sorter_names.size) + 1) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("# units") - ax.set_xlim(0, sorter_names.size + 1) - - if count_units["num_gt"].unique().size == 1: - num_gt = count_units["num_gt"].unique()[0] - ax.axhline(num_gt, ls="--", color="k") - - -class StudyComparisonUnitCountsWidget(BaseWidget): - """ - Plot averages over found units for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - log_scale: if the y-axis should be displayed as log scaled - - """ - - def __init__(self, study, cmap_name="Set2", log_scale=False, ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.log_scale = log_scale - - num_rec = len(study.rec_names) - if ax is None: - fig, axes = plt.subplots(ncols=1, nrows=num_rec, squeeze=False) - else: - axes = np.array([ax]).T - - BaseWidget.__init__(self, axes=axes) - - def plot(self): - study = self.study - ax = self.ax - - import seaborn as sns - - study = self.study - - count_units = study.aggregate_count_units() - count_units = count_units.reset_index() - - if study.exhaustive_gt: - columns = ["num_well_detected", "num_false_positive", "num_redundant", "num_overmerged"] - else: - columns = ["num_well_detected", "num_redundant", "num_overmerged"] - - ncol = len(columns) - cmap = plt.get_cmap(self.cmap_name, 4) - - for r, rec_name in enumerate(study.rec_names): - ax = self.axes[r, 0] - ax.set_title(rec_name) - df = count_units.loc[count_units["rec_name"] == rec_name, :] - m = df.groupby("sorter_name")[columns].mean() - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - ax.bar(x, m[col].values, width=1 / (ncol + 2), color=cmap(c), label=clean_labels[c]) - - if r == 0: - ax.legend() - - if self.log_scale: - ax.set_yscale("log") - - if r == len(study.rec_names) - 1: - ax.set_xticks(np.arange(sorter_names.size) + 1) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("# units") - ax.set_xlim(0, sorter_names.size + 1) - - if count_units["num_gt"].unique().size == 1: - num_gt = count_units["num_gt"].unique()[0] - ax.axhline(num_gt, ls="--", color="k") - - -def plot_gt_study_unit_counts_averages(*args, **kwargs): - W = StudyComparisonUnitCountsAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_unit_counts_averages.__doc__ = StudyComparisonUnitCountsAveragesWidget.__doc__ - - -def plot_gt_study_unit_counts(*args, **kwargs): - W = StudyComparisonUnitCountsWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_unit_counts.__doc__ = StudyComparisonUnitCountsWidget.__doc__ - - -class StudyComparisonPerformancesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, palette="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.palette = palette - - num_rec = len(study.rec_names) - if ax is None: - fig, axes = plt.subplots(ncols=1, nrows=num_rec, squeeze=False) - else: - axes = np.array([ax]).T - - BaseWidget.__init__(self, axes=axes) - - def plot(self, average=False): - import seaborn as sns - - study = self.study - - sns.set_palette(sns.color_palette(self.palette)) - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - for r, rec_name in enumerate(study.rec_names): - ax = self.axes[r, 0] - ax.set_title(rec_name) - df = perf_by_units.loc[perf_by_units["rec_name"] == rec_name, :] - df = pd.melt( - df, - id_vars="sorter_name", - var_name="Metric", - value_name="Score", - value_vars=("accuracy", "precision", "recall"), - ).sort_values("sorter_name") - sns.swarmplot( - data=df, x="sorter_name", y="Score", hue="Metric", dodge=True, s=3, ax=ax - ) # order=sorter_list, - # ~ ax.set_xticklabels(sorter_names_short, rotation=30, ha='center') - # ~ ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5) - - ax.set_ylim(0, 1.05) - ax.set_ylabel(f"Perfs for {rec_name}") - if r < len(study.rec_names) - 1: - ax.set_xlabel("") - ax.set(xticklabels=[]) - - -class StudyComparisonTemplateSimilarityWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - study = self.study - ax = self.ax - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean"] - - data = perf_by_units.groupby(["sorter_name", "rec_name"]).agg(to_agg) - - m = data.groupby("sorter_name").mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = data.groupby("sorter_name").std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values.flatten(), yerr=yerr.flatten(), width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - - ax.set_xticks(np.arange(sorter_names.size) + 1 + width) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("metric") - ax.set_xlim(0, sorter_names.size + 1) - - -class StudyComparisonPerformancesAveragesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - study = self.study - ax = self.ax - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean"] - - data = perf_by_units.groupby(["sorter_name", "rec_name"]).agg(to_agg) - - m = data.groupby("sorter_name").mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = data.groupby("sorter_name").std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values.flatten(), yerr=yerr.flatten(), width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - - ax.set_xticks(np.arange(sorter_names.size) + 1 + width) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("metric") - ax.set_xlim(0, sorter_names.size + 1) - - -class StudyComparisonPerformancesByTemplateSimilarity(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None, ylim=(0.6, 1), show_legend=True): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.show_legend = show_legend - self.ylim = ylim - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import sklearn - - cmap = plt.get_cmap(self.cmap_name, len(self.study.sorter_names)) - colors = [cmap(i) for i in range(len(self.study.sorter_names))] - - flat_templates_gt = {} - for rec_name in self.study.rec_names: - waveform_folder = self.study.study_folder / "waveforms" / f"waveforms_GroundTruth_{rec_name}" - if not waveform_folder.is_dir(): - self.study.compute_waveforms(rec_name) - - templates = self.study.get_templates(rec_name) - flat_templates_gt[rec_name] = templates.reshape(templates.shape[0], -1) - - all_results = {} - - for sorter_name in self.study.sorter_names: - all_results[sorter_name] = {"similarity": [], "accuracy": []} - - for rec_name in self.study.rec_names: - try: - waveform_folder = self.study.study_folder / "waveforms" / f"waveforms_{sorter_name}_{rec_name}" - if not waveform_folder.is_dir(): - self.study.compute_waveforms(rec_name, sorter_name) - templates = self.study.get_templates(rec_name, sorter_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix = sklearn.metrics.pairwise.cosine_similarity( - flat_templates_gt[rec_name], flat_templates - ) - - comp = self.study.comparisons[(rec_name, sorter_name)] - - for i, u1 in enumerate(comp.sorting1.unit_ids): - u2 = comp.best_match_12[u1] - if u2 != -1: - all_results[sorter_name]["similarity"] += [ - similarity_matrix[comp.sorting1.id_to_index(u1), comp.sorting2.id_to_index(u2)] - ] - all_results[sorter_name]["accuracy"] += [comp.agreement_scores.at[u1, u2]] - except Exception: - pass - - all_results[sorter_name]["similarity"] = np.array(all_results[sorter_name]["similarity"]) - all_results[sorter_name]["accuracy"] = np.array(all_results[sorter_name]["accuracy"]) - - from matplotlib.patches import Ellipse - - similarity_means = [all_results[sorter_name]["similarity"].mean() for sorter_name in self.study.sorter_names] - similarity_stds = [all_results[sorter_name]["similarity"].std() for sorter_name in self.study.sorter_names] - - accuracy_means = [all_results[sorter_name]["accuracy"].mean() for sorter_name in self.study.sorter_names] - accuracy_stds = [all_results[sorter_name]["accuracy"].std() for sorter_name in self.study.sorter_names] - - scount = 0 - for x, y, i, j in zip(similarity_means, accuracy_means, similarity_stds, accuracy_stds): - e = Ellipse((x, y), i, j) - e.set_alpha(0.2) - e.set_facecolor(colors[scount]) - self.ax.add_artist(e) - self.ax.scatter([x], [y], c=colors[scount], label=self.study.sorter_names[scount]) - scount += 1 - - self.ax.set_ylabel("accuracy") - self.ax.set_xlabel("cosine similarity") - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - if self.show_legend: - self.ax.legend() - - -def plot_gt_study_performances(*args, **kwargs): - W = StudyComparisonPerformancesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances.__doc__ = StudyComparisonPerformancesWidget.__doc__ - - -def plot_gt_study_performances_averages(*args, **kwargs): - W = StudyComparisonPerformancesAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances_averages.__doc__ = StudyComparisonPerformancesAveragesWidget.__doc__ - - -def plot_gt_study_performances_by_template_similarity(*args, **kwargs): - W = StudyComparisonPerformancesByTemplateSimilarity(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances_by_template_similarity.__doc__ = StudyComparisonPerformancesByTemplateSimilarity.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py deleted file mode 100644 index 6e6578a4c4..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ProbeMapWidget(BaseWidget): - """ - Plot the probe of a recording. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - channel_ids: list - The channel ids to display - with_channel_ids: bool False default - Add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function - - Returns - ------- - W: ProbeMapWidget - The output widget - """ - - def __init__(self, recording, channel_ids=None, with_channel_ids=False, figure=None, ax=None, **plot_probe_kwargs): - import matplotlib.pylab as plt - from probeinterface.plotting import plot_probe, get_auto_lims - - BaseWidget.__init__(self, figure, ax) - - if channel_ids is not None: - recording = recording.channel_slice(channel_ids) - self._recording = recording - self._probegroup = recording.get_probegroup() - self.with_channel_ids = with_channel_ids - self._plot_probe_kwargs = plot_probe_kwargs - - def plot(self): - self._do_plot() - - def _do_plot(self): - from probeinterface.plotting import get_auto_lims - - xlims, ylims, zlims = get_auto_lims(self._probegroup.probes[0]) - for i, probe in enumerate(self._probegroup.probes): - xlims2, ylims2, _ = get_auto_lims(probe) - xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) - ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) - - self._plot_probe_kwargs["title"] = False - pos = 0 - text_on_contact = None - for i, probe in enumerate(self._probegroup.probes): - n = probe.get_contact_count() - if self.with_channel_ids: - text_on_contact = self._recording.channel_ids[pos : pos + n] - pos += n - from probeinterface.plotting import plot_probe - - plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **self._plot_probe_kwargs) - - self.ax.set_xlim(*xlims) - self.ax.set_ylim(*ylims) - - -def plot_probe_map(*args, **kwargs): - W = ProbeMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_probe_map.__doc__ = ProbeMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py deleted file mode 100644 index d05373103e..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class RasterWidget(BaseWidget): - """ - Plots spike train rasters. - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - segment_index: None or int - The segment index. - unit_ids: list - List of unit ids - time_range: list - List with start time and end time - color: matplotlib color - The color to be used - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: RasterWidget - The output widget - """ - - def __init__(self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", figure=None, ax=None): - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._sorting = sorting - - if segment_index is None: - nseg = sorting.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - else: - segment_index = 0 - self.segment_index = segment_index - - self._unit_ids = unit_ids - self._figure = None - self._sampling_frequency = sorting.get_sampling_frequency() - self._color = color - self._max_frame = 0 - for unit_id in self._sorting.get_unit_ids(): - spike_train = self._sorting.get_unit_spike_train(unit_id, segment_index=self.segment_index) - if len(spike_train) > 0: - curr_max_frame = np.max(spike_train) - if curr_max_frame > self._max_frame: - self._max_frame = curr_max_frame - self._visible_trange = time_range - if self._visible_trange is None: - self._visible_trange = [0, self._max_frame] - else: - assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" - self._visible_trange = [int(t * self._sampling_frequency) for t in time_range] - - self._visible_trange = self._fix_trange(self._visible_trange) - self.name = "Raster" - - def plot(self): - self._do_plot() - - def _do_plot(self): - units_ids = self._unit_ids - if units_ids is None: - units_ids = self._sorting.get_unit_ids() - import matplotlib.pyplot as plt - - with plt.rc_context({"axes.edgecolor": "gray"}): - for u_i, unit_id in enumerate(units_ids): - spiketrain = self._sorting.get_unit_spike_train( - unit_id, - start_frame=self._visible_trange[0], - end_frame=self._visible_trange[1], - segment_index=self.segment_index, - ) - spiketimes = spiketrain / float(self._sampling_frequency) - self.ax.plot( - spiketimes, - u_i * np.ones_like(spiketimes), - marker="|", - mew=1, - markersize=3, - ls="", - color=self._color, - ) - visible_start_frame = self._visible_trange[0] / self._sampling_frequency - visible_end_frame = self._visible_trange[1] / self._sampling_frequency - self.ax.set_yticks(np.arange(len(units_ids))) - self.ax.set_yticklabels(units_ids) - self.ax.set_xlim(visible_start_frame, visible_end_frame) - self.ax.set_xlabel("time (s)") - - def _fix_trange(self, trange): - if trange[1] > self._max_frame: - # trange[0] += max_t - trange[1] - trange[1] = self._max_frame - if trange[0] < 0: - # trange[1] += -trange[0] - trange[0] = 0 - # trange[0] = np.maximum(0, trange[0]) - # trange[1] = np.minimum(max_t, trange[1]) - return trange - - -def plot_rasters(*args, **kwargs): - W = RasterWidget(*args, **kwargs) - W.plot() - return W - - -plot_rasters.__doc__ = RasterWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 5004765251..8814e0131a 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -32,10 +32,10 @@ def setUp(self): self.num_units = len(self._sorting.get_unit_ids()) #  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True) - if (cache_folder / "mearec_test").is_dir(): - self._we = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_old_api").is_dir(): + self._we = load_waveforms(cache_folder / "mearec_test_old_api") else: - self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test") + self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test_old_api", sparse=False) self._amplitudes = compute_spike_amplitudes(self._we, peak_sign="neg", outputs="by_unit") self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting) @@ -43,44 +43,6 @@ def setUp(self): def tearDown(self): pass - # def test_timeseries(self): - # sw.plot_timeseries(self._rec, mode='auto') - # sw.plot_timeseries(self._rec, mode='line', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True, order_channel_by_depth=True) - - def test_rasters(self): - sw.plot_rasters(self._sorting) - - def test_plot_probe_map(self): - sw.plot_probe_map(self._rec) - sw.plot_probe_map(self._rec, with_channel_ids=True) - - # TODO - # def test_spectrum(self): - # sw.plot_spectrum(self._rec) - - # TODO - # def test_spectrogram(self): - # sw.plot_spectrogram(self._rec, channel=0) - - # def test_unitwaveforms(self): - # w = sw.plot_unit_waveforms(self._we) - # unit_ids = self._sorting.unit_ids[:6] - # sw.plot_unit_waveforms(self._we, max_channels=5, unit_ids=unit_ids) - # sw.plot_unit_waveforms(self._we, radius_um=60, unit_ids=unit_ids) - - # def test_plot_unit_waveform_density_map(self): - # unit_ids = self._sorting.unit_ids[:3] - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=4) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=50) - # - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=25, same_axis=True) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=2, same_axis=True) - - # def test_unittemplates(self): - # sw.plot_unit_templates(self._we) - def test_plot_unit_probe_map(self): sw.plot_unit_probe_map(self._we, with_channel_ids=True) sw.plot_unit_probe_map(self._we, animated=True) @@ -120,12 +82,6 @@ def test_plot_peak_activity_map(self): sw.plot_peak_activity_map(self._rec, with_channel_ids=True) sw.plot_peak_activity_map(self._rec, bin_duration_s=1.0) - def test_confusion(self): - sw.plot_confusion_matrix(self._gt_comp, count_text=True) - - def test_agreement(self): - sw.plot_agreement_matrix(self._gt_comp, count_text=True) - def test_multicomp_graph(self): msc = sc.compare_multiple_sorters([self._sorting, self._sorting, self._sorting]) sw.plot_multicomp_graph(msc, edge_cmap="viridis", node_cmap="rainbow", draw_labels=False) @@ -150,8 +106,6 @@ def test_sorting_performance(self): mytest.setUp() # ~ mytest.test_timeseries() - # ~ mytest.test_rasters() - mytest.test_plot_probe_map() # ~ mytest.test_unitwaveforms() # ~ mytest.test_plot_unit_waveform_density_map() # mytest.test_unittemplates() @@ -169,8 +123,6 @@ def test_sorting_performance(self): # ~ mytest.test_plot_drift_over_time() # ~ mytest.test_plot_peak_activity_map() - # mytest.test_confusion() - # mytest.test_agreement() # ~ mytest.test_multicomp_graph() #  mytest.test_sorting_performance() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py deleted file mode 100644 index ab6fa2ace5..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py +++ /dev/null @@ -1,233 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from matplotlib.ticker import MaxNLocator -from .basewidget import BaseWidget - -import scipy.spatial - - -class TracesWidget(BaseWidget): - """ - Plots recording timeseries. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - segment_index: None or int - The segment index. - channel_ids: list - The channel ids to display. - order_channel_by_depth: boolean - Reorder channel by depth. - time_range: list - List with start time and end time - mode: 'line' or 'map' or 'auto' - 2 possible mode: - * 'line' : classical for low channel count - * 'map' : for high channel count use color heat map - * 'auto' : auto switch depending the channel count <32ch - cmap: str default 'RdBu' - matplotlib colormap used in mode 'map' - show_channel_ids: bool - Set yticks with channel ids - color_groups: bool - If True groups are plotted with different colors - color: matplotlib color, default: None - The color used to draw the traces. - clim: None or tupple - When mode='map' this control color lims - with_colorbar: bool default True - When mode='map' add colorbar - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: TracesWidget - The output widget - """ - - def __init__( - self, - recording, - segment_index=None, - channel_ids=None, - order_channel_by_depth=False, - time_range=None, - mode="auto", - cmap="RdBu", - show_channel_ids=False, - color_groups=False, - color=None, - clim=None, - with_colorbar=True, - figure=None, - ax=None, - **plot_kwargs, - ): - BaseWidget.__init__(self, figure, ax) - self.recording = recording - self._sampling_frequency = recording.get_sampling_frequency() - self.visible_channel_ids = channel_ids - self._plot_kwargs = plot_kwargs - - if segment_index is None: - nseg = recording.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 - self.segment_index = segment_index - - if self.visible_channel_ids is None: - self.visible_channel_ids = recording.get_channel_ids() - - if order_channel_by_depth: - locations = self.recording.get_channel_locations() - channel_inds = self.recording.ids_to_indices(self.visible_channel_ids) - locations = locations[channel_inds, :] - origin = np.array([np.max(locations[:, 0]), np.min(locations[:, 1])])[None, :] - dist = scipy.spatial.distance.cdist(locations, origin, metric="euclidean") - dist = dist[:, 0] - self.order = np.argsort(dist) - else: - self.order = None - - if channel_ids is None: - channel_ids = recording.get_channel_ids() - - fs = recording.get_sampling_frequency() - if time_range is None: - time_range = (0, 1.0) - time_range = np.array(time_range) - - assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map" - if mode == "auto": - if len(channel_ids) <= 64: - mode = "line" - else: - mode = "map" - self.mode = mode - self.cmap = cmap - - self.show_channel_ids = show_channel_ids - - self._frame_range = (time_range * fs).astype("int64") - a_max = self.recording.get_num_frames(segment_index=self.segment_index) - self._frame_range = np.clip(self._frame_range, 0, a_max) - self._time_range = [e / fs for e in self._frame_range] - - self.clim = clim - self.with_colorbar = with_colorbar - - self._initialize_stats() - - # self._vspacing = self._mean_channel_std * 20 - self._vspacing = self._max_channel_amp * 1.5 - - if recording.get_channel_groups() is None: - color_groups = False - - self._color_groups = color_groups - self._color = color - if color_groups: - self._colors = [] - self._group_color_map = {} - all_groups = recording.get_channel_groups() - groups = np.unique(all_groups) - N = len(groups) - import colorsys - - HSV_tuples = [(x * 1.0 / N, 0.5, 0.5) for x in range(N)] - self._colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)) - color_idx = 0 - for group in groups: - self._group_color_map[group] = color_idx - color_idx += 1 - self.name = "TimeSeries" - - def plot(self): - self._do_plot() - - def _do_plot(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - if self.order is not None: - chunk0 = chunk0[:, self.order] - self.visible_channel_ids = np.array(self.visible_channel_ids)[self.order] - - ax = self.ax - - n = len(self.visible_channel_ids) - - if self.mode == "line": - ax.set_xlim( - self._frame_range[0] / self._sampling_frequency, self._frame_range[1] / self._sampling_frequency - ) - ax.set_ylim(-self._vspacing, self._vspacing * n) - ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) - ax.get_yaxis().set_ticks([]) - ax.set_xlabel("time (s)") - - self._plots = {} - self._plot_offsets = {} - offset0 = self._vspacing * (n - 1) - times = np.arange(self._frame_range[0], self._frame_range[1]) / self._sampling_frequency - for im, m in enumerate(self.visible_channel_ids): - self._plot_offsets[m] = offset0 - if self._color_groups: - group = self.recording.get_channel_groups(channel_ids=[m])[0] - group_color_idx = self._group_color_map[group] - color = self._colors[group_color_idx] - else: - color = self._color - self._plots[m] = ax.plot(times, self._plot_offsets[m] + chunk0[:, im], color=color, **self._plot_kwargs) - offset0 = offset0 - self._vspacing - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) * self._vspacing) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - elif self.mode == "map": - extent = (self._time_range[0], self._time_range[1], 0, self.recording.get_num_channels()) - im = ax.imshow( - chunk0.T, interpolation="nearest", origin="upper", aspect="auto", extent=extent, cmap=self.cmap - ) - - if self.clim is None: - im.set_clim(-self._max_channel_amp, self._max_channel_amp) - else: - im.set_clim(*self.clim) - - if self.with_colorbar: - self.figure.colorbar(im, ax=ax) - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) + 0.5) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - def _initialize_stats(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - - self._mean_channel_std = np.mean(np.std(chunk0, axis=0)) - self._max_channel_amp = np.max(np.max(np.abs(chunk0), axis=0)) - - -def plot_timeseries(*args, **kwargs): - W = TracesWidget(*args, **kwargs) - W.plot() - return W - - -plot_timeseries.__doc__ = TracesWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py b/src/spikeinterface/widgets/agreement_matrix.py similarity index 53% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py rename to src/spikeinterface/widgets/agreement_matrix.py index 369746e99b..ec6ea1c87c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py +++ b/src/spikeinterface/widgets/agreement_matrix.py @@ -1,11 +1,13 @@ import numpy as np +from warnings import warn -from .basewidget import BaseWidget +from .base import BaseWidget, to_attr +from .utils import get_unit_colors class AgreementMatrixWidget(BaseWidget): """ - Plots sorting comparison confusion matrix. + Plots sorting comparison agreement matrix. Parameters ---------- @@ -19,31 +21,34 @@ class AgreementMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created + """ - def __init__(self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt + def __init__( + self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs + ): + plot_data = dict( + sorting_comparison=sorting_comparison, + ordered=ordered, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) - BaseWidget.__init__(self, figure, ax) - self._sc = sorting_comparison - self._ordered = ordered - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - def plot(self): - self._do_plot() + comp = dp.sorting_comparison - def _do_plot(self): - # a dataframe - if self._ordered: - scores = self._sc.get_ordered_agreement_scores() + if dp.ordered: + scores = comp.get_ordered_agreement_scores() else: - scores = self._sc.agreement_scores + scores = comp.agreement_scores N1 = scores.shape[0] N2 = scores.shape[1] @@ -54,9 +59,9 @@ def _do_plot(self): # Using matshow here just because it sets the ticks up nicely. imshow is faster. self.ax.matshow(scores.values, cmap="Greens") - if self._count_text: + if dp.count_text: for i, u1 in enumerate(unit_ids1): - u2 = self._sc.best_match_12[u1] + u2 = comp.best_match_12[u1] if u2 != -1: j = np.where(unit_ids2 == u2)[0][0] @@ -68,24 +73,15 @@ def _do_plot(self): self.ax.xaxis.tick_bottom() # Labels for major ticks - if self._unit_ticks: + if dp.unit_ticks: self.ax.set_yticklabels(scores.index, fontsize=12) self.ax.set_xticklabels(scores.columns, fontsize=12) - self.ax.set_xlabel(self._sc.name_list[1], fontsize=20) - self.ax.set_ylabel(self._sc.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) self.ax.set_xlim(-0.5, N2 - 0.5) self.ax.set_ylim( N1 - 0.5, -0.5, ) - - -def plot_agreement_matrix(*args, **kwargs): - W = AgreementMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_agreement_matrix.__doc__ = AgreementMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 7ef6e0ff61..6b6496a577 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -147,13 +147,16 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: bins = dp.bins ax_hist = self.axes.flatten()[1] - ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + # this is super slow, using plot and np.histogram is really much faster (and nicer!) + # ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + count, bins = np.histogram(amps, bins=bins) + ax_hist.plot(count, bins[:-1], color=dp.unit_colors[unit_id], alpha=0.8) if dp.plot_histograms: ax_hist = self.axes.flatten()[1] ax_hist.set_ylim(scatter_ax.get_ylim()) ax_hist.axis("off") - self.figure.tight_layout() + # self.figure.tight_layout() if dp.plot_legend: if hasattr(self, "legend") and self.legend is not None: @@ -171,9 +174,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - import ipywidgets.widgets as widgets + + # import ipywidgets.widgets as widgets + import ipywidgets.widgets as W from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -188,60 +193,63 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ratios = [0.15, 0.85] with plt.ioff(): - output = widgets.Output() + output = W.Output() with output: self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.unit_selector = UnitSelector(we.unit_ids) + self.unit_selector.value = list(we.unit_ids)[:1] - plot_histograms = widgets.Checkbox( + self.checkbox_histograms = W.Checkbox( value=data_plot["plot_histograms"], - description="plot histograms", - disabled=False, + description="hist", ) - footer = plot_histograms - - self.controller = {"plot_histograms": plot_histograms} - self.controller.update(unit_controller) - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + left_sidebar = W.VBox( + children=[ + self.unit_selector, + self.checkbox_histograms, + ], + layout=W.Layout(align_items="center", width="4cm", height="100%"), + ) - self.widget = widgets.AppLayout( + self.widget = W.AppLayout( center=self.figure.canvas, - left_sidebar=unit_widget, + left_sidebar=left_sidebar, pane_widths=ratios + [0], - footer=footer, ) # a first update - self._update_ipywidget(None) + self._full_update_plot() + + self.unit_selector.observe(self._update_plot, names="value", type="change") + self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _full_update_plot(self, change=None): self.figure.clear() + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False - unit_ids = self.controller["unit_ids"].value - plot_histograms = self.controller["plot_histograms"].value + backend_kwargs = dict(figure=self.figure, axes=None, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + self._update_plot() - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_histograms"] = plot_histograms + def _update_plot(self, change=None): + for ax in self.axes.flatten(): + ax.clear() - backend_kwargs = {} - # backend_kwargs["figure"] = self.fig - backend_kwargs["figure"] = self.figure - backend_kwargs["axes"] = None - backend_kwargs["ax"] = None + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False + backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) self.figure.canvas.draw() diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index dea46b8f51..9fc7b73707 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -38,13 +38,16 @@ def set_default_plotter_backend(backend): "width_cm": "Width of the figure in cm (default 10)", "height_cm": "Height of the figure in cm (default 6)", "display": "If True, widgets are immediately displayed", + # "controllers": "" }, + "ephyviewer": {}, } default_backend_kwargs = { "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, - "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True}, + "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True, "controllers": None}, + "ephyviewer": {}, } diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py b/src/spikeinterface/widgets/confusion_matrix.py similarity index 62% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py rename to src/spikeinterface/widgets/confusion_matrix.py index 942b613fbf..8eb58f30b2 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py +++ b/src/spikeinterface/widgets/confusion_matrix.py @@ -1,6 +1,8 @@ import numpy as np +from warnings import warn -from .basewidget import BaseWidget +from .base import BaseWidget, to_attr +from .utils import get_unit_colors class ConfusionMatrixWidget(BaseWidget): @@ -15,40 +17,35 @@ class ConfusionMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ConfusionMatrixWidget - The output widget + """ - def __init__(self, gt_comparison, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt + def __init__(self, gt_comparison, count_text=True, unit_ticks=True, backend=None, **backend_kwargs): + plot_data = dict( + gt_comparison=gt_comparison, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure - BaseWidget.__init__(self, figure, ax) - self._gtcomp = gt_comparison - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" + dp = to_attr(data_plot) - def plot(self): - self._do_plot() + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - def _do_plot(self): - # a dataframe - confusion_matrix = self._gtcomp.get_confusion_matrix() + comp = dp.gt_comparison + confusion_matrix = comp.get_confusion_matrix() N1 = confusion_matrix.shape[0] - 1 N2 = confusion_matrix.shape[1] - 1 # Using matshow here just because it sets the ticks up nicely. imshow is faster. self.ax.matshow(confusion_matrix.values, cmap="Greens") - if self._count_text: + if dp.count_text: for (i, j), z in np.ndenumerate(confusion_matrix.values): if z != 0: if z > np.max(confusion_matrix.values) / 2.0: @@ -65,27 +62,18 @@ def _do_plot(self): self.ax.xaxis.tick_bottom() # Labels for major ticks - if self._unit_ticks: + if dp.unit_ticks: self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) else: self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) - self.ax.set_xlabel(self._gtcomp.name_list[1], fontsize=20) - self.ax.set_ylabel(self._gtcomp.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) self.ax.set_xlim(-0.5, N2 + 0.5) self.ax.set_ylim( N1 + 0.5, -0.5, ) - - -def plot_confusion_matrix(*args, **kwargs): - W = ConfusionMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_confusion_matrix.__doc__ = ConfusionMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py new file mode 100644 index 0000000000..6a27b78dec --- /dev/null +++ b/src/spikeinterface/widgets/gtstudy.py @@ -0,0 +1,253 @@ +import numpy as np + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + +from ..core import ChannelSparsity +from ..core.waveform_extractor import WaveformExtractor +from ..core.basesorting import BaseSorting + + +class StudyRunTimesWidget(BaseWidget): + """ + Plot sorter run times for a GroundTruthStudy + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + run_times=study.get_run_times(case_keys), + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + for i, key in enumerate(dp.case_keys): + label = dp.study.cases[key]["label"] + rt = dp.run_times.loc[key] + self.ax.bar(i, rt, width=0.8, label=label) + + self.ax.legend() + + +# TODO : plot optionally average on some levels using group by +class StudyUnitCountsWidget(BaseWidget): + """ + Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + count_units=study.get_count_units(case_keys=case_keys), + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + columns = dp.count_units.columns.tolist() + columns.remove("num_gt") + columns.remove("num_sorter") + + ncol = len(columns) + + colors = get_some_colors(columns, color_engine="auto", map_name="hot") + colors["num_well_detected"] = "green" + + xticklabels = [] + for i, key in enumerate(dp.case_keys): + for c, col in enumerate(columns): + x = i + 1 + c / (ncol + 1) + y = dp.count_units.loc[key, col] + if not "well_detected" in col: + y = -y + + if i == 0: + label = col.replace("num_", "").replace("_", " ").title() + else: + label = None + + self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) + + xticklabels.append(dp.study.cases[key]["label"]) + + self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) + self.ax.set_xticklabels(xticklabels) + self.ax.legend() + + +# TODO : plot optionally average on some levels using group by +class StudyPerformances(BaseWidget): + """ + Plot performances over case for a study. + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + mode: str + Which mode in "swarm" + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + mode="swarm", + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + perfs=study.get_performance_by_unit(case_keys=case_keys), + mode=mode, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + import pandas as pd + import seaborn as sns + + dp = to_attr(data_plot) + perfs = dp.perfs + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + if dp.mode == "swarm": + levels = perfs.index.names + df = pd.melt( + perfs.reset_index(), + id_vars=levels, + var_name="Metric", + value_name="Score", + value_vars=("accuracy", "precision", "recall"), + ) + df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) + + +class StudyPerformancesVsMetrics(BaseWidget): + """ + Plot performances vs a metrics (snr for instance) over case for a study. + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + mode: str + Which mode in "swarm" + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + metric_name="snr", + performance_name="accuracy", + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + metric_name=metric_name, + performance_name=performance_name, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + dp = to_attr(data_plot) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + study = dp.study + perfs = study.get_performance_by_unit(case_keys=dp.case_keys) + + max_metric = 0 + for key in dp.case_keys: + x = study.get_metrics(key)[dp.metric_name].values + y = perfs.xs(key)[dp.performance_name].values + label = dp.study.cases[key]["label"] + self.ax.scatter(x, y, label=label) + max_metric = max(max_metric, np.max(x)) + + self.ax.legend() + self.ax.set_xlim(0, max_metric * 1.05) + self.ax.set_ylim(0, 1.05) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 9dc51f522e..c7b701c8b0 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -128,7 +128,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -147,34 +147,28 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): with output: self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - if data_plot["unit_ids"] is None: - data_plot["unit_ids"] = [] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids) + self.unit_selector.value = [] self.widget = widgets.AppLayout( center=self.figure.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update self._update_ipywidget(None) + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") + if backend_kwargs["display"]: display(self.widget) def _update_ipywidget(self, change): from matplotlib.lines import Line2D - unit_ids = self.controller["unit_ids"].value + unit_ids = self.unit_selector.value unit_colors = self.data_plot["unit_colors"] # matplotlib next_data_plot dict update at each call @@ -198,6 +192,7 @@ def _update_ipywidget(self, change): self.plot_matplotlib(self.data_plot, **backend_kwargs) if len(unit_ids) > 0: + # TODO later make option to control legend or not for l in self.figure.legends: l.remove() handles = [ diff --git a/src/spikeinterface/widgets/probe_map.py b/src/spikeinterface/widgets/probe_map.py new file mode 100644 index 0000000000..7fb74abd7c --- /dev/null +++ b/src/spikeinterface/widgets/probe_map.py @@ -0,0 +1,75 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs +from .utils import get_unit_colors + + +class ProbeMapWidget(BaseWidget): + """ + Plot the probe of a recording. + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object + channel_ids: list + The channel ids to display + with_channel_ids: bool False default + Add channel ids text on the probe + **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function + + """ + + def __init__( + self, recording, channel_ids=None, with_channel_ids=False, backend=None, **backend_or_plot_probe_kwargs + ): + # split backend_or_plot_probe_kwargs + backend_kwargs = dict() + plot_probe_kwargs = dict() + backend = self.check_backend(backend) + for k, v in backend_or_plot_probe_kwargs.items(): + if k in default_backend_kwargs[backend]: + backend_kwargs[k] = v + else: + plot_probe_kwargs[k] = v + + plot_data = dict( + recording=recording, + channel_ids=channel_ids, + with_channel_ids=with_channel_ids, + plot_probe_kwargs=plot_probe_kwargs, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from probeinterface.plotting import get_auto_lims, plot_probe + + dp = to_attr(data_plot) + + plot_probe_kwargs = dp.plot_probe_kwargs + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + probegroup = dp.recording.get_probegroup() + + xlims, ylims, zlims = get_auto_lims(probegroup.probes[0]) + for i, probe in enumerate(probegroup.probes): + xlims2, ylims2, _ = get_auto_lims(probe) + xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) + ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) + + plot_probe_kwargs["title"] = False + pos = 0 + text_on_contact = None + for i, probe in enumerate(probegroup.probes): + n = probe.get_contact_count() + if dp.with_channel_ids: + text_on_contact = dp.recording.channel_ids[pos : pos + n] + pos += n + plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **plot_probe_kwargs) + + self.ax.set_xlim(*xlims) + self.ax.set_ylim(*ylims) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py new file mode 100644 index 0000000000..4a1d76279f --- /dev/null +++ b/src/spikeinterface/widgets/rasters.py @@ -0,0 +1,84 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs + + +class RasterWidget(BaseWidget): + """ + Plots spike train rasters. + + Parameters + ---------- + sorting: SortingExtractor + The sorting extractor object + segment_index: None or int + The segment index. + unit_ids: list + List of unit ids + time_range: list + List with start time and end time + color: matplotlib color + The color to be used + """ + + def __init__( + self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", backend=None, **backend_kwargs + ): + if segment_index is None: + if sorting.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 + + if time_range is None: + frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]] + time_range = [f / sorting.sampling_frequency for f in frame_range] + else: + assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" + frame_range = [int(t * sorting.sampling_frequency) for t in time_range] + + plot_data = dict( + sorting=sorting, + segment_index=segment_index, + unit_ids=unit_ids, + color=color, + frame_range=frame_range, + time_range=time_range, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + sorting = dp.sorting + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + units_ids = dp.unit_ids + if units_ids is None: + units_ids = sorting.unit_ids + + with plt.rc_context({"axes.edgecolor": "gray"}): + for unit_index, unit_id in enumerate(units_ids): + spiketrain = sorting.get_unit_spike_train( + unit_id, + start_frame=dp.frame_range[0], + end_frame=dp.frame_range[1], + segment_index=dp.segment_index, + ) + spiketimes = spiketrain / float(sorting.sampling_frequency) + self.ax.plot( + spiketimes, + unit_index * np.ones_like(spiketimes), + marker="|", + mew=1, + markersize=3, + ls="", + color=dp.color, + ) + self.ax.set_yticks(np.arange(len(units_ids))) + self.ax.set_yticklabels(units_ids) + self.ax.set_xlim(*dp.time_range) + self.ax.set_xlabel("time (s)") diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 9771b2c0e9..fda2356105 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -191,7 +191,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -210,48 +210,36 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], - list(data_plot["unit_colors"].keys()), - ratios[0] * width_cm, - height_cm, - ) - - self.controller = unit_controller - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.widget = widgets.AppLayout( center=fig.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_all_units"] = True + # TODO add an option checkbox for legend data_plot["plot_legend"] = True data_plot["hide_axis"] = True - backend_kwargs = {} - backend_kwargs["ax"] = self.ax + backend_kwargs = dict(ax=self.ax) - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() fig.canvas.draw() diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index ae036d1ba1..b68efc3f8a 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -149,23 +149,19 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = we.sorting # first plot time series - ts_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) - self.ax = ts_widget.ax - self.axes = ts_widget.axes - self.figure = ts_widget.figure + traces_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) + self.ax = traces_widget.ax + self.axes = traces_widget.axes + self.figure = traces_widget.figure ax = self.ax - frame_range = ts_widget.data_plot["frame_range"] - segment_index = ts_widget.data_plot["segment_index"] - min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) - max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) + frame_range = traces_widget.data_plot["frame_range"] + segment_index = traces_widget.data_plot["segment_index"] + min_y = np.min(traces_widget.data_plot["channel_locations"][:, 1]) + max_y = np.max(traces_widget.data_plot["channel_locations"][:, 1]) - n = len(ts_widget.data_plot["channel_ids"]) - order = ts_widget.data_plot["order"] - - if order is None: - order = np.arange(n) + n = len(traces_widget.data_plot["channel_ids"]) if ax.get_legend() is not None: ax.get_legend().remove() @@ -210,21 +206,21 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # construct waveforms label_set = False if len(spike_frames_to_plot) > 0: - vspacing = ts_widget.data_plot["vspacing"] - traces = ts_widget.data_plot["list_traces"][0] + vspacing = traces_widget.data_plot["vspacing"] + traces = traces_widget.data_plot["list_traces"][0] waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) + waveform_idxs = np.clip(waveform_idxs, 0, len(traces_widget.data_plot["times"]) - 1) - times = ts_widget.data_plot["times"][waveform_idxs] + times = traces_widget.data_plot["times"][waveform_idxs] # discontinuity times[:, -1] = np.nan times_r = times.reshape(times.shape[0] * times.shape[1]) - waveforms = traces[waveform_idxs] # [:, :, order] + waveforms = traces[waveform_idxs] waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) - for i, chan_id in enumerate(ts_widget.data_plot["channel_ids"]): + for i, chan_id in enumerate(traces_widget.data_plot["channel_ids"]): offset = vspacing * i if chan_id in chan_ids: l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) @@ -232,13 +228,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): handles.append(l[0]) labels.append(unit) label_set = True - ax.legend(handles, labels) + # ax.legend(handles, labels) def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -256,37 +252,56 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - ts_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) - self.ax = ts_widget.ax - self.axes = ts_widget.axes - self.figure = ts_widget.figure + self._traces_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + self.ax = self._traces_widget.ax + self.axes = self._traces_widget.axes + self.figure = self._traces_widget.figure - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.sampling_frequency = self._traces_widget.rec0.sampling_frequency - self.controller = dict() - self.controller.update(ts_widget.controller) - self.controller.update(unit_controller) + self.time_slider = self._traces_widget.time_slider - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) + self.widget = widgets.AppLayout( + center=self._traces_widget.widget, left_sidebar=self.unit_selector, pane_widths=ratios + [0] + ) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + # remove callback from traces_widget + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.time_slider.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.channel_selector.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.scaler.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value + # TODO later: this is still a bit buggy because it make double refresh one from _traces_widget and one internal + + unit_ids = self.unit_selector.value + start_frame, end_frame, segment_index = self._traces_widget.time_slider.value + channel_ids = self._traces_widget.channel_selector.value + mode = self._traces_widget.mode_selector.value data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids + data_plot["options"].update( + dict( + channel_ids=channel_ids, + segment_index=segment_index, + # frame_range=(start_frame, end_frame), + time_range=np.array([start_frame, end_frame]) / self.sampling_frequency, + mode=mode, + with_colorbar=False, + ) + ) backend_kwargs = {} backend_kwargs["ax"] = self.ax diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index a5f75ebf50..1a2fdf38d9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -48,31 +48,32 @@ def setUpClass(cls): cls.sorting = se.MEArecSortingExtractor(local_path) cls.num_units = len(cls.sorting.get_unit_ids()) - if (cache_folder / "mearec_test").is_dir(): - cls.we = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_dense").is_dir(): + cls.we_dense = load_waveforms(cache_folder / "mearec_test_dense") else: - cls.we = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test") + cls.we_dense = extract_waveforms( + cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False + ) + metric_names = ["snr", "isi_violation", "num_spikes"] + _ = compute_spike_amplitudes(cls.we_dense) + _ = compute_unit_locations(cls.we_dense) + _ = compute_spike_locations(cls.we_dense) + _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) + _ = compute_template_metrics(cls.we_dense) + _ = compute_correlograms(cls.we_dense) + _ = compute_template_similarity(cls.we_dense) sw.set_default_plotter_backend("matplotlib") - metric_names = ["snr", "isi_violation", "num_spikes"] - _ = compute_spike_amplitudes(cls.we) - _ = compute_unit_locations(cls.we) - _ = compute_spike_locations(cls.we) - _ = compute_quality_metrics(cls.we, metric_names=metric_names) - _ = compute_template_metrics(cls.we) - _ = compute_correlograms(cls.we) - _ = compute_template_similarity(cls.we) - # make sparse waveforms - cls.sparsity_radius = compute_sparsity(cls.we, method="radius", radius_um=50) - cls.sparsity_best = compute_sparsity(cls.we, method="best_channels", num_channels=5) + cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50) + cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5) if (cache_folder / "mearec_test_sparse").is_dir(): cls.we_sparse = load_waveforms(cache_folder / "mearec_test_sparse") else: - cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) + cls.we_sparse = cls.we_dense.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) - cls.skip_backends = ["ipywidgets"] + cls.skip_backends = ["ipywidgets", "ephyviewer"] if ON_GITHUB and not KACHERY_CLOUD_SET: cls.skip_backends.append("sortingview") @@ -124,17 +125,17 @@ def test_plot_unit_waveforms(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_waveforms(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_waveforms(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_best, unit_ids=unit_ids, backend=backend, @@ -148,10 +149,10 @@ def test_plot_unit_templates(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_templates(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_templates(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_templates( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, @@ -171,7 +172,7 @@ def test_plot_unit_waveforms_density_map(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): @@ -180,7 +181,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, + self.we_dense, sparsity=self.sparsity_radius, same_axis=False, unit_ids=unit_ids, @@ -234,11 +235,15 @@ def test_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_amplitudes(self.we, backend=backend, **self.backend_kwargs[backend]) - unit_ids = self.we.unit_ids[:4] - sw.plot_amplitudes(self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) + sw.plot_amplitudes(self.we_dense, backend=backend, **self.backend_kwargs[backend]) + unit_ids = self.we_dense.unit_ids[:4] + sw.plot_amplitudes(self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) sw.plot_amplitudes( - self.we, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend] + self.we_dense, + unit_ids=unit_ids, + plot_histograms=True, + backend=backend, + **self.backend_kwargs[backend], ) sw.plot_amplitudes( self.we_sparse, @@ -252,9 +257,9 @@ def test_plot_all_amplitudes_distributions(self): possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - unit_ids = self.we.unit_ids[:4] + unit_ids = self.we_dense.unit_ids[:4] sw.plot_all_amplitudes_distributions( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) sw.plot_all_amplitudes_distributions( self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] @@ -264,7 +269,9 @@ def test_unit_locations(self): possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_unit_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -273,7 +280,9 @@ def test_spike_locations(self): possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_spike_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_spike_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_spike_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -282,28 +291,28 @@ def test_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_similarity(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_quality_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_quality_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_template_metrics(self): possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_depths(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_depths(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): @@ -311,19 +320,43 @@ def test_plot_unit_summary(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( - self.we, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_dense, self.we_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) sw.plot_unit_summary( - self.we_sparse, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_sparse, self.we_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) def test_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + def test_plot_agreement_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_agreement_matrix(self.gt_comp) + + def test_plot_confusion_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_confusion_matrix(self.gt_comp) + + def test_plot_probe_map(self): + possible_backends = list(sw.ProbeMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_probe_map(self.recording, with_channel_ids=True, with_contact_id=True) + + def test_plot_rasters(self): + possible_backends = list(sw.RasterWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_rasters(self.sorting) + if __name__ == "__main__": # unittest.main() @@ -344,7 +377,11 @@ def test_sorting_summary(self): # mytest.test_unit_locations() # mytest.test_quality_metrics() # mytest.test_template_metrics() - mytest.test_amplitudes() + # mytest.test_amplitudes() + mytest.test_plot_agreement_matrix() + # mytest.test_plot_confusion_matrix() + # mytest.test_plot_probe_map() + # mytest.test_plot_rasters() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index e025f779c1..fc8b30eb05 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,26 +88,32 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - layer_keys = list(recordings.keys()) + if rec0.has_channel_location(): + channel_locations = rec0.get_channel_locations() + else: + channel_locations = None - if segment_index is None: - if rec0.get_num_segments() != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 + if order_channel_by_depth and channel_locations is not None: + from ..preprocessing import depth_order + + rec0 = depth_order(rec0) + recordings = {k: depth_order(rec) for k, rec in recordings.items()} + + if channel_ids is not None: + # ensure that channel_ids are in the good order + channel_ids_ = list(rec0.channel_ids) + order = np.argsort([channel_ids_.index(c) for c in channel_ids]) + channel_ids = list(np.array(channel_ids)[order]) if channel_ids is None: channel_ids = rec0.channel_ids - if "location" in rec0.get_property_keys(): - channel_locations = rec0.get_channel_locations() - else: - channel_locations = None + layer_keys = list(recordings.keys()) - if order_channel_by_depth: - if channel_locations is not None: - order, _ = order_channels_by_depth(rec0, channel_ids) - else: - order = None + if segment_index is None: + if rec0.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 fs = rec0.get_sampling_frequency() if time_range is None: @@ -124,7 +130,7 @@ def __init__( cmap = cmap times, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, order, return_scaled + recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled ) # stat for auto scaling done on the first layer @@ -138,9 +144,10 @@ def __init__( # colors is a nested dict by layer and channels # lets first create black for all channels and layer + # all color are generated for ipywidgets colors = {} for k in layer_keys: - colors[k] = {chan_id: "k" for chan_id in channel_ids} + colors[k] = {chan_id: "k" for chan_id in rec0.channel_ids} if color_groups: channel_groups = rec0.get_channel_groups(channel_ids=channel_ids) @@ -149,7 +156,7 @@ def __init__( group_colors = get_some_colors(groups, color_engine="auto") channel_colors = {} - for i, chan_id in enumerate(channel_ids): + for i, chan_id in enumerate(rec0.channel_ids): group = channel_groups[i] channel_colors[chan_id] = group_colors[group] @@ -159,12 +166,12 @@ def __init__( elif color is not None: # old behavior one color for all channel # if multi layer then black for all - colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids} + colors[layer_keys[0]] = {chan_id: color for chan_id in rec0.channel_ids} elif color is None and len(recordings) > 1: # several layer layer_colors = get_some_colors(layer_keys) for k in layer_keys: - colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids} + colors[k] = {chan_id: layer_colors[k] for chan_id in rec0.channel_ids} else: # color is None unique layer : all channels black pass @@ -201,7 +208,6 @@ def __init__( show_channel_ids=show_channel_ids, add_legend=add_legend, order_channel_by_depth=order_channel_by_depth, - order=order, tile_size=tile_size, num_timepoints_per_row=int(seconds_per_row * fs), return_scaled=return_scaled, @@ -276,22 +282,26 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display + import ipywidgets.widgets as W from .utils_ipywidgets import ( check_ipywidget_backend, - make_timeseries_controller, - make_channel_controller, - make_scale_controller, + # make_timeseries_controller, + # make_channel_controller, + # make_scale_controller, + TimeSlider, + ChannelSelector, + ScaleWidget, ) check_ipywidget_backend() self.next_data_plot = data_plot.copy() - self.next_data_plot["add_legend"] = False - recordings = data_plot["recordings"] + self.recordings = data_plot["recordings"] # first layer - rec0 = recordings[data_plot["layer_keys"][0]] + # rec0 = recordings[data_plot["layer_keys"][0]] + rec0 = self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] cm = 1 / 2.54 @@ -305,182 +315,153 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.figure, self.ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) plt.show() - t_start = 0.0 - t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() - - ts_widget, ts_controller = make_timeseries_controller( - t_start, - t_stop, - data_plot["layer_keys"], - rec0.get_num_segments(), - data_plot["time_range"], - data_plot["mode"], - False, - width_cm, + # some widgets + self.time_slider = TimeSlider( + durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], + sampling_frequency=rec0.sampling_frequency, + # layout=W.Layout(height="2cm"), ) - ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) + start_frame = int(data_plot["time_range"][0] * rec0.sampling_frequency) + end_frame = int(data_plot["time_range"][1] * rec0.sampling_frequency) - scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) + self.time_slider.value = start_frame, end_frame, data_plot["segment_index"] - self.controller = ts_controller - self.controller.update(ch_controller) - self.controller.update(scale_controller) + _layer_keys = data_plot["layer_keys"] + if len(_layer_keys) > 1: + _layer_keys = ["ALL"] + _layer_keys + self.layer_selector = W.Dropdown( + options=_layer_keys, + layout=W.Layout(width="95%"), + ) + self.mode_selector = W.Dropdown( + options=["line", "map"], + value=data_plot["mode"], + # layout=W.Layout(width="5cm"), + layout=W.Layout(width="95%"), + ) + self.scaler = ScaleWidget() + self.channel_selector = ChannelSelector(self.rec0.channel_ids) + self.channel_selector.value = list(data_plot["channel_ids"]) + + left_sidebar = W.VBox( + children=[ + W.Label(value="layer"), + self.layer_selector, + W.Label(value="mode"), + self.mode_selector, + self.scaler, + # self.channel_selector, + ], + layout=W.Layout(width="3.5cm"), + align_items="center", + ) - self.recordings = data_plot["recordings"] self.return_scaled = data_plot["return_scaled"] - self.list_traces = None - self.actual_segment_index = self.controller["segment_index"].value - - self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] - self.t_stops = [ - self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() - for seg_index in range(self.rec0.get_num_segments()) - ] - - for w in self.controller.values(): - if isinstance(w, widgets.Button): - w.on_click(self._update_ipywidget) - else: - w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, - footer=ts_widget, - left_sidebar=scale_widget, - right_sidebar=ch_widget, + footer=self.time_slider, + left_sidebar=left_sidebar, + right_sidebar=self.channel_selector, pane_heights=[0, 6, 1], pane_widths=ratios, ) # a first update - self._update_ipywidget(None) + self._retrieve_traces() + self._update_plot() + + # callbacks: + # some widgets generate a full retrieve + refresh + self.time_slider.observe(self._retrieve_traces, names="value", type="change") + self.layer_selector.observe(self._retrieve_traces, names="value", type="change") + self.channel_selector.observe(self._retrieve_traces, names="value", type="change") + # other widgets only refresh + self.scaler.observe(self._update_plot, names="value", type="change") + # map is a special case because needs to check layer also + self.mode_selector.observe(self._mode_changed, names="value", type="change") if backend_kwargs["display"]: # self.check_backend() display(self.widget) - def _update_ipywidget(self, change): - import ipywidgets.widgets as widgets + def _get_layers(self): + layer = self.layer_selector.value + if layer == "ALL": + layer_keys = self.data_plot["layer_keys"] + else: + layer_keys = [layer] + if self.mode_selector.value == "map": + layer_keys = layer_keys[:1] + return layer_keys + + def _mode_changed(self, change=None): + if self.mode_selector.value == "map" and self.layer_selector.value == "ALL": + self.layer_selector.value = self.data_plot["layer_keys"][0] + else: + self._update_plot() - # if changing the layer_key, no need to retrieve and process traces - retrieve_traces = True - scale_up = False - scale_down = False - if change is not None: - for cname, c in self.controller.items(): - if isinstance(change, dict): - if change["owner"] is c and cname == "layer_key": - retrieve_traces = False - elif isinstance(change, widgets.Button): - if change is c and cname == "plus": - scale_up = True - if change is c and cname == "minus": - scale_down = True - - t_start = self.controller["t_start"].value - window = self.controller["window"].value - layer_key = self.controller["layer_key"].value - segment_index = self.controller["segment_index"].value - mode = self.controller["mode"].value - chan_start, chan_stop = self.controller["channel_inds"].value + def _retrieve_traces(self, change=None): + channel_ids = np.array(self.channel_selector.value) - if mode == "line": - self.controller["all_layers"].layout.visibility = "visible" - all_layers = self.controller["all_layers"].value - elif mode == "map": - self.controller["all_layers"].layout.visibility = "hidden" - all_layers = False + # if self.data_plot["order_channel_by_depth"]: + # order, _ = order_channels_by_depth(self.rec0, channel_ids) + # else: + # order = None - if all_layers: - self.controller["layer_key"].layout.visibility = "hidden" - else: - self.controller["layer_key"].layout.visibility = "visible" + start_frame, end_frame, segment_index = self.time_slider.value + time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency - if chan_start == chan_stop: - chan_stop += 1 - channel_indices = np.arange(chan_start, chan_stop) + self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} + times, list_traces, frame_range, channel_ids = _get_trace_list( + self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled + ) - t_stop = self.t_stops[segment_index] - if self.actual_segment_index != segment_index: - # change time_slider limits - self.controller["t_start"].max = t_stop - self.actual_segment_index = segment_index + self._channel_ids = channel_ids + self._list_traces = list_traces + self._times = times + self._time_range = time_range + self._frame_range = (start_frame, end_frame) + self._segment_index = segment_index - # protect limits - if t_start >= t_stop - window: - t_start = t_stop - window + self._update_plot() - time_range = np.array([t_start, t_start + window]) + def _update_plot(self, change=None): data_plot = self.next_data_plot - if retrieve_traces: - all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - if self.data_plot["order"] is not None: - all_channel_ids = all_channel_ids[self.data_plot["order"]] - channel_ids = all_channel_ids[channel_indices] - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None - times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled - ) - self.list_traces = list_traces - else: - times = data_plot["times"] - list_traces = data_plot["list_traces"] - frame_range = data_plot["frame_range"] - channel_ids = data_plot["channel_ids"] + # matplotlib next_data_plot dict update at each call + mode = self.mode_selector.value + layer_keys = self._get_layers() - if all_layers: - layer_keys = self.data_plot["layer_keys"] - recordings = self.recordings - list_traces_plot = self.list_traces - else: - layer_keys = [layer_key] - recordings = {layer_key: self.recordings[layer_key]} - list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] - - if scale_up: - if mode == "line": - data_plot["vspacing"] *= 0.8 - elif mode == "map": - data_plot["clims"] = { - layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() - } - if scale_down: - if mode == "line": - data_plot["vspacing"] *= 1.2 - elif mode == "map": - data_plot["clims"] = { - layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() - } - - self.next_data_plot["vspacing"] = data_plot["vspacing"] - self.next_data_plot["clims"] = data_plot["clims"] + data_plot["mode"] = mode + data_plot["frame_range"] = self._frame_range + data_plot["time_range"] = self._time_range + data_plot["with_colorbar"] = False + data_plot["recordings"] = self._selected_recordings + data_plot["add_legend"] = False if mode == "line": clims = None elif mode == "map": - clims = {layer_key: self.data_plot["clims"][layer_key]} + clims = {k: self.data_plot["clims"][k] for k in layer_keys} - # matplotlib next_data_plot dict update at each call - data_plot["mode"] = mode - data_plot["frame_range"] = frame_range - data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - data_plot["recordings"] = recordings - data_plot["layer_keys"] = layer_keys - data_plot["list_traces"] = list_traces_plot - data_plot["times"] = times data_plot["clims"] = clims - data_plot["channel_ids"] = channel_ids + data_plot["channel_ids"] = self._channel_ids + + data_plot["layer_keys"] = layer_keys + data_plot["colors"] = {k: self.data_plot["colors"][k] for k in layer_keys} + + list_traces = [traces * self.scaler.value for traces in self._list_traces] + data_plot["list_traces"] = list_traces + data_plot["times"] = self._times backend_kwargs = {} backend_kwargs["ax"] = self.ax + self.ax.clear() self.plot_matplotlib(data_plot, **backend_kwargs) + self.ax.set_title("") fig = self.ax.figure fig.canvas.draw() @@ -524,8 +505,32 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **backend_kwargs) + def plot_ephyviewer(self, data_plot, **backend_kwargs): + import ephyviewer + from ..preprocessing import depth_order + + dp = to_attr(data_plot) + + app = ephyviewer.mkQApp() + win = ephyviewer.MainViewer(debug=False, show_auto_scale=True) + + for k, rec in dp.recordings.items(): + if dp.order_channel_by_depth: + rec = depth_order(rec, flip=True) + + sig_source = ephyviewer.SpikeInterfaceRecordingSource(recording=rec) + view = ephyviewer.TraceViewer(source=sig_source, name=k) + view.params["scale_mode"] = "by_channel" + if dp.show_channel_ids: + view.params["display_labels"] = True + view.auto_scale() + win.add_view(view) -def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): + win.show() + app.exec() + + +def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] rec0 = recordings[k0] @@ -552,11 +557,6 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No return_scaled=return_scaled, ) - if order is not None: - traces = traces[:, order] list_traces.append(traces) - if order is not None: - channel_ids = np.array(channel_ids)[order] - return times, list_traces, frame_range, channel_ids diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 42267e711f..b41ee3508b 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -167,7 +167,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -186,42 +186,35 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.widget = widgets.AppLayout( center=fig.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_all_units"] = True + # TODO later add an option checkbox for legend data_plot["plot_legend"] = True data_plot["hide_axis"] = True - backend_kwargs = {} - backend_kwargs["ax"] = self.ax + backend_kwargs = dict(ax=self.ax) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index e64765b44b..8ffc931bf2 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -250,7 +250,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -274,44 +274,32 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.fig_probe, self.ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] - same_axis_button = widgets.Checkbox( + self.same_axis_button = widgets.Checkbox( value=False, description="same axis", disabled=False, ) - plot_templates_button = widgets.Checkbox( + self.plot_templates_button = widgets.Checkbox( value=True, description="plot templates", disabled=False, ) - hide_axis_button = widgets.Checkbox( + self.hide_axis_button = widgets.Checkbox( value=True, description="hide axis", disabled=False, ) - footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) - - self.controller = { - "same_axis": same_axis_button, - "plot_templates": plot_templates_button, - "hide_axis": hide_axis_button, - } - self.controller.update(unit_controller) - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + footer = widgets.HBox([self.same_axis_button, self.plot_templates_button, self.hide_axis_button]) self.widget = widgets.AppLayout( center=self.fig_wf.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, right_sidebar=self.fig_probe.canvas, pane_widths=ratios, footer=footer, @@ -320,6 +308,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") + for w in self.same_axis_button, self.plot_templates_button, self.hide_axis_button: + w.observe(self._update_ipywidget, names="value", type="change") + if backend_kwargs["display"]: display(self.widget) @@ -327,10 +319,15 @@ def _update_ipywidget(self, change): self.fig_wf.clear() self.ax_probe.clear() - unit_ids = self.controller["unit_ids"].value - same_axis = self.controller["same_axis"].value - plot_templates = self.controller["plot_templates"].value - hide_axis = self.controller["hide_axis"].value + # unit_ids = self.controller["unit_ids"].value + unit_ids = self.unit_selector.value + # same_axis = self.controller["same_axis"].value + # plot_templates = self.controller["plot_templates"].value + # hide_axis = self.controller["hide_axis"].value + + same_axis = self.same_axis_button.value + plot_templates = self.plot_templates_button.value + hide_axis = self.hide_axis_button.value # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot @@ -342,6 +339,8 @@ def _update_ipywidget(self, change): if data_plot["plot_waveforms"]: data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} + # TODO option for plot_legend + backend_kwargs = {} if same_axis: @@ -369,6 +368,7 @@ def _update_ipywidget(self, change): self.ax_probe.axis("off") self.ax_probe.axis("equal") + # TODO this could be done with probeinterface plotting plotting tools!! for unit in unit_ids: channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] self.ax_probe.plot( diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index e8a6868e92..b3391c0712 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -103,7 +103,7 @@ def __init__( if same_axis and not np.array_equal(chan_inds, shared_chan_inds): # add more channels if necessary wfs_ = np.zeros((wfs.shape[0], wfs.shape[1], shared_chan_inds.size), dtype=float) - mask = np.in1d(shared_chan_inds, chan_inds) + mask = np.isin(shared_chan_inds, chan_inds) wfs_[:, :, mask] = wfs wfs_[:, :, ~mask] = np.nan wfs = wfs_ diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index a7c571d1f0..58dd5c7f32 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -1,4 +1,6 @@ -import ipywidgets.widgets as widgets +import ipywidgets.widgets as W +import traitlets + import numpy as np @@ -9,96 +11,349 @@ def check_ipywidget_backend(): assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" -def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): - time_slider = widgets.FloatSlider( - orientation="horizontal", - description="time:", - value=time_range[0], - min=t_start, - max=t_stop, - continuous_update=False, - layout=widgets.Layout(width=f"{width_cm}cm"), - ) - layer_selector = widgets.Dropdown(description="layer", options=layer_keys) - segment_selector = widgets.Dropdown(description="segment", options=list(range(num_segments))) - window_sizer = widgets.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") - mode_selector = widgets.Dropdown(options=["line", "map"], description="mode", value=mode) - all_layers = widgets.Checkbox(description="plot all layers", value=all_layers) - - controller = { - "layer_key": layer_selector, - "segment_index": segment_selector, - "window": window_sizer, - "t_start": time_slider, - "mode": mode_selector, - "all_layers": all_layers, - } - widget = widgets.VBox( - [time_slider, widgets.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] - ) - - return widget, controller - - -def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): - unit_label = widgets.Label(value="units:") - - unit_selector = widgets.SelectMultiple( - options=all_unit_ids, - value=list(unit_ids), - disabled=False, - layout=widgets.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), - ) - - controller = {"unit_ids": unit_selector} - widget = widgets.VBox([unit_label, unit_selector]) - - return widget, controller - - -def make_channel_controller(recording, width_cm, height_cm): - channel_label = widgets.Label("channel indices:", layout=widgets.Layout(justify_content="center")) - channel_selector = widgets.IntRangeSlider( - value=[0, recording.get_num_channels()], - min=0, - max=recording.get_num_channels(), - step=1, - disabled=False, - continuous_update=False, - orientation="vertical", - readout=True, - readout_format="d", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), - ) - - controller = {"channel_inds": channel_selector} - widget = widgets.VBox([channel_label, channel_selector]) - - return widget, controller - - -def make_scale_controller(width_cm, height_cm): - scale_label = widgets.Label("Scale", layout=widgets.Layout(justify_content="center")) - - plus_selector = widgets.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Increase scale", - icon="arrow-up", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) - - minus_selector = widgets.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Decrease scale", - icon="arrow-down", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) - - controller = {"plus": plus_selector, "minus": minus_selector} - widget = widgets.VBox([scale_label, plus_selector, minus_selector]) - - return widget, controller +class TimeSlider(W.HBox): + value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) + + def __init__(self, durations, sampling_frequency, time_range=(0, 1.0), **kwargs): + self.num_segments = len(durations) + self.frame_limits = [int(sampling_frequency * d) for d in durations] + self.sampling_frequency = sampling_frequency + start_frame = int(time_range[0] * sampling_frequency) + end_frame = int(time_range[1] * sampling_frequency) + + self.frame_range = (start_frame, end_frame) + + self.segment_index = 0 + self.value = (start_frame, end_frame, self.segment_index) + + layout = W.Layout(align_items="center", width="2.5cm", height="1.cm") + but_left = W.Button(description="", disabled=False, button_style="", icon="arrow-left", layout=layout) + but_right = W.Button(description="", disabled=False, button_style="", icon="arrow-right", layout=layout) + + but_left.on_click(self.move_left) + but_right.on_click(self.move_right) + + self.move_size = W.Dropdown( + options=[ + "10 ms", + "100 ms", + "1 s", + "10 s", + "1 m", + "30 m", + "1 h", + ], # '6 h', '24 h' + value="1 s", + description="", + layout=W.Layout(width="2cm"), + ) + + # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) + self.time_label = W.Text( + value=f"{time_range[0]}", description="", disabled=False, layout=W.Layout(width="2.5cm") + ) + self.time_label.observe(self.time_label_changed, names="value", type="change") + + self.slider = W.IntSlider( + orientation="horizontal", + # description='time:', + value=start_frame, + min=0, + max=self.frame_limits[self.segment_index] - 1, + readout=False, + continuous_update=False, + layout=W.Layout(width=f"70%"), + ) + + self.slider.observe(self.slider_moved, names="value", type="change") + + delta_s = np.diff(self.frame_range) / sampling_frequency + + self.window_sizer = W.BoundedFloatText( + value=delta_s, + step=1, + min=0.01, + max=30.0, + description="win (s)", + layout=W.Layout(width="auto") + # layout=W.Layout(width=f'10%') + ) + self.window_sizer.observe(self.win_size_changed, names="value", type="change") + + self.segment_selector = W.Dropdown(description="segment", options=list(range(self.num_segments))) + self.segment_selector.observe(self.segment_changed, names="value", type="change") + + super(W.HBox, self).__init__( + children=[ + self.segment_selector, + but_left, + self.move_size, + but_right, + self.slider, + self.time_label, + self.window_sizer, + ], + layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) + + self.observe(self.value_changed, names=["value"], type="change") + + def value_changed(self, change=None): + self.unobserve(self.value_changed, names=["value"], type="change") + + start, stop, seg_index = self.value + if seg_index < 0 or seg_index >= self.num_segments: + self.value = change["old"] + return + if start < 0 or stop < 0: + self.value = change["old"] + return + if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: + self.value = change["old"] + return + + self.segment_selector.value = seg_index + self.update_time(new_frame=start, update_slider=True, update_label=True) + delta_s = (stop - start) / self.sampling_frequency + self.window_sizer.value = delta_s + + self.observe(self.value_changed, names=["value"], type="change") + + def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): + if new_frame is None and new_time is None: + start_frame = self.slider.value + elif new_frame is None: + start_frame = int(new_time * self.sampling_frequency) + else: + start_frame = new_frame + delta_s = self.window_sizer.value + delta = int(delta_s * self.sampling_frequency) + + # clip + start_frame = min(self.frame_limits[self.segment_index] - delta, start_frame) + start_frame = max(0, start_frame) + end_frame = start_frame + delta + + end_frame = min(self.frame_limits[self.segment_index], end_frame) + + start_time = start_frame / self.sampling_frequency + + if update_label: + self.time_label.unobserve(self.time_label_changed, names="value", type="change") + self.time_label.value = f"{start_time}" + self.time_label.observe(self.time_label_changed, names="value", type="change") + + if update_slider: + self.slider.unobserve(self.slider_moved, names="value", type="change") + self.slider.value = start_frame + self.slider.observe(self.slider_moved, names="value", type="change") + + self.frame_range = (start_frame, end_frame) + self.value = (start_frame, end_frame, self.segment_index) + + def time_label_changed(self, change=None): + try: + new_time = float(self.time_label.value) + except: + new_time = None + if new_time is not None: + self.update_time(new_time=new_time, update_slider=True) + + def win_size_changed(self, change=None): + self.update_time() + + def slider_moved(self, change=None): + new_frame = self.slider.value + self.update_time(new_frame=new_frame, update_label=True) + + def move(self, sign): + value, units = self.move_size.value.split(" ") + value = int(value) + delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, "s") + delta_sample = int(delta_s * self.sampling_frequency) + + new_frame = self.frame_range[0] + delta_sample + self.slider.value = new_frame + + def move_left(self, change=None): + self.move(-1) + + def move_right(self, change=None): + self.move(+1) + + def segment_changed(self, change=None): + self.segment_index = self.segment_selector.value + + self.slider.unobserve(self.slider_moved, names="value", type="change") + # self.slider.value = 0 + self.slider.max = self.frame_limits[self.segment_index] - 1 + self.slider.observe(self.slider_moved, names="value", type="change") + + self.update_time(new_frame=0, update_slider=True, update_label=True) + + +class ChannelSelector(W.VBox): + value = traitlets.List() + + def __init__(self, channel_ids, **kwargs): + self.channel_ids = list(channel_ids) + self.value = self.channel_ids + + channel_label = W.Label("Channels", layout=W.Layout(justify_content="center")) + n = len(channel_ids) + self.slider = W.IntRangeSlider( + value=[0, n], + min=0, + max=n, + step=1, + disabled=False, + continuous_update=False, + orientation="vertical", + readout=True, + readout_format="d", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(height="100%"), + ) + + # first channel are bottom: need reverse + self.selector = W.SelectMultiple( + options=self.channel_ids[::-1], + value=self.channel_ids[::-1], + disabled=False, + # layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(height="100%", width="2cm"), + ) + hbox = W.HBox(children=[self.slider, self.selector]) + + super(W.VBox, self).__init__( + children=[channel_label, hbox], + layout=W.Layout(align_items="center"), + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + self.observe(self.value_changed, names=["value"], type="change") + + def on_slider_changed(self, change=None): + i0, i1 = self.slider.value + + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = self.channel_ids[i0:i1][::-1] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + self.value = self.channel_ids[i0:i1] + + def on_selector_changed(self, change=None): + channel_ids = self.selector.value + channel_ids = channel_ids[::-1] + + if len(channel_ids) > 0: + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + + self.value = channel_ids + + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + channel_ids = self.selector.value + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + + +class ScaleWidget(W.VBox): + value = traitlets.Float() + + def __init__(self, value=1.0, factor=1.2, **kwargs): + assert factor > 1.0 + self.factor = factor + + self.scale_label = W.Label("Scale", layout=W.Layout(layout=W.Layout(width="95%"), justify_content="center")) + + self.plus_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Increase scale", + icon="arrow-up", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width="60%", align_self="center"), + ) + + self.minus_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Decrease scale", + icon="arrow-down", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width="60%", align_self="center"), + ) + + self.plus_selector.on_click(self.plus_clicked) + self.minus_selector.on_click(self.minus_clicked) + + self.value = 1.0 + super(W.VBox, self).__init__( + children=[self.plus_selector, self.scale_label, self.minus_selector], + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) + + self.update_label() + self.observe(self.value_changed, names=["value"], type="change") + + def update_label(self): + self.scale_label.value = f"Scale: {self.value:0.2f}" + + def plus_clicked(self, change=None): + self.value = self.value * self.factor + + def minus_clicked(self, change=None): + self.value = self.value / self.factor + + def value_changed(self, change=None): + self.update_label() + + +class UnitSelector(W.VBox): + value = traitlets.List() + + def __init__(self, unit_ids, **kwargs): + self.unit_ids = list(unit_ids) + self.value = self.unit_ids + + label = W.Label("Units", layout=W.Layout(justify_content="center")) + + self.selector = W.SelectMultiple( + options=self.unit_ids, + value=self.unit_ids, + disabled=False, + layout=W.Layout(height="100%", width="2cm"), + ) + + super(W.VBox, self).__init__(children=[label, self.selector], layout=W.Layout(align_items="center"), **kwargs) + + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + self.observe(self.value_changed, names=["value"], type="change") + + def on_selector_changed(self, change=None): + unit_ids = self.selector.value + self.value = unit_ids + + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 9c89b3981e..51e7208080 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -2,12 +2,16 @@ from .base import backend_kwargs_desc +from .agreement_matrix import AgreementMatrixWidget from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget +from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .motion import MotionWidget +from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget +from .rasters import RasterWidget from .sorting_summary import SortingSummaryWidget from .spike_locations import SpikeLocationsWidget from .spikes_on_traces import SpikesOnTracesWidget @@ -20,15 +24,20 @@ from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget +from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics widget_list = [ + AgreementMatrixWidget, AllAmplitudesDistributionsWidget, AmplitudesWidget, AutoCorrelogramsWidget, + ConfusionMatrixWidget, CrossCorrelogramsWidget, MotionWidget, + ProbeMapWidget, QualityMetricsWidget, + RasterWidget, SortingSummaryWidget, SpikeLocationsWidget, SpikesOnTracesWidget, @@ -41,6 +50,10 @@ UnitTemplatesWidget, UnitWaveformDensityMapWidget, UnitWaveformsWidget, + StudyRunTimesWidget, + StudyUnitCountsWidget, + StudyPerformances, + StudyPerformancesVsMetrics, ] @@ -76,12 +89,16 @@ # make function for all widgets +plot_agreement_matrix = AgreementMatrixWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget +plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_motion = MotionWidget +plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget +plot_rasters = RasterWidget plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget plot_spikes_on_traces = SpikesOnTracesWidget @@ -94,6 +111,10 @@ plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_waveforms = UnitWaveformsWidget +plot_study_run_times = StudyRunTimesWidget +plot_study_unit_counts = StudyUnitCountsWidget +plot_study_performances = StudyPerformances +plot_study_performances_vs_metrics = StudyPerformancesVsMetrics def plot_timeseries(*args, **kwargs):