From 2090c24e5fee3f6b51066e7820d3ba223c4b47fb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 17 Jan 2024 17:57:48 +0100 Subject: [PATCH 1/3] Add option to spikesort by group --- .../spikesorting/params.py | 1 + .../spikesorting/spikesorting.py | 47 ++++++++++++++----- tests/test_pipeline.py | 44 +++++++++++++---- 3 files changed, 70 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index 627687f..2a4070a 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -62,6 +62,7 @@ class MountainSort5Model(BaseModel): class SpikeSortingParams(BaseModel): sorter_name: SorterName = Field(default="kilosort2_5", description="Name of the sorter to use.") + spikesort_by_group: bool = Field(default=False, description="If True, spike sorting is run for each group separately.") sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field( default=Kilosort25Model(), description="Sorter specific kwargs." ) diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py index d9243c1..900748e 100644 --- a/src/spikeinterface_pipelines/spikesorting/spikesorting.py +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -1,6 +1,8 @@ from __future__ import annotations -from pathlib import Path import shutil +import numpy as np +from pathlib import Path + import spikeinterface.full as si import spikeinterface.curation as sc @@ -38,21 +40,34 @@ def spikesort( try: logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter") + ## TEST ONLY - REMOVE LATER ## # si.get_default_sorter_params('kilosort2_5') # params_kilosort2_5 = {'do_correction': False} ## --------------------------## - sorting = si.run_sorter( - recording=recording, - sorter_name=spikesorting_params.sorter_name, - output_folder=str(output_folder), - verbose=True, - delete_output_folder=True, - remove_existing_folder=True, - **spikesorting_params.sorter_kwargs.model_dump(), - # **params_kilosort2_5 - ) + if spikesorting_params.spikesort_by_group and len(np.unique(recording.get_channel_groups())) > 1: + logger.info(f"[Spikesorting] \tSorting by channel groups") + sorting = si.run_sorter_by_property( + recording=recording, + sorter_name=spikesorting_params.sorter_name, + grouping_property="group", + working_folder=str(output_folder), + verbose=True, + delete_output_folder=True, + remove_existing_folder=True, + **spikesorting_params.sorter_kwargs.model_dump(), + ) + else: + sorting = si.run_sorter( + recording=recording, + sorter_name=spikesorting_params.sorter_name, + output_folder=str(output_folder), + verbose=True, + delete_output_folder=True, + remove_existing_folder=True, + **spikesorting_params.sorter_kwargs.model_dump(), + ) logger.info(f"[Spikesorting] \tFound {len(sorting.unit_ids)} raw units") # remove spikes beyond num_Samples (if any) sorting = sc.remove_excess_spikes(sorting=sorting, recording=recording) @@ -62,8 +77,14 @@ def spikesort( except Exception as e: # save log to results results_folder.mkdir(exist_ok=True, parents=True) - if (output_folder).is_dir(): - shutil.copy(output_folder / "spikeinterface_log.json", results_folder) + if not spikesorting_params.spikesort_by_group: + if (output_folder).is_dir(): + shutil.copy(output_folder / "spikeinterface_log.json", results_folder) + shutil.rmtree(output_folder) + else: + for group_folder in output_folder.iterdir(): + if group_folder.is_dir(): + shutil.copy(group_folder / "spikeinterface_log.json", results_folder / group_folder.name) shutil.rmtree(output_folder) logger.info(f"Spike sorting error:\n{e}") return None diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5cf5ce1..5edcc54 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -22,7 +22,7 @@ def _generate_gt_recording(): - recording, sorting = si.generate_ground_truth_recording(durations=[30], num_channels=64, seed=0) + recording, sorting = si.generate_ground_truth_recording(durations=[15], num_channels=128, seed=0) # add inter sample shift (but fake) inter_sample_shifts = np.zeros(recording.get_num_channels()) recording.set_property("inter_sample_shift", inter_sample_shifts) @@ -62,15 +62,41 @@ def test_spikesorting(tmp_path, generate_recording): results_folder = Path(tmp_path) / "results_spikesorting" scratch_folder = Path(tmp_path) / "scratch_spikesorting" + ks25_params = Kilosort25Model(do_correction=False) + spikesorting_params = SpikeSortingParams( + sorter_name="kilosort2_5", + sorter_kwargs=ks25_params, + ) + sorting = spikesort( recording=recording, - spikesorting_params=SpikeSortingParams(), + spikesorting_params=spikesorting_params, results_folder=results_folder, scratch_folder=scratch_folder, ) assert isinstance(sorting, si.BaseSorting) + # by group + num_channels = recording.get_num_channels() + groups = [0] * (num_channels // 2) + [1] * (num_channels // 2) + recording.set_channel_groups(groups) + + spikesorting_params = SpikeSortingParams( + sorter_name="kilosort2_5", + sorter_kwargs=ks25_params, + spikesort_by_group=True, + ) + sorting_group = spikesort( + recording=recording, + spikesorting_params=spikesorting_params, + results_folder=results_folder, + scratch_folder=scratch_folder, + ) + + assert isinstance(sorting_group, si.BaseSorting) + assert "group" in sorting_group.get_property_keys() + def test_postprocessing(tmp_path, generate_recording): recording, sorting, _ = generate_recording @@ -160,13 +186,13 @@ def test_pipeline(tmp_path, generate_recording): recording, sorting, waveform_extractor = _generate_gt_recording() # print("TEST PREPROCESSING") - # test_preprocessing(tmp_folder, (recording, sorting)) - # print("TEST SPIKESORTING") - # test_spikesorting(tmp_folder, (recording, sorting)) + # test_preprocessing(tmp_folder, (recording, sorting, waveform_extractor)) + print("TEST SPIKESORTING") + test_spikesorting(tmp_folder, (recording, sorting, waveform_extractor)) # print("TEST POSTPROCESSING") - # test_postprocessing(tmp_folder, (recording, sorting)) - print("TEST CURATION") - test_curation(tmp_folder, (recording, sorting, waveform_extractor)) + # test_postprocessing(tmp_folder, (recording, sorting, waveform_extractor)) + # print("TEST CURATION") + # test_curation(tmp_folder, (recording, sorting, waveform_extractor)) # print("TEST PIPELINE") - # test_pipeline(tmp_folder, (recording, sorting)) + # test_pipeline(tmp_folder, (recording, sorting, waveform_extractor)) From 062cc695f8635babef0c4d5e87324384cc9bb891 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 14 Mar 2024 16:49:31 +0100 Subject: [PATCH 2/3] Fix is_extension --- src/spikeinterface_pipelines/curation/curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface_pipelines/curation/curation.py b/src/spikeinterface_pipelines/curation/curation.py index 0f941f4..c02c775 100644 --- a/src/spikeinterface_pipelines/curation/curation.py +++ b/src/spikeinterface_pipelines/curation/curation.py @@ -38,7 +38,7 @@ def curate( Curated sorting """ # get quality metrics - if not waveform_extractor.is_extension("quality_metrics"): + if not waveform_extractor.has_extension("quality_metrics"): logger.info(f"[Curation] \tQuality metrics not found in WaveformExtractor.") return From 81b13e8988342c52a02b2e88007b2b8b585310a8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 16 Mar 2024 09:36:51 +0100 Subject: [PATCH 3/3] Fix visualize tests --- .../visualization/visualization.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface_pipelines/visualization/visualization.py b/src/spikeinterface_pipelines/visualization/visualization.py index 71cd873..9e0fa33 100644 --- a/src/spikeinterface_pipelines/visualization/visualization.py +++ b/src/spikeinterface_pipelines/visualization/visualization.py @@ -74,14 +74,17 @@ def visualize( decimation_factor = recording_params["drift"]["decimation_factor"] alpha = recording_params["drift"]["alpha"] - # use spike locations - if not waveform_extractor.has_extension("quality_metrics"): - logger.info("[Visualization] \tVisualizing drift maps using pre-computed spike locations") - peaks = waveform_extractor.sorting.to_spike_vector() - peak_locations = waveform_extractor.load_extension("spike_locations").get_data() - peak_amps = np.concatenate(waveform_extractor.load_extension("spike_amplitudes").get_data()) + # check if spike locations are available + spike_locations_available = False + if waveform_extractor is not None: + if waveform_extractor.has_extension("spike_locations"): + logger.info("[Visualization] \tVisualizing drift maps using pre-computed spike locations") + peaks = waveform_extractor.sorting.to_spike_vector() + peak_locations = waveform_extractor.load_extension("spike_locations").get_data() + peak_amps = np.concatenate(waveform_extractor.load_extension("spike_amplitudes").get_data()) + spike_locations_available = True # otherwise detect peaks - else: + if not spike_locations_available: from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass