diff --git a/.gitignore b/.gitignore index 3ee3cb8867..7838213bed 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,4 @@ test_folder/ # Mac OS .DS_Store +test_data.json diff --git a/README.md b/README.md index 55f33d04b1..883dcdb944 100644 --- a/README.md +++ b/README.md @@ -59,15 +59,17 @@ With SpikeInterface, users can: - post-process sorted datasets. - compare and benchmark spike sorting outputs. - compute quality metrics to validate and curate spike sorting outputs. -- visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, in jupyter) -- export report and export to phy -- offer a powerful Qt-based viewer in separate package [spikeinterface-gui](https://github.com/SpikeInterface/spikeinterface-gui) -- have some powerful sorting components to build your own sorter. +- visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, jupyter, ephyviewer) +- export a report and/or export to phy +- offer a powerful Qt-based viewer in a separate package [spikeinterface-gui](https://github.com/SpikeInterface/spikeinterface-gui) +- have powerful sorting components to build your own sorter. ## Documentation -Detailed documentation for spikeinterface can be found [here](https://spikeinterface.readthedocs.io/en/latest). +Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.98.2). + +Detailed documentation of the development version of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/latest). Several tutorials to get started can be found in [spiketutorials](https://github.com/SpikeInterface/spiketutorials). @@ -77,9 +79,9 @@ and sorting components. You can also have a look at the [spikeinterface-gui](https://github.com/SpikeInterface/spikeinterface-gui). -## How to install spikeinteface +## How to install spikeinterface -You can install the new `spikeinterface` version with pip: +You can install the latest version of `spikeinterface` version with pip: ```bash pip install spikeinterface[full] @@ -94,7 +96,7 @@ To install all interactive widget backends, you can use: ``` -To get the latest updates, you can install `spikeinterface` from sources: +To get the latest updates, you can install `spikeinterface` from source: ```bash git clone https://github.com/SpikeInterface/spikeinterface.git diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index e12d83810a..54a66c0890 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -54,7 +54,7 @@ Use the following Python script to load the binary data into SpikeInterface: dtype = "float64" # MATLAB's double corresponds to Python's float64 # Load data using SpikeInterface - recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, + recording = si.read_binary(file_paths=file_path, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype) # Confirm that the data was loaded correctly by comparing the data shapes and see they match the MATLAB data @@ -86,7 +86,7 @@ If your data in MATLAB is stored as :code:`int16`, and you know the gain and off gain_to_uV = 0.195 # Adjust according to your MATLAB dataset offset_to_uV = 0 # Adjust according to your MATLAB dataset - recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, + recording = si.read_binary(file_paths=file_path, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype_int, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 6101b81552..23e9e20d96 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -24,21 +24,21 @@ The merging and splitting operations are handled by the :py:class:`~spikeinterfa from spikeinterface.curation import CurationSorting - sorting = run_sorter('kilosort2', recording) + sorting = run_sorter(sorter_name='kilosort2', recording=recording) - cs = CurationSorting(sorting) + cs = CurationSorting(parent_sorting=sorting) # make a first merge - cs.merge(['#1', '#5', '#15']) + cs.merge(units_to_merge=['#1', '#5', '#15']) # make a second merge - cs.merge(['#11', '#21']) + cs.merge(units_to_merge=['#11', '#21']) # make a split split_index = ... # some criteria on spikes - cs.split('#20', split_index) + cs.split(split_unit_id='#20', indices_list=split_index) - # here the final clean sorting + # here is the final clean sorting clean_sorting = cs.sorting @@ -60,12 +60,12 @@ merges. Therefore, it has many parameters and options. from spikeinterface.curation import MergeUnitsSorting, get_potential_auto_merge - sorting = run_sorter('kilosort', recording) + sorting = run_sorter(sorter_name='kilosort', recording=recording) - we = extract_waveforms(recording, sorting, folder='wf_folder') + we = extract_waveforms(recording=recording, sorting=sorting, folder='wf_folder') # merges is a list of lists, with unit_ids to be merged. - merges = get_potential_auto_merge(we, minimum_spikes=1000, maximum_distance_um=150., + merges = get_potential_auto_merge(waveform_extractor=we, minimum_spikes=1000, maximum_distance_um=150., peak_sign="neg", bin_ms=0.25, window_ms=100., corr_diff_thresh=0.16, template_diff_thresh=0.25, censored_period_ms=0., refractory_period_ms=1.0, @@ -73,7 +73,7 @@ merges. Therefore, it has many parameters and options. firing_contamination_balance=1.5) # here we apply the merges - clean_sorting = MergeUnitsSorting(sorting, merges) + clean_sorting = MergeUnitsSorting(parent_sorting=sorting, units_to_merge=merges) Manual curation with sorting view @@ -98,24 +98,24 @@ The manual curation (including merges and labels) can be applied to a SpikeInter from spikeinterface.widgets import plot_sorting_summary # run a sorter and export waveforms - sorting = run_sorter('kilosort2', recording) - we = extract_waveforms(recording, sorting, folder='wf_folder') + sorting = run_sorter(sorter_name'kilosort2', recording=recording) + we = extract_waveforms(recording=recording, sorting=sorting, folder='wf_folder') # some postprocessing is required - _ = compute_spike_amplitudes(we) - _ = compute_unit_locations(we) - _ = compute_template_similarity(we) - _ = compute_correlograms(we) + _ = compute_spike_amplitudes(waveform_extractor=we) + _ = compute_unit_locations(waveform_extractor=we) + _ = compute_template_similarity(waveform_extractor=we) + _ = compute_correlograms(waveform_extractor=we) # This loads the data to the cloud for web-based plotting and sharing - plot_sorting_summary(we, curation=True, backend='sortingview') + plot_sorting_summary(waveform_extractor=we, curation=True, backend='sortingview') # we open the printed link URL in a browswe # - make manual merges and labeling # - from the curation box, click on "Save as snapshot (sha1://)" # copy the uri sha_uri = "sha1://59feb326204cf61356f1a2eb31f04d8e0177c4f1" - clean_sorting = apply_sortingview_curation(sorting, uri_or_json=sha_uri) + clean_sorting = apply_sortingview_curation(sorting=sorting, uri_or_json=sha_uri) Note that you can also "Export as JSON" and pass the json file as :code:`uri_or_json` parameter. diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index fa637f898b..155050ddb0 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -28,15 +28,14 @@ The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:` from spikeinterface.exporters import export_to_phy # the waveforms are sparse so it is faster to export to phy - folder = 'waveforms' - we = extract_waveforms(recording, sorting, folder, sparse=True) + we = extract_waveforms(recording=recording, sorting=sorting, folder='waveforms', sparse=True) # some computations are done before to control all options - compute_spike_amplitudes(we) - compute_principal_components(we, n_components=3, mode='by_channel_global') + compute_spike_amplitudes(waveform_extractor=we) + compute_principal_components(waveform_extractor=we, n_components=3, mode='by_channel_global') # the export process is fast because everything is pre-computed - export_to_phy(we, output_folder='path/to/phy_folder') + export_to_phy(wavefor_extractor=we, output_folder='path/to/phy_folder') @@ -72,12 +71,12 @@ with many units! # the waveforms are sparse for more interpretable figures - we = extract_waveforms(recording, sorting, folder='path/to/wf', sparse=True) + we = extract_waveforms(recording=recording, sorting=sorting, folder='path/to/wf', sparse=True) # some computations are done before to control all options - compute_spike_amplitudes(we) - compute_correlograms(we) - compute_quality_metrics(we, metric_names=['snr', 'isi_violation', 'presence_ratio']) + compute_spike_amplitudes(waveform_extractor=we) + compute_correlograms(waveform_extractor=we) + compute_quality_metrics(waveform_extractor=we, metric_names=['snr', 'isi_violation', 'presence_ratio']) # the export process - export_report(we, output_folder='path/to/spikeinterface-report-folder') + export_report(waveform_extractor=we, output_folder='path/to/spikeinterface-report-folder') diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index 5aed24ca41..2d0e047672 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -13,11 +13,12 @@ Most of the :code:`Recording` classes are implemented by wrapping the Most of the :code:`Sorting` classes are instead directly implemented in SpikeInterface. - Although SpikeInterface is object-oriented (class-based), each object can also be loaded with a convenient :code:`read_XXXXX()` function. +.. code-block:: python + import spikeinterface.extractors as se Read one Recording @@ -27,32 +28,34 @@ Every format can be read with a simple function: .. code-block:: python - recording_oe = read_openephys("open-ephys-folder") + recording_oe = read_openephys(folder_path="open-ephys-folder") - recording_spikeglx = read_spikeglx("spikeglx-folder") + recording_spikeglx = read_spikeglx(folder_path="spikeglx-folder") - recording_blackrock = read_blackrock("blackrock-folder") + recording_blackrock = read_blackrock(folder_path="blackrock-folder") - recording_mearec = read_mearec("mearec_file.h5") + recording_mearec = read_mearec(file_path="mearec_file.h5") Importantly, some formats directly handle the probe information: .. code-block:: python - recording_spikeglx = read_spikeglx("spikeglx-folder") + recording_spikeglx = read_spikeglx(folder_path="spikeglx-folder") print(recording_spikeglx.get_probe()) - recording_mearec = read_mearec("mearec_file.h5") + recording_mearec = read_mearec(file_path="mearec_file.h5") print(recording_mearec.get_probe()) + + Read one Sorting ---------------- .. code-block:: python - sorting_KS = read_kilosort("kilosort-folder") + sorting_KS = read_kilosort(folder_path="kilosort-folder") Read one Event @@ -60,7 +63,7 @@ Read one Event .. code-block:: python - events_OE = read_openephys_event("open-ephys-folder") + events_OE = read_openephys_event(folder_path="open-ephys-folder") For a comprehensive list of compatible technologies, see :ref:`compatible_formats`. @@ -77,7 +80,7 @@ The actual reading will be done on demand using the :py:meth:`~spikeinterface.co .. code-block:: python # opening a 40GB SpikeGLX dataset is fast - recording_spikeglx = read_spikeglx("spikeglx-folder") + recording_spikeglx = read_spikeglx(folder_path="spikeglx-folder") # this really does load the full 40GB into memory : not recommended!!!!! traces = recording_spikeglx.get_traces(start_frame=None, end_frame=None, return_scaled=False) diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index afedc4f982..8934ae1ff6 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -77,12 +77,12 @@ We currently have 3 presets: .. code-block:: python # read and preprocess - rec = read_spikeglx('/my/Neuropixel/recording') - rec = bandpass_filter(rec) - rec = common_reference(rec) + rec = read_spikeglx(folder_path='/my/Neuropixel/recording') + rec = bandpass_filter(recording=rec) + rec = common_reference(recording=rec) # then correction is one line of code - rec_corrected = correct_motion(rec, preset="nonrigid_accurate") + rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate") The process is quite long due the two first steps (activity profile + motion inference) But the return :code:`rec_corrected` is a lazy recording object that will interpolate traces on the @@ -94,20 +94,20 @@ If you want to user other presets, this is as easy as: .. code-block:: python # mimic kilosort motion - rec_corrected = correct_motion(rec, preset="kilosort_like") + rec_corrected = correct_motion(recording=rec, preset="kilosort_like") # super but less accurate and rigid - rec_corrected = correct_motion(rec, preset="rigid_fast") + rec_corrected = correct_motion(recording=rec, preset="rigid_fast") Optionally any parameter from the preset can be overwritten: .. code-block:: python - rec_corrected = correct_motion(rec, preset="nonrigid_accurate", + rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate", detect_kwargs=dict( detect_threshold=10.), - estimate_motion_kwargs=dic( + estimate_motion_kwargs=dict( histogram_depth_smooth_um=8., time_horizon_s=120., ), @@ -123,7 +123,7 @@ and checking. The folder will contain the motion vector itself of course but als .. code-block:: python motion_folder = '/somewhere/to/save/the/motion' - rec_corrected = correct_motion(rec, preset="nonrigid_accurate", folder=motion_folder) + rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate", folder=motion_folder) # and then motion_info = load_motion_info(motion_folder) @@ -156,14 +156,16 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte job_kwargs = dict(chunk_duration="1s", n_jobs=20, progress_bar=True) # Step 1 : activity profile - peaks = detect_peaks(rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs) + peaks = detect_peaks(recording=rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs) # (optional) sub-select some peaks to speed up the localization - peaks = select_peaks(peaks, ...) - peak_locations = localize_peaks(rec, peaks, method="monopolar_triangulation",radius_um=75.0, + peaks = select_peaks(peaks=peaks, ...) + peak_locations = localize_peaks(recording=rec, peaks=peaks, method="monopolar_triangulation",radius_um=75.0, max_distance_um=150.0, **job_kwargs) # Step 2: motion inference - motion, temporal_bins, spatial_bins = estimate_motion(rec, peaks, peak_locations, + motion, temporal_bins, spatial_bins = estimate_motion(recording=rec, + peaks=peaks, + peak_locations=peak_locations, method="decentralized", direction="y", bin_duration_s=2.0, @@ -173,7 +175,9 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte # Step 3: motion interpolation # this step is lazy - rec_corrected = interpolate_motion(rec, motion, temporal_bins, spatial_bins, + rec_corrected = interpolate_motion(recording=rec, motion=motion, + temporal_bins=temporal_bins, + spatial_bins=spatial_bins, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=30.) @@ -196,20 +200,20 @@ different preprocessing chains: one for motion correction and one for spike sort .. code-block:: python - raw_rec = read_spikeglx(...) + raw_rec = read_spikeglx(folder_path='/spikeglx_folder') # preprocessing 1 : bandpass (this is smoother) + cmr - rec1 = si.bandpass_filter(raw_rec, freq_min=300., freq_max=5000.) - rec1 = si.common_reference(rec1, reference='global', operator='median') + rec1 = si.bandpass_filter(recording=raw_rec, freq_min=300., freq_max=5000.) + rec1 = si.common_reference(recording=rec1, reference='global', operator='median') # here the corrected recording is done on the preprocessing 1 # rec_corrected1 will not be used for sorting! motion_folder = '/my/folder' - rec_corrected1 = correct_motion(rec1, preset="nonrigid_accurate", folder=motion_folder) + rec_corrected1 = correct_motion(recording=rec1, preset="nonrigid_accurate", folder=motion_folder) # preprocessing 2 : highpass + cmr - rec2 = si.highpass_filter(raw_rec, freq_min=300.) - rec2 = si.common_reference(rec2, reference='global', operator='median') + rec2 = si.highpass_filter(recording=raw_rec, freq_min=300.) + rec2 = si.common_reference(recording=rec2, reference='global', operator='median') # we use another preprocessing for the final interpolation motion_info = load_motion_info(motion_folder) @@ -220,7 +224,7 @@ different preprocessing chains: one for motion correction and one for spike sort spatial_bins=motion_info['spatial_bins'], **motion_info['parameters']['interpolate_motion_kwargs']) - sorting = run_sorter("montainsort5", rec_corrected2) + sorting = run_sorter(sorter_name="montainsort5", recording=rec_corrected2) References diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index a560f4d5c9..112c6e367d 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -14,9 +14,9 @@ WaveformExtractor extensions There are several postprocessing tools available, and all of them are implemented as a :py:class:`~spikeinterface.core.BaseWaveformExtractorExtension`. All computations on top -of a WaveformExtractor will be saved along side the WaveformExtractor itself (sub folder, zarr path or sub dict). +of a :code:`WaveformExtractor` will be saved along side the :code:`WaveformExtractor` itself (sub folder, zarr path or sub dict). This workflow is convenient for retrieval of time-consuming computations (such as pca or spike amplitudes) when reloading a -WaveformExtractor. +:code:`WaveformExtractor`. :py:class:`~spikeinterface.core.BaseWaveformExtractorExtension` objects are tightly connected to the parent :code:`WaveformExtractor` object, so that operations done on the :code:`WaveformExtractor`, such as saving, @@ -80,9 +80,9 @@ This extension computes the principal components of the waveforms. There are sev * "by_channel_local" (default): fits one PCA model for each by_channel * "by_channel_global": fits the same PCA model to all channels (also termed temporal PCA) -* "concatenated": contatenates all channels and fits a PCA model on the concatenated data +* "concatenated": concatenates all channels and fits a PCA model on the concatenated data -If the input :code:`WaveformExtractor` is sparse, the sparsity is used when computing PCA. +If the input :code:`WaveformExtractor` is sparse, the sparsity is used when computing the PCA. For dense waveforms, sparsity can also be passed as an argument. For more information, see :py:func:`~spikeinterface.postprocessing.compute_principal_components` @@ -127,7 +127,7 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_locations` -unit locations +unit_locations ^^^^^^^^^^^^^^ diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 7c1f33f298..67f1e52011 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -22,8 +22,8 @@ In this code example, we build a preprocessing chain with two steps: import spikeinterface.preprocessing import bandpass_filter, common_reference # recording is a RecordingExtractor object - recording_f = bandpass_filter(recording, freq_min=300, freq_max=6000) - recording_cmr = common_reference(recording_f, operator="median") + recording_f = bandpass_filter(recording=recording, freq_min=300, freq_max=6000) + recording_cmr = common_reference(recording=recording_f, operator="median") These two preprocessors will not compute anything at instantiation, but the computation will be "on-demand" ("on-the-fly") when getting traces. @@ -38,7 +38,7 @@ save the object: .. code-block:: python # here the spykingcircus2 sorter engine directly uses the lazy "recording_cmr" object - sorting = run_sorter(recording_cmr, 'spykingcircus2') + sorting = run_sorter(recording=recording_cmr, sorter_name='spykingcircus2') Most of the external sorters, however, will need a binary file as input, so we can optionally save the processed recording with the efficient SpikeInterface :code:`save()` function: @@ -64,12 +64,13 @@ dtype (unless specified otherwise): .. code-block:: python + import spikeinterface.extractors as se # spikeGLX is int16 - rec_int16 = read_spikeglx("my_folder") + rec_int16 = se.read_spikeglx(folder_path"my_folder") # by default the int16 is kept - rec_f = bandpass_filter(rec_int16, freq_min=300, freq_max=6000) + rec_f = bandpass_filter(recording=rec_int16, freq_min=300, freq_max=6000) # we can force a float32 casting - rec_f2 = bandpass_filter(rec_int16, freq_min=300, freq_max=6000, dtype='float32') + rec_f2 = bandpass_filter(recording=rec_int16, freq_min=300, freq_max=6000, dtype='float32') Some scaling pre-processors, such as :code:`whiten()` or :code:`zscore()`, will force the output to :code:`float32`. @@ -83,6 +84,8 @@ The full list of preprocessing functions can be found here: :ref:`api_preprocess Here is a full list of possible preprocessing steps, grouped by type of processing: +For all examples :code:`rec` is a :code:`RecordingExtractor`. + filter() / bandpass_filter() / notch_filter() / highpass_filter() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -98,7 +101,7 @@ Important aspects of filtering functions: .. code-block:: python - rec_f = bandpass_filter(rec, freq_min=300, freq_max=6000) + rec_f = bandpass_filter(recording=rec, freq_min=300, freq_max=6000) * :py:func:`~spikeinterface.preprocessing.filter()` @@ -119,7 +122,7 @@ There are various options when combining :code:`operator` and :code:`reference` .. code-block:: python - rec_cmr = common_reference(rec, operator="median", reference="global") + rec_cmr = common_reference(recording=rec, operator="median", reference="global") * :py:func:`~spikeinterface.preprocessing.common_reference()` @@ -144,8 +147,8 @@ difference on artifact removal. .. code-block:: python - rec_shift = phase_shift(rec) - rec_cmr = common_reference(rec_shift, operator="median", reference="global") + rec_shift = phase_shift(recording=rec) + rec_cmr = common_reference(recording=rec_shift, operator="median", reference="global") @@ -168,7 +171,7 @@ centered with unitary variance on each channel. .. code-block:: python - rec_normed = zscore(rec) + rec_normed = zscore(recording=rec) * :py:func:`~spikeinterface.preprocessing.normalize_by_quantile()` * :py:func:`~spikeinterface.preprocessing.scale()` @@ -186,7 +189,7 @@ The whitened traces are then the dot product between the traces and the :code:`W .. code-block:: python - rec_w = whiten(rec) + rec_w = whiten(recording=rec) * :py:func:`~spikeinterface.preprocessing.whiten()` @@ -199,7 +202,7 @@ The :code:`blank_staturation()` function is similar, but it automatically estima .. code-block:: python - rec_w = clip(rec, a_min=-250., a_max=260) + rec_w = clip(recording=rec, a_min=-250., a_max=260) * :py:func:`~spikeinterface.preprocessing.clip()` * :py:func:`~spikeinterface.preprocessing.blank_staturation()` @@ -234,11 +237,11 @@ interpolated with the :code:`interpolate_bad_channels()` function (channels labe .. code-block:: python # detect - bad_channel_ids, channel_labels = detect_bad_channels(rec) + bad_channel_ids, channel_labels = detect_bad_channels(recording=rec) # Case 1 : remove then - rec_clean = recording.remove_channels(bad_channel_ids) + rec_clean = recording.remove_channels(remove_channel_ids=bad_channel_ids) # Case 2 : interpolate then - rec_clean = interpolate_bad_channels(rec, bad_channel_ids) + rec_clean = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids) * :py:func:`~spikeinterface.preprocessing.detect_bad_channels()` @@ -257,13 +260,13 @@ remove_artifacts() Given an external list of trigger times, :code:`remove_artifacts()` function can remove artifacts with several strategies: -* replace with zeros (blank) -* make a linear or cubic interpolation -* remove the median or average template (with optional time jitter and amplitude scaling correction) +* replace with zeros (blank) :code:`'zeros'` +* make a linear (:code:`'linear'`) or cubic (:code:`'cubic'`) interpolation +* remove the median (:code:`'median'`) or average (:code:`'avereage'`) template (with optional time jitter and amplitude scaling correction) .. code-block:: python - rec_clean = remove_artifacts(rec, list_triggers) + rec_clean = remove_artifacts(recording=rec, list_triggers=[100, 200, 300], mode='zeros') * :py:func:`~spikeinterface.preprocessing.remove_artifacts()` @@ -276,7 +279,7 @@ Similarly to :code:`numpy.astype()`, the :code:`astype()` casts the traces to th .. code-block:: python - rec_int16 = astype(rec_float, "int16") + rec_int16 = astype(recording=rec_float, dtype="int16") For recordings whose traces are unsigned (e.g. Maxwell Biosystems), the :code:`unsigned_to_signed()` function makes them @@ -286,7 +289,7 @@ is subtracted, and the traces are finally cast to :code:`int16`: .. code-block:: python - rec_int16 = unsigned_to_signed(rec_uint16) + rec_int16 = unsigned_to_signed(recording=rec_uint16) * :py:func:`~spikeinterface.preprocessing.astype()` * :py:func:`~spikeinterface.preprocessing.unsigned_to_signed()` @@ -300,7 +303,7 @@ required. .. code-block:: python - rec_with_more_channels = zero_channel_pad(rec, 128) + rec_with_more_channels = zero_channel_pad(parent_recording=rec, num_channels=128) * :py:func:`~spikeinterface.preprocessing.zero_channel_pad()` @@ -331,7 +334,7 @@ How to implement "IBL destriping" or "SpikeGLX CatGT" in SpikeInterface SpikeGLX has a built-in function called `CatGT `_ to apply some preprocessing on the traces to remove noise and artifacts. IBL also has a standardized pipeline for preprocessed traces a bit similar to CatGT which is called "destriping" [IBL_spikesorting]_. -In these both cases, the traces are entiely read, processed and written back to a file. +In both these cases, the traces are entirely read, processed and written back to a file. SpikeInterface can reproduce similar results without the need to write back to a file by building a *lazy* preprocessing chain. Optionally, the result can still be written to a binary (or a zarr) file. @@ -341,12 +344,12 @@ Here is a recipe to mimic the **IBL destriping**: .. code-block:: python - rec = read_spikeglx('my_spikeglx_folder') - rec = highpass_filter(rec, n_channel_pad=60) - rec = phase_shift(rec) - bad_channel_ids = detect_bad_channels(rec) - rec = interpolate_bad_channels(rec, bad_channel_ids) - rec = highpass_spatial_filter(rec) + rec = read_spikeglx(folder_path='my_spikeglx_folder') + rec = highpass_filter(recording=rec, n_channel_pad=60) + rec = phase_shift(recording=rec) + bad_channel_ids = detect_bad_channels(recording=rec) + rec = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids) + rec = highpass_spatial_filter(recording=rec) # optional rec.save(folder='clean_traces', n_jobs=10, chunk_duration='1s', progres_bar=True) @@ -356,9 +359,9 @@ Here is a recipe to mimic the **SpikeGLX CatGT**: .. code-block:: python - rec = read_spikeglx('my_spikeglx_folder') - rec = phase_shift(rec) - rec = common_reference(rec, operator="median", reference="global") + rec = read_spikeglx(folder_path='my_spikeglx_folder') + rec = phase_shift(recording=rec) + rec = common_reference(recording=rec, operator="median", reference="global") # optional rec.save(folder='clean_traces', n_jobs=10, chunk_duration='1s', progres_bar=True) @@ -369,7 +372,6 @@ Of course, these pipelines can be enhanced and customized using other available - Preprocessing on Snippets ------------------------- diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index 447d83db52..ec1788350f 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -47,16 +47,16 @@ This code snippet shows how to compute quality metrics (with or without principa .. code-block:: python - we = si.load_waveforms(...) # start from a waveform extractor + we = si.load_waveforms(folder='waveforms') # start from a waveform extractor # without PC - metrics = compute_quality_metrics(we, metric_names=['snr']) + metrics = compute_quality_metrics(waveform_extractor=we, metric_names=['snr']) assert 'snr' in metrics.columns # with PCs from spikeinterface.postprocessing import compute_principal_components - pca = compute_principal_components(we, n_components=5, mode='by_channel_local') - metrics = compute_quality_metrics(we) + pca = compute_principal_components(waveform_extractor=we, n_components=5, mode='by_channel_local') + metrics = compute_quality_metrics(waveform_extractor=we) assert 'isolation_distance' in metrics.columns For more information about quality metrics, check out this excellent diff --git a/doc/modules/qualitymetrics/amplitude_cv.rst b/doc/modules/qualitymetrics/amplitude_cv.rst index 13117b607c..81d3b4f12d 100644 --- a/doc/modules/qualitymetrics/amplitude_cv.rst +++ b/doc/modules/qualitymetrics/amplitude_cv.rst @@ -37,7 +37,7 @@ Example code # Make recording, sorting and wvf_extractor object for your data. # It is required to run `compute_spike_amplitudes(wvf_extractor)` or # `compute_amplitude_scalings(wvf_extractor)` (if missing, values will be NaN) - amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(wvf_extractor) + amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(waveform_extractor=wvf_extractor) # amplitude_cv_median and amplitude_cv_range are dicts containing the unit ids as keys, # and their amplitude_cv metrics as values. diff --git a/doc/modules/qualitymetrics/amplitude_median.rst b/doc/modules/qualitymetrics/amplitude_median.rst index 3ac52560e8..c77a57b033 100644 --- a/doc/modules/qualitymetrics/amplitude_median.rst +++ b/doc/modules/qualitymetrics/amplitude_median.rst @@ -24,7 +24,7 @@ Example code # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitude values from all spikes. - amplitude_medians = sqm.compute_amplitude_medians(wvf_extractor) + amplitude_medians = sqm.compute_amplitude_medians(waveform_extractor=wvf_extractor) # amplitude_medians is a dict containing the unit IDs as keys, # and their estimated amplitude medians as values. diff --git a/doc/modules/qualitymetrics/d_prime.rst b/doc/modules/qualitymetrics/d_prime.rst index e3bd61c580..9b540be743 100644 --- a/doc/modules/qualitymetrics/d_prime.rst +++ b/doc/modules/qualitymetrics/d_prime.rst @@ -34,7 +34,7 @@ Example code import spikeinterface.qualitymetrics as sqm - d_prime = sqm.lda_metrics(all_pcs, all_labels, 0) + d_prime = sqm.lda_metrics(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) Reference diff --git a/doc/modules/qualitymetrics/drift.rst b/doc/modules/qualitymetrics/drift.rst index ae52f7f883..dad2aafe7c 100644 --- a/doc/modules/qualitymetrics/drift.rst +++ b/doc/modules/qualitymetrics/drift.rst @@ -43,10 +43,10 @@ Example code import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - # It is required to run `compute_spike_locations(wvf_extractor)` + # It is required to run `compute_spike_locations(wvf_extractor) first` # (if missing, values will be NaN) - drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(wvf_extractor, peak_sign="neg") - # drift_ptps, drift_stds, and drift_mads are dict containing the units' ID as keys, + drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(waveform_extractor=wvf_extractor, peak_sign="neg") + # drift_ptps, drift_stds, and drift_mads are each a dict containing the unit IDs as keys, # and their metrics as values. diff --git a/doc/modules/qualitymetrics/firing_range.rst b/doc/modules/qualitymetrics/firing_range.rst index 925539e9c6..1cbd903c7a 100644 --- a/doc/modules/qualitymetrics/firing_range.rst +++ b/doc/modules/qualitymetrics/firing_range.rst @@ -24,7 +24,7 @@ Example code import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - firing_range = sqm.compute_firing_ranges(wvf_extractor) + firing_range = sqm.compute_firing_ranges(waveform_extractor=wvf_extractor) # firing_range is a dict containing the unit IDs as keys, # and their firing firing_range as values (in Hz). diff --git a/doc/modules/qualitymetrics/firing_rate.rst b/doc/modules/qualitymetrics/firing_rate.rst index c0e15d7c2e..ef8cb3d8f4 100644 --- a/doc/modules/qualitymetrics/firing_rate.rst +++ b/doc/modules/qualitymetrics/firing_rate.rst @@ -40,7 +40,7 @@ With SpikeInterface: import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - firing_rate = sqm.compute_firing_rates(wvf_extractor) + firing_rate = sqm.compute_firing_rates(waveform_extractor=wvf_extractor) # firing_rate is a dict containing the unit IDs as keys, # and their firing rates across segments as values (in Hz). diff --git a/doc/modules/qualitymetrics/isolation_distance.rst b/doc/modules/qualitymetrics/isolation_distance.rst index 640a5a8b5a..6ba0d0b1ec 100644 --- a/doc/modules/qualitymetrics/isolation_distance.rst +++ b/doc/modules/qualitymetrics/isolation_distance.rst @@ -23,6 +23,16 @@ Expectation and use Isolation distance can be interpreted as a measure of distance from the cluster to the nearest other cluster. A well isolated unit should have a large isolation distance. +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + iso_distance, _ = sqm.isolation_distance(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + + References ---------- diff --git a/doc/modules/qualitymetrics/l_ratio.rst b/doc/modules/qualitymetrics/l_ratio.rst index b37913ba58..ae31ab40a4 100644 --- a/doc/modules/qualitymetrics/l_ratio.rst +++ b/doc/modules/qualitymetrics/l_ratio.rst @@ -37,6 +37,17 @@ Since this metric identifies unit separation, a high value indicates a highly co A well separated unit should have a low L-ratio ([Schmitzer-Torbert]_ et al.). + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + _, l_ratio = sqm.isolation_distance(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + + References ---------- diff --git a/doc/modules/qualitymetrics/presence_ratio.rst b/doc/modules/qualitymetrics/presence_ratio.rst index 5a420c8ccf..ad0766d37c 100644 --- a/doc/modules/qualitymetrics/presence_ratio.rst +++ b/doc/modules/qualitymetrics/presence_ratio.rst @@ -27,7 +27,7 @@ Example code # Make recording, sorting and wvf_extractor object for your data. - presence_ratio = sqm.compute_presence_ratios(wvf_extractor) + presence_ratio = sqm.compute_presence_ratios(waveform_extractor=wvf_extractor) # presence_ratio is a dict containing the unit IDs as keys # and their presence ratio (between 0 and 1) as values. diff --git a/doc/modules/qualitymetrics/silhouette_score.rst b/doc/modules/qualitymetrics/silhouette_score.rst index b924cdbf73..7da01e0476 100644 --- a/doc/modules/qualitymetrics/silhouette_score.rst +++ b/doc/modules/qualitymetrics/silhouette_score.rst @@ -50,6 +50,16 @@ To reduce complexity the default implementation in SpikeInterface is to use the This can be changes by switching the silhouette method to either 'full' (the Rousseeuw implementation) or ('simplified', 'full') for both methods when entering the qm_params parameter. +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + simple_sil_score = sqm.simplified_silhouette_score(all_pcs=all_pcs, all_labels=all_labels, this_unit_id=0) + + References ---------- diff --git a/doc/modules/qualitymetrics/sliding_rp_violations.rst b/doc/modules/qualitymetrics/sliding_rp_violations.rst index de68c3a92f..fd53d7da3b 100644 --- a/doc/modules/qualitymetrics/sliding_rp_violations.rst +++ b/doc/modules/qualitymetrics/sliding_rp_violations.rst @@ -31,7 +31,7 @@ With SpikeInterface: # Make recording, sorting and wvf_extractor object for your data. - contamination = sqm.compute_sliding_rp_violations(wvf_extractor, bin_size_ms=0.25) + contamination = sqm.compute_sliding_rp_violations(waveform_extractor=wvf_extractor, bin_size_ms=0.25) References ---------- diff --git a/doc/modules/qualitymetrics/snr.rst b/doc/modules/qualitymetrics/snr.rst index b88d3291be..7f27a5078a 100644 --- a/doc/modules/qualitymetrics/snr.rst +++ b/doc/modules/qualitymetrics/snr.rst @@ -44,8 +44,7 @@ With SpikeInterface: import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - - SNRs = sqm.compute_snrs(wvf_extractor) + SNRs = sqm.compute_snrs(waveform_extractor=wvf_extractor) # SNRs is a dict containing the unit IDs as keys and their SNRs as values. Links to original implementations diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index 0750940199..d1a3c70a97 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -29,7 +29,7 @@ Example code import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - synchrony = sqm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(waveform_extractor=wvf_extractor, synchrony_sizes=(2, 4, 8)) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index f3c8e7b733..5040b01ec2 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -49,15 +49,15 @@ to easily run spike sorters: from spikeinterface.sorters import run_sorter # run Tridesclous - sorting_TDC = run_sorter("tridesclous", recording, output_folder="/folder_TDC") + sorting_TDC = run_sorter(sorter_name="tridesclous", recording=recording, output_folder="/folder_TDC") # run Kilosort2.5 - sorting_KS2_5 = run_sorter("kilosort2_5", recording, output_folder="/folder_KS2.5") + sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, output_folder="/folder_KS2.5") # run IronClust - sorting_IC = run_sorter("ironclust", recording, output_folder="/folder_IC") + sorting_IC = run_sorter(sorter_name="ironclust", recording=recording, output_folder="/folder_IC") # run pyKilosort - sorting_pyKS = run_sorter("pykilosort", recording, output_folder="/folder_pyKS") + sorting_pyKS = run_sorter(sorter_name="pykilosort", recording=recording, output_folder="/folder_pyKS") # run SpykingCircus - sorting_SC = run_sorter("spykingcircus", recording, output_folder="/folder_SC") + sorting_SC = run_sorter(sorter_name="spykingcircus", recording=recording, output_folder="/folder_SC") Then the output, which is a :py:class:`~spikeinterface.core.BaseSorting` object, can be easily @@ -81,10 +81,10 @@ Spike-sorter-specific parameters can be controlled directly from the .. code-block:: python - sorting_TDC = run_sorter('tridesclous', recording, output_folder="/folder_TDC", + sorting_TDC = run_sorter(sorter_name='tridesclous', recording=recording, output_folder="/folder_TDC", detect_threshold=8.) - sorting_KS2_5 = run_sorter("kilosort2_5", recording, output_folder="/folder_KS2.5" + sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, output_folder="/folder_KS2.5" do_correction=False, preclust_threshold=6, freq_min=200.) @@ -185,7 +185,7 @@ The following code creates a test recording and runs a containerized spike sorte ) test_recording = test_recording.save(folder="test-docker-folder") - sorting = ss.run_sorter('kilosort3', + sorting = ss.run_sorter(sorter_name='kilosort3', recording=test_recording, output_folder="kilosort3", singularity_image=True) @@ -201,7 +201,7 @@ To run in Docker instead of Singularity, use ``docker_image=True``. .. code-block:: python - sorting = run_sorter('kilosort3', recording=test_recording, + sorting = run_sorter(sorter_name='kilosort3', recording=test_recording, output_folder="/tmp/kilosort3", docker_image=True) To use a specific image, set either ``docker_image`` or ``singularity_image`` to a string, @@ -209,7 +209,7 @@ e.g. ``singularity_image="spikeinterface/kilosort3-compiled-base:0.1.0"``. .. code-block:: python - sorting = run_sorter("kilosort3", + sorting = run_sorter(sorter_name="kilosort3", recording=test_recording, output_folder="kilosort3", singularity_image="spikeinterface/kilosort3-compiled-base:0.1.0") @@ -271,7 +271,7 @@ And use the custom image whith the :code:`run_sorter` function: .. code-block:: python - sorting = run_sorter("kilosort3", + sorting = run_sorter(sorter_name="kilosort3", recording=recording, docker_image="my-user/ks3-with-spikeinterface-test:0.1.0") @@ -302,7 +302,7 @@ an :code:`engine` that supports parallel processing (such as :code:`joblib` or : ] # run in loop - sortings = run_sorter_jobs(job_list, engine='loop') + sortings = run_sorter_jobs(job_list=job_list, engine='loop') @@ -314,11 +314,11 @@ an :code:`engine` that supports parallel processing (such as :code:`joblib` or : .. code-block:: python - run_sorter_jobs(job_list, engine='loop') + run_sorter_jobs(job_list=job_list, engine='loop') - run_sorter_jobs(job_list, engine='joblib', engine_kwargs={'n_jobs': 2}) + run_sorter_jobs(job_list=job_list, engine='joblib', engine_kwargs={'n_jobs': 2}) - run_sorter_jobs(job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem', '5G'}) + run_sorter_jobs(job_list=job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem': '5G'}) Spike sorting by group @@ -374,7 +374,7 @@ In this example, we create a 16-channel recording with 4 tetrodes: # here the result is a dict of a sorting object sortings = {} for group, sub_recording in recordings.items(): - sorting = run_sorter('kilosort2', recording, output_folder=f"folder_KS2_group{group}") + sorting = run_sorter(sorter_name='kilosort2', recording=recording, output_folder=f"folder_KS2_group{group}") sortings[group] = sorting **Option 2 : Automatic splitting** @@ -382,7 +382,7 @@ In this example, we create a 16-channel recording with 4 tetrodes: .. code-block:: python # here the result is one sorting that aggregates all sub sorting objects - aggregate_sorting = run_sorter_by_property('kilosort2', recording_4_tetrodes, + aggregate_sorting = run_sorter_by_property(sorter_name='kilosort2', recording=recording_4_tetrodes, grouping_property='group', working_folder='working_path') @@ -421,7 +421,7 @@ do not handle multi-segment, and in that case we will use the # multirecording has 4 segments of 10s each # run tridesclous in multi-segment mode - multisorting = si.run_sorter('tridesclous', multirecording) + multisorting = si.run_sorter(sorter_name='tridesclous', recording=multirecording) print(multisorting) # Case 2: the sorter DOES NOT handle multi-segment objects @@ -433,7 +433,7 @@ do not handle multi-segment, and in that case we will use the # multirecording has 1 segment of 40s each # run mountainsort4 in mono-segment mode - multisorting = si.run_sorter('mountainsort4', multirecording) + multisorting = si.run_sorter(sorter_name='mountainsort4', recording=multirecording) See also the :ref:`multi_seg` section. @@ -507,7 +507,7 @@ message will appear indicating how to install the given sorter, .. code:: python - recording = run_sorter('ironclust', recording) + recording = run_sorter(sorter_name='ironclust', recording=recording) throws the error, @@ -540,7 +540,7 @@ From the user's perspective, they behave exactly like the external sorters: .. code-block:: python - sorting = run_sorter("spykingcircus2", recording, "/tmp/folder") + sorting = run_sorter(sorter_name="spykingcircus2", recording=recording, output_folder="/tmp/folder") Contributing diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index 422eaea890..1e58972497 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -47,7 +47,8 @@ follows: job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True) peaks = detect_peaks( - recording, method='by_channel', + recording=recording, + method='by_channel', peak_sign='neg', detect_threshold=5, exclude_sweep_ms=0.2, @@ -94,7 +95,7 @@ follows: job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True) - peak_locations = localize_peaks(recording, peaks, method='center_of_mass', + peak_locations = localize_peaks(recording=recording, peaks=peaks, method='center_of_mass', radius_um=70., ms_before=0.3, ms_after=0.6, **job_kwargs) @@ -122,7 +123,7 @@ For instance, the 'monopolar_triangulation' method will have: .. note:: - By convention in SpikeInterface, when a probe is described in 2d + By convention in SpikeInterface, when a probe is described in 3d * **'x'** is the width of the probe * **'y'** is the depth * **'z'** is orthogonal to the probe plane @@ -144,11 +145,11 @@ can be *hidden* by this process. from spikeinterface.sortingcomponents.peak_detection import detect_peaks - many_peaks = detect_peaks(...) + many_peaks = detect_peaks(...) # as in above example from spikeinterface.sortingcomponents.peak_selection import select_peaks - some_peaks = select_peaks(many_peaks, method='uniform', n_peaks=10000) + some_peaks = select_peaks(peaks=many_peaks, method='uniform', n_peaks=10000) Implemented methods are the following: @@ -183,15 +184,15 @@ Here is an example with non-rigid motion estimation: .. code-block:: python from spikeinterface.sortingcomponents.peak_detection import detect_peaks - peaks = detect_peaks(recording, ...) + peaks = detect_peaks(recording=recording, ...) # as in above example from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(recording, peaks, ...) + peak_locations = localize_peaks(recording=recording, peaks=peaks, ...) # as above from spikeinterface.sortingcomponents.motion_estimation import estimate_motion motion, temporal_bins, spatial_bins, - extra_check = estimate_motion(recording, peaks, peak_locations=peak_locations, + extra_check = estimate_motion(recording=recording, peaks=peaks, peak_locations=peak_locations, direction='y', bin_duration_s=10., bin_um=10., margin_um=0., method='decentralized_registration', rigid=False, win_shape='gaussian', win_step_um=50., win_sigma_um=150., @@ -217,7 +218,7 @@ Here is a short example that depends on the output of "Motion interpolation": from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording - recording_corrected = InterpolateMotionRecording(recording_with_drift, motion, temporal_bins, spatial_bins + recording_corrected = InterpolateMotionRecording(recording=recording_with_drift, motion=motion, temporal_bins=temporal_bins, spatial_bins=spatial_bins spatial_interpolation_method='kriging, border_mode='remove_channels') @@ -255,10 +256,10 @@ Different methods may need different inputs (for instance some of them require p .. code-block:: python from spikeinterface.sortingcomponents.peak_detection import detect_peaks - peaks = detect_peaks(recording, ...) + peaks = detect_peaks(recording, ...) # as in above example from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks - labels, peak_labels = find_cluster_from_peaks(recording, peaks, method="sliding_hdbscan") + labels, peak_labels = find_cluster_from_peaks(recording=recording, peaks=peaks, method="sliding_hdbscan") * **labels** : contains all possible labels diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 8565e94fce..f37b2a5a6f 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -148,7 +148,7 @@ The :code:`plot_*(..., backend="matplotlib")` functions come with the following .. code-block:: python # matplotlib backend - w = plot_traces(recording, backend="matplotlib") + w = plot_traces(recording=recording, backend="matplotlib") **Output:** @@ -173,7 +173,7 @@ Each function has the following additional arguments: # ipywidgets backend also supports multiple "layers" for plot_traces rec_dict = dict(filt=recording, cmr=common_reference(recording)) - w = sw.plot_traces(rec_dict, backend="ipywidgets") + w = sw.plot_traces(recording=rec_dict, backend="ipywidgets") **Output:** @@ -196,8 +196,8 @@ The functions have the following additional arguments: .. code-block:: python # sortingview backend - w_ts = sw.plot_traces(recording, backend="ipywidgets") - w_ss = sw.plot_sorting_summary(recording, backend="sortingview") + w_ts = sw.plot_traces(recording=recording, backend="ipywidgets") + w_ss = sw.plot_sorting_summary(recording=recording, backend="sortingview") **Output:** @@ -249,7 +249,7 @@ The :code:`ephyviewer` backend is currently only available for the :py:func:`~sp .. code-block:: python - plot_traces(recording, backend="ephyviewer", mode="line", show_channel_ids=True) + plot_traces(recording=recording, backend="ephyviewer", mode="line", show_channel_ids=True) .. image:: ../images/plot_traces_ephyviewer.png diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index d43727cb44..df0b5296c0 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -180,7 +180,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True if sorting_exists: # delete older sorting + log before running sorters - shutil.rmtree(sorting_exists) + shutil.rmtree(sorting_folder) log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" if log_file.exists(): log_file.unlink() diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 1430e8fb45..ad31b97d8e 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -45,8 +45,12 @@ def __init__(self, main_ids: Sequence) -> None: self._kwargs = {} # 'main_ids' will either be channel_ids or units_ids - # They is used for properties + # They are used for properties self._main_ids = np.array(main_ids) + if len(self._main_ids) > 0: + assert ( + self._main_ids.dtype.kind in "uiSU" + ), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}" # dict at object level self._annotations = {} @@ -616,7 +620,7 @@ def dump_to_pickle( Parameters ---------- file_path: str - Path of the json file + Path of the pickle file include_properties: bool If True, all properties are dumped folder_metadata: str, Path, or None @@ -984,7 +988,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: class_name = None if "kwargs" not in dic: - raise Exception(f"This dict cannot be load into extractor {dic}") + raise Exception(f"This dict cannot be loaded into extractor {dic}") # Create new kwargs to avoid modifying the original dict["kwargs"] new_kwargs = dict() @@ -1005,7 +1009,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class" if not _check_same_version(class_name, dic["version"]): warnings.warn( - f"Versions are not the same. This might lead compatibility errors. " + f"Versions are not the same. This might lead to compatibility errors. " f"Using {class_name.split('.')[0]}=={dic['version']} is recommended" ) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 08f187895b..2977211c25 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -305,7 +305,8 @@ def get_traces( if not self.has_scaled(): raise ValueError( - "This recording do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)" + "This recording does not support return_scaled=True (need gain_to_uV and offset_" + "to_uV properties)" ) else: gains = self.get_property("gain_to_uV") @@ -416,8 +417,8 @@ def set_times(self, times, segment_index=None, with_warning=True): if with_warning: warn( "Setting times with Recording.set_times() is not recommended because " - "times are not always propagated to across preprocessing" - "Use use this carefully!" + "times are not always propagated across preprocessing" + "Use this carefully!" ) def sample_index_to_time(self, sample_ind, segment_index=None): diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index affde8a75e..d411f38d2a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,4 +1,4 @@ -from typing import List +from __future__ import annotations from pathlib import Path import numpy as np @@ -19,7 +19,7 @@ class BaseRecordingSnippets(BaseExtractor): has_default_locations = False - def __init__(self, sampling_frequency: float, channel_ids: List, dtype): + def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype): BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = sampling_frequency self._dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index f35bc2b266..b4e3c11f55 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -1,10 +1,8 @@ from typing import List, Union -from pathlib import Path from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets import numpy as np from warnings import warn -from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes # snippets segments? diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index e6d08d38f7..2a06a699cb 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -170,7 +170,7 @@ def register_recording(self, recording, check_spike_frames=True): if check_spike_frames: if has_exceeding_spikes(recording, self): warnings.warn( - "Some spikes are exceeding the recording's duration! " + "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " "Might be necessary for further postprocessing." ) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 72a95637f6..b45290caa5 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -91,7 +91,7 @@ def __init__( file_path_list = [Path(file_paths)] if t_starts is not None: - assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths" + assert len(t_starts) == len(file_path_list), "t_starts must be a list of the same size as file_paths" t_starts = [float(t_start) for t_start in t_starts] dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index d36e168f8d..8714580821 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -104,11 +104,11 @@ def __init__(self, channel_map, parent_segments): times_kargs0 = parent_segment0.get_times_kwargs() if times_kargs0["time_vector"] is None: for ps in parent_segments: - assert ps.get_times_kwargs()["time_vector"] is None, "All segment should not have times set" + assert ps.get_times_kwargs()["time_vector"] is None, "All segments should not have times set" else: for ps in parent_segments: assert ps.get_times_kwargs()["t_start"] == times_kargs0["t_start"], ( - "All segment should have the same " "t_start" + "All segments should have the same " "t_start" ) BaseRecordingSegment.__init__(self, **times_kargs0) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index ebd1b7db03..3a21e356a6 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -35,7 +35,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) ), "ChannelSliceRecording: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceRecording : channel_ids not unique" + ), "ChannelSliceRecording : channel_ids are not unique" sampling_frequency = parent_recording.get_sampling_frequency() @@ -123,7 +123,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): ), "ChannelSliceSnippets: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceSnippets : channel_ids not unique" + ), "ChannelSliceSnippets : channel_ids are not unique" sampling_frequency = parent_snippets.get_sampling_frequency() diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 968f27c6ad..b8574c506f 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -27,7 +27,7 @@ class FrameSliceRecording(BaseRecording): def __init__(self, parent_recording, start_frame=None, end_frame=None): channel_ids = parent_recording.get_channel_ids() - assert parent_recording.get_num_segments() == 1, "FrameSliceRecording work only with one segment" + assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment" parent_size = parent_recording.get_num_samples(0) if start_frame is None: diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index 5da5350f06..ed1391b0e2 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -36,7 +36,7 @@ class FrameSliceSorting(BaseSorting): def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike_frames=True): unit_ids = parent_sorting.get_unit_ids() - assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting work only with one segment" + assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting only works with one segment" if start_frame is None: start_frame = 0 @@ -49,10 +49,10 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = parent_n_samples assert ( end_frame <= parent_n_samples - ), "`end_frame` should be smaller than the sortings total number of samples." + ), "`end_frame` should be smaller than the sortings' total number of samples." assert ( start_frame <= parent_n_samples - ), "`start_frame` should be smaller than the sortings total number of samples." + ), "`start_frame` should be smaller than the sortings' total number of samples." if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): raise ValueError( "The sorting object has spikes exceeding the recording duration. You have to remove those spikes " @@ -67,7 +67,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = max_spike_time + 1 assert start_frame < end_frame, ( - "`start_frame` should be greater than `end_frame`. " + "`start_frame` should be less than `end_frame`. " "This may be due to start_frame >= max_spike_time, if the end frame " "was not specified explicitly." ) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 06a5ec96ec..0c67404069 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1101,11 +1101,11 @@ def __init__( # handle also upsampling and jitter upsample_factor = templates.shape[3] elif templates.ndim == 5: - # handle also dirft + # handle also drift raise NotImplementedError("Drift will be implented soon...") # upsample_factor = templates.shape[3] else: - raise ValueError("templates have wring dim should 3 or 4") + raise ValueError("templates have wrong dim should 3 or 4") if upsample_factor is not None: assert upsample_vector is not None diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index e5581c7a67..d039206296 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -96,7 +96,7 @@ def is_set_global_dataset_folder(): ######################################## global global_job_kwargs -global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) +global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global global_job_kwargs_set global_job_kwargs_set = False diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 84ee502c14..cf7a67489c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -380,10 +380,6 @@ def run(self): self.gather_func(res) else: n_jobs = min(self.n_jobs, len(all_chunks)) - ######## Do you want to limit the number of threads per process? - ######## It has to be done to speed up numpy a lot if multicores - ######## Otherwise, np.dot will be slow. How to do that, up to you - ######## This is just a suggestion, but here it adds a dependency # parallel with ProcessPoolExecutor( @@ -436,3 +432,59 @@ def function_wrapper(args): else: with threadpool_limits(limits=max_threads_per_process): return _func(segment_index, start_frame, end_frame, _worker_ctx) + + +# Here some utils copy/paste from DART (Charlie Windolf) + + +class MockFuture: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__(self, f, *args): + self.f = f + self.args = args + + def result(self): + return self.f(*self.args) + + +class MockPoolExecutor: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__( + self, + max_workers=None, + mp_context=None, + initializer=None, + initargs=None, + context=None, + ): + if initializer is not None: + initializer(*initargs) + self.map = map + self.imap = map + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return + + def submit(self, f, *args): + return MockFuture(f, *args) + + +class MockQueue: + """Another helper class for turning off concurrency when debugging.""" + + def __init__(self): + self.q = [] + self.put = self.q.append + self.get = lambda: self.q.pop(0) + + +def get_poolexecutor(n_jobs): + if n_jobs == 1: + return MockPoolExecutor + else: + return ProcessPoolExecutor diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 651804c995..a0ded216d1 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -432,6 +432,7 @@ def run_node_pipeline( job_name="pipeline", mp_context=None, gather_mode="memory", + gather_kwargs={}, squeeze_output=True, folder=None, names=None, @@ -448,7 +449,7 @@ def run_node_pipeline( if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": - gather_func = GatherToNpy(folder, names) + gather_func = GatherToNpy(folder, names, **gather_kwargs) else: raise ValueError(f"wrong gather_mode : {gather_mode}") @@ -593,9 +594,9 @@ class GatherToNpy: * create the npy v1.0 header at the end with the correct shape and dtype """ - def __init__(self, folder, names, npy_header_size=1024): + def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): self.folder = Path(folder) - self.folder.mkdir(parents=True, exist_ok=False) + self.folder.mkdir(parents=True, exist_ok=exist_ok) assert names is not None self.names = names self.npy_header_size = npy_header_size diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index 80979ce6c9..69c48356e5 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -27,6 +27,9 @@ def __init__( num_segments = len(file_paths) data = np.load(file_paths[0], mmap_mode="r") + if channel_ids is None: + channel_ids = np.arange(data["snippet"].shape[2]) + BaseSnippets.__init__( self, sampling_frequency, @@ -84,7 +87,7 @@ def write_snippets(snippets, file_paths, dtype=None): arr = np.empty(n, dtype=snippets_t, order="F") arr["frame"] = snippets.get_frames(segment_index=i) arr["snippet"] = snippets.get_snippets(segment_index=i).astype(dtype, copy=False) - + file_paths[i].parent.mkdir(parents=True, exist_ok=True) np.save(file_paths[i], arr) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 8c5c62d568..896e3800d7 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -102,7 +102,11 @@ def __init__(self, mask, unit_ids, channel_ids): self.num_channels = self.channel_ids.size self.num_units = self.unit_ids.size - self.max_num_active_channels = self.mask.sum(axis=1).max() + if self.mask.shape[0]: + self.max_num_active_channels = self.mask.sum(axis=1).max() + else: + # empty sorting without units + self.max_num_active_channels = 0 def __repr__(self): density = np.mean(self.mask) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 95278b76da..b6022e27c0 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np import warnings @@ -5,7 +6,9 @@ from .recording_tools import get_channel_distances, get_noise_levels -def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: str = "extremum"): +def get_template_amplitudes( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum" +): """ Get amplitude per channel for each unit. @@ -13,9 +16,9 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index @@ -24,8 +27,8 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st peak_values: dict Dictionary with unit ids as keys and template amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore @@ -57,7 +60,10 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st def get_template_extremum_channel( - waveform_extractor, peak_sign: str = "neg", mode: str = "extremum", outputs: str = "id" + waveform_extractor, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" = "extremum", + outputs: "id" | "index" = "id", ): """ Compute the channel with the extremum peak for each unit. @@ -66,12 +72,12 @@ def get_template_extremum_channel( ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index - outputs: str + outputs: "id" | "index", default: "id" * 'id': channel id * 'index': channel index @@ -159,7 +165,7 @@ def get_template_channel_sparsity( get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc) -def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str = "neg"): +def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg"): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. @@ -169,8 +175,8 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels Returns ------- @@ -203,7 +209,9 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str return shifts -def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", mode: str = "at_index"): +def get_template_extremum_amplitude( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index" +): """ Computes amplitudes on the best channel. @@ -211,9 +219,9 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "at_index" Where the amplitude is computed 'extremum': max or min 'at_index': take value at spike index @@ -223,8 +231,8 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", amplitudes: dict Dictionary with unit ids as keys and amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 8216a4aae6..d0672405d6 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -37,16 +37,20 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True) + job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict( + n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict( + n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=10) job_kwargs_split = fix_job_kwargs(new_job_kwargs) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 2bbf5e9b0f..204f796c0e 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -346,6 +346,8 @@ def test_recordingless(): # delete original recording and rely on rec_attributes if platform.system() != "Windows": + # this avoid reference on the folder + del we, recording shutil.rmtree(cache_folder / "recording1") we_loaded = WaveformExtractor.load(wf_folder, with_recording=False) assert not we_loaded.has_recording() @@ -554,6 +556,7 @@ def test_non_json_object(): # test_WaveformExtractor() # test_extract_waveforms() # test_portability() - # test_recordingless() + test_recordingless() # test_compute_sparsity() - test_non_json_object() + # test_non_json_object() + test_empty_sorting() diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 32158f00df..4e98864ba9 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -95,7 +95,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None): try: property_dict[prop_name] = np.concatenate((property_dict[prop_name], values)) except Exception as e: - print(f"Skipping property '{prop_name}' for shape inconsistency") + print(f"Skipping property '{prop_name}' due to shape inconsistency") del property_dict[prop_name] break for prop_name, prop_values in property_dict.items(): diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index ab36672979..d4ae140b90 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -175,7 +175,12 @@ def load_from_folder( rec_attributes = None if sorting is None: - sorting = load_extractor(folder / "sorting.json", base_folder=folder) + if (folder / "sorting.json").exists(): + sorting = load_extractor(folder / "sorting.json", base_folder=folder) + elif (folder / "sorting.pickle").exists(): + sorting = load_extractor(folder / "sorting.pickle") + else: + raise FileNotFoundError("load_waveforms() impossible to find the sorting object (json or pickle)") # the sparsity is the sparsity of the saved/cached waveforms arrays sparsity_file = folder / "sparsity.json" @@ -1468,13 +1473,13 @@ def extract_waveforms( folder=None, mode="folder", precompute_template=("average",), - ms_before=3.0, - ms_after=4.0, + ms_before=1.0, + ms_after=2.0, max_spikes_per_unit=500, overwrite=False, return_scaled=True, dtype=None, - sparse=False, + sparse=True, sparsity=None, num_spikes_for_sparsity=100, allow_unfiltered=False, @@ -1518,7 +1523,7 @@ def extract_waveforms( If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV. dtype: dtype or None Dtype of the output waveforms. If None, the recording dtype is maintained. - sparse: bool (default False) + sparse: bool, default: True If True, before extracting all waveforms the `precompute_sparsity()` function is run using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the waveforms will be sparse at extraction time, which saves a lot of memory. @@ -1737,6 +1742,7 @@ def precompute_sparsity( max_spikes_per_unit=num_spikes_for_sparsity, return_scaled=False, allow_unfiltered=allow_unfiltered, + sparse=False, **job_kwargs, ) local_sparsity = compute_sparsity(local_we, **sparse_kwargs) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 6adf9effd4..626ea79eb9 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -57,37 +57,47 @@ def apply_sortingview_curation( unit_ids_dtype = sorting.unit_ids.dtype # STEP 1: merge groups + labels_dict = sortingview_curation_dict["labelsByUnit"] if "mergeGroups" in sortingview_curation_dict and not skip_merge: merge_groups = sortingview_curation_dict["mergeGroups"] - for mg in merge_groups: + for merge_group in merge_groups: + # Store labels of units that are about to be merged + labels_to_inherit = [] + for unit in merge_group: + labels_to_inherit.extend(labels_dict.get(str(unit), [])) + labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates + if verbose: - print(f"Merging {mg}") + print(f"Merging {merge_group}") if unit_ids_dtype.kind in ("U", "S"): # if unit dtype is str, set new id as "{unit1}-{unit2}" - new_unit_id = "-".join(mg) + new_unit_id = "-".join(merge_group) + curation_sorting.merge(merge_group, new_unit_id=new_unit_id) else: # in this case, the CurationSorting takes care of finding a new unused int - new_unit_id = None - curation_sorting.merge(mg, new_unit_id=new_unit_id) + curation_sorting.merge(merge_group, new_unit_id=None) + new_unit_id = curation_sorting.max_used_id # merged unit id + labels_dict[str(new_unit_id)] = labels_to_inherit # STEP 2: gather and apply sortingview curation labels - # In sortingview, a unit is not required to have all labels. # For example, the first 3 units could be labeled as "accept". # In this case, the first 3 values of the property "accept" will be True, the rest False - labels_dict = sortingview_curation_dict["labelsByUnit"] - properties = {} - for _, labels in labels_dict.items(): - for label in labels: - if label not in properties: - properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) - for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - labels_unit = [] - for unit_label, labels in labels_dict.items(): - if unit_label in str(unit_id): - labels_unit.extend(labels) - for label in labels_unit: - properties[label][u_i] = True + + # Initialize the properties dictionary + properties = { + label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) + for labels in labels_dict.values() + for label in labels + } + + # Populate the properties dictionary + for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): + unit_id_str = str(unit_id) + if unit_id_str in labels_dict: + for label in labels_dict[unit_id_str]: + properties[label][unit_index] = True + for prop_name, prop_values in properties.items(): curation_sorting.current_sorting.set_property(prop_name, prop_values) @@ -103,5 +113,4 @@ def apply_sortingview_curation( units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True]) units_to_remove = np.unique(units_to_remove) curation_sorting.remove_units(units_to_remove) - return curation_sorting.current_sorting diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json new file mode 100644 index 0000000000..48881388bb --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json @@ -0,0 +1,19 @@ +{ + "labelsByUnit": { + "1": [ + "accept" + ], + "2": [ + "artifact" + ], + "12": [ + "artifact" + ] + }, + "mergeGroups": [ + [ + 2, + 12 + ] + ] +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-int.json b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json new file mode 100644 index 0000000000..2047c514ce --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "1": [ + "mua" + ], + "2": [ + "mua" + ], + "3": [ + "reject" + ], + "4": [ + "noise" + ], + "5": [ + "accept" + ], + "6": [ + "accept" + ], + "7": [ + "accept" + ] + }, + "mergeGroups": [ + [ + 1, + 2 + ], + [ + 3, + 4 + ], + [ + 5, + 6 + ] + ] +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-str.json b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json new file mode 100644 index 0000000000..2585b5cc50 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "a": [ + "mua" + ], + "b": [ + "mua" + ], + "c": [ + "reject" + ], + "d": [ + "noise" + ], + "e": [ + "accept" + ], + "f": [ + "accept" + ], + "g": [ + "accept" + ] + }, + "mergeGroups": [ + [ + "a", + "b" + ], + [ + "c", + "d" + ], + [ + "e", + "f" + ] + ] +} diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 9177cb5536..ce6c7dd5a6 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -1,8 +1,11 @@ import pytest from pathlib import Path import os +import json +import numpy as np import spikeinterface as si +import spikeinterface.extractors as se from spikeinterface.extractors import read_mearec from spikeinterface import set_global_tmp_folder from spikeinterface.postprocessing import ( @@ -19,7 +22,6 @@ cache_folder = Path("cache_folder") / "curation" parent_folder = Path(__file__).parent - ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) @@ -50,15 +52,15 @@ def generate_sortingview_curation_dataset(): @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_gh_curation(): + """ + Test curation using GitHub URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) - - # from GH # curated link: # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22gh://alejoe91/spikeinterface/fix-codecov/spikeinterface/curation/tests/sv-sorting-curation.json%22} gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True) - print(f"From GH: {sorting_curated_gh}") assert len(sorting_curated_gh.unit_ids) == 9 assert "#8-#9" in sorting_curated_gh.unit_ids @@ -78,6 +80,9 @@ def test_gh_curation(): @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_sha1_curation(): + """ + Test curation using SHA1 URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) @@ -86,14 +91,14 @@ def test_sha1_curation(): # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22%22} sha1_uri = "sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22" sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, verbose=True) - print(f"From SHA: {sorting_curated_sha1}") + # print(f"From SHA: {sorting_curated_sha1}") assert len(sorting_curated_sha1.unit_ids) == 9 assert "#8-#9" in sorting_curated_sha1.unit_ids assert "accept" in sorting_curated_sha1.get_property_keys() assert "mua" in sorting_curated_sha1.get_property_keys() assert "artifact" in sorting_curated_sha1.get_property_keys() - + unit_ids = sorting_curated_sha1.unit_ids sorting_curated_sha1_accepted = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, include_labels=["accept"]) sorting_curated_sha1_mua = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, exclude_labels=["mua"]) sorting_curated_sha1_art_mua = apply_sortingview_curation( @@ -105,13 +110,16 @@ def test_sha1_curation(): def test_json_curation(): + """ + Test curation using a JSON file. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) # from curation.json json_file = parent_folder / "sv-sorting-curation.json" + # print(f"Sorting: {sorting.get_unit_ids()}") sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) - print(f"From JSON: {sorting_curated_json}") assert len(sorting_curated_json.unit_ids) == 9 assert "#8-#9" in sorting_curated_json.unit_ids @@ -131,8 +139,133 @@ def test_json_curation(): assert len(sorting_curated_json_mua1.unit_ids) == 5 +def test_false_positive_curation(): + """ + Test curation for false positives. + """ + # https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html + sampling_frequency = 30000.0 + duration = 20.0 + num_timepoints = int(sampling_frequency * duration) + num_units = 20 + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, num_units + 1, size=num_spikes) + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + # print("Sorting: {}".format(sorting.get_unit_ids())) + + json_file = parent_folder / "sv-sorting-curation-false-positive.json" + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + # print("Curated:", sorting_curated_json.get_unit_ids()) + + # Assertions + assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") + assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") + assert 21 in sorting_curated_json.unit_ids + + +def test_label_inheritance_int(): + """ + Test curation for label inheritance for integer unit IDs. + """ + # Setup + sampling_frequency = 30000.0 + duration = 20.0 + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + num_units = 7 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, 1 + num_units, size=num_spikes) # 7 units: 1 to 7 + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + + json_file = parent_folder / "sv-sorting-curation-int.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file) + + # Assertions for merged units + # print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2 + assert not sorting_merge.get_unit_property(unit_id=8, key="reject") + assert not sorting_merge.get_unit_property(unit_id=8, key="noise") + assert not sorting_merge.get_unit_property(unit_id=8, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=9, key="mua") # 9 = merged unit of 3 and 4 + assert sorting_merge.get_unit_property(unit_id=9, key="reject") + assert sorting_merge.get_unit_property(unit_id=9, key="noise") + assert not sorting_merge.get_unit_property(unit_id=9, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=10, key="mua") # 10 = merged unit of 5 and 6 + assert not sorting_merge.get_unit_property(unit_id=10, key="reject") + assert not sorting_merge.get_unit_property(unit_id=10, key="noise") + assert sorting_merge.get_unit_property(unit_id=10, key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert 9 not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert 8 not in sorting_include_accept.get_unit_ids() + assert 9 not in sorting_include_accept.get_unit_ids() + assert 10 in sorting_include_accept.get_unit_ids() + + +def test_label_inheritance_str(): + """ + Test curation for label inheritance for string unit IDs. + """ + sampling_frequency = 30000.0 + duration = 20.0 + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes) + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + # print(f"Sorting: {sorting.get_unit_ids()}") + + # Apply curation + json_file = parent_folder / "sv-sorting-curation-str.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + + # Assertions for merged units + # print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id="a-b", key="mua") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="c-d", key="mua") + assert sorting_merge.get_unit_property(unit_id="c-d", key="reject") + assert sorting_merge.get_unit_property(unit_id="c-d", key="noise") + assert not sorting_merge.get_unit_property(unit_id="c-d", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="e-f", key="mua") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="reject") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="noise") + assert sorting_merge.get_unit_property(unit_id="e-f", key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert "c-d" not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert "a-b" not in sorting_include_accept.get_unit_ids() + assert "c-d" not in sorting_include_accept.get_unit_ids() + assert "e-f" in sorting_include_accept.get_unit_ids() + + if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() test_gh_curation() test_json_curation() + test_false_positive_curation() + test_label_inheritance_int() + test_label_inheritance_str() diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 7528f0ebf9..39bb875ea8 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -78,7 +78,7 @@ def test_export_to_phy_by_property(): recording = recording.save(folder=rec_folder) sorting = sorting.save(folder=sort_folder) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_group = compute_sparsity(waveform_extractor, method="by_property", by_property="group") export_to_phy( waveform_extractor, @@ -96,7 +96,7 @@ def test_export_to_phy_by_property(): # Remove one channel recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7]) - waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm) + waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm, sparse=False) sparsity_group = compute_sparsity(waveform_extractor_rm, method="by_property", by_property="group") export_to_phy( @@ -130,7 +130,7 @@ def test_export_to_phy_by_sparsity(): if f.is_dir(): shutil.rmtree(f) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_radius = compute_sparsity(waveform_extractor, method="radius", radius_um=50.0) export_to_phy( waveform_extractor, diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ebc810b953..31a452f389 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -94,6 +94,7 @@ def export_to_phy( if waveform_extractor.is_sparse(): used_sparsity = waveform_extractor.sparsity + assert sparsity is None elif sparsity is not None: used_sparsity = sparsity else: diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 31241a4147..0980e89f1c 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -118,7 +118,7 @@ def __init__( spike_times = spikes_data["times"] # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames - unit_ids = unit_ids[:].tolist() + unit_ids = [str(unit_id) for unit_id in unit_ids] spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: spiketrains_dict[unit_id] = (sampling_frequency * spiketrains_dict[unit_id]).round().astype(np.int64) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 815c617677..1eb0182318 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -216,10 +216,14 @@ def write_sorting(sorting, save_path, write_primary_channels=False): times_list = [] labels_list = [] primary_channels_list = [] - for unit_id in unit_ids: + for unit_index, unit_id in enumerate(unit_ids): times = sorting.get_unit_spike_train(unit_id=unit_id) times_list.append(times) - labels_list.append(np.ones(times.shape) * unit_id) + # unit id may not be numeric + if unit_id.dtype.kind in "iu": + labels_list.append(np.ones(times.shape, dtype=unit_id.dtype) * unit_id) + else: + labels_list.append(np.ones(times.shape, dtype=int) * unit_index) if write_primary_channels: if "max_channel" in sorting.get_unit_property_names(unit_id): primary_channels_list.append([sorting.get_unit_property(unit_id, "max_channel")] * times.shape[0]) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 6cd5238abd..6e693635eb 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -137,8 +137,8 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ def compute_correlograms( waveform_or_sorting_extractor, load_if_exists=False, - window_ms: float = 100.0, - bin_ms: float = 5.0, + window_ms: float = 50.0, + bin_ms: float = 1.0, method: str = "auto", ): """Compute auto and cross correlograms. diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 8f864e9b84..50e2ecdb57 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -57,6 +57,7 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, @@ -92,6 +93,7 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, @@ -112,6 +114,7 @@ def setUp(self): recording, sorting, mode="memory", + sparse=False, ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index d2739f69dd..48ceb34a4e 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -96,7 +96,7 @@ def get_extension_function(): def compute_unit_locations( - waveform_extractor, load_if_exists=False, method="center_of_mass", outputs="numpy", **method_kwargs + waveform_extractor, load_if_exists=False, method="monopolar_triangulation", outputs="numpy", **method_kwargs ): """ Localize units in 2D or 3D with several methods given the template. diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index a2349c1ee9..cc18d51d2e 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -97,7 +97,7 @@ def __init__( chunk_size=500, seed=0, ): - assert direction in ("upper", "lower", "both") + assert direction in ("upper", "lower", "both"), "'direction' must be 'upper', 'lower', or 'both'" if fill_value is None or quantile_threshold is not None: random_data = get_random_data_chunks( diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index d2ac227217..6d6ce256de 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -83,7 +83,7 @@ def __init__( ref_channel_ids = np.asarray(ref_channel_ids) assert np.all( [ch in recording.get_channel_ids() for ch in ref_channel_ids] - ), "Some wrong 'ref_channel_ids'!" + ), "Some 'ref_channel_ids' are wrong!" elif reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index cc4e8601e2..e6e2836a35 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -211,9 +211,9 @@ def detect_bad_channels( if bad_channel_ids.size > recording.get_num_channels() / 3: warnings.warn( - "Over 1/3 of channels are detected as bad. In the precense of a high" + "Over 1/3 of channels are detected as bad. In the presence of a high" "number of dead / noisy channels, bad channel detection may fail " - "(erroneously label good channels as dead)." + "(good channels may be erroneously labeled as dead)." ) elif method == "neighborhood_r2": diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 51c1fb4ad6..b31088edf7 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -71,10 +71,10 @@ def __init__( ): import scipy.signal - assert filter_mode in ("sos", "ba") + assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'" fs = recording.get_sampling_frequency() if coeff is None: - assert btype in ("bandpass", "highpass") + assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'" # coefficient # self.coeff is 'sos' or 'ab' style filter_coeff = scipy.signal.iirfilter( @@ -258,7 +258,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): if dtype.kind == "u": raise TypeError( "The notch filter only supports signed types. Use the 'dtype' argument" - "to specify a signed type (e.g. 'int16', 'float32'" + "to specify a signed type (e.g. 'int16', 'float32')" ) BasePreprocessor.__init__(self, recording, dtype=dtype) diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 790279d647..d3a08297c6 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -50,9 +50,9 @@ def __init__( margin_ms=5.0, ): assert HAVE_PYOPENCL, "You need to install pyopencl (and GPU driver!!)" - - assert btype in ("bandpass", "lowpass", "highpass", "bandstop") - assert filter_mode in ("sos",) + btype_modes = ("bandpass", "lowpass", "highpass", "bandstop") + assert btype in btype_modes, f"'btype' must be in {btype_modes}" + assert filter_mode in ("sos",), "'filter_mode' must be 'sos'" # coefficient sf = recording.get_sampling_frequency() @@ -96,8 +96,8 @@ def __init__(self, parent_recording_segment, executor, margin): self.margin = margin def get_traces(self, start_frame, end_frame, channel_indices): - assert start_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" - assert end_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" + assert start_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" + assert end_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" chunk_size = end_frame - start_frame if chunk_size != self.executor.chunk_size: @@ -157,7 +157,7 @@ def process(self, traces): if traces.shape[0] != self.full_size: if self.full_size is not None: - print(f"Warning : chunk_size have change {self.chunk_size} {traces.shape[0]}, need recompile CL!!!") + print(f"Warning : chunk_size has changed {self.chunk_size} {traces.shape[0]}, need to recompile CL!!!") self.create_buffers_and_compile() event = pyopencl.enqueue_copy(self.queue, self.input_cl, traces) diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index aa98410568..4df4a409bc 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -212,7 +212,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces * self.taper[np.newaxis, :] # apply actual HP filter - import scipy + import scipy.signal traces = scipy.signal.sosfiltfilt(self.sos_filter, traces, axis=1) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 7d43982853..bd53866b6a 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -68,7 +68,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("pool_channel", "by_channel") + assert mode in ("pool_channel", "by_channel"), "'mode' must be 'pool_channel' or 'by_channel'" random_data = get_random_data_chunks(recording, **random_chunk_kwargs) @@ -260,7 +260,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("median+mad", "mean+std") + assert mode in ("median+mad", "mean+std"), "'mode' must be 'median+mad' or 'mean+std'" # fix dtype dtype_ = fix_dtype(recording, dtype) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 9c8b2589a0..bdba55038d 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -42,7 +42,9 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non assert "inter_sample_shift" in recording.get_property_keys(), "'inter_sample_shift' is not a property!" sample_shifts = recording.get_property("inter_sample_shift") else: - assert len(inter_sample_shift) == recording.get_num_channels(), "sample " + assert ( + len(inter_sample_shift) == recording.get_num_channels() + ), "the 'inter_sample_shift' must be same size at the num_channels " sample_shifts = np.asarray(inter_sample_shift) margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 7e84822c61..1eafa48a0b 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -107,8 +107,6 @@ def __init__( time_jitter=0, waveforms_kwargs={"allow_unfiltered": True, "mode": "memory"}, ): - import scipy.interpolate - available_modes = ("zeros", "linear", "cubic", "average", "median") num_seg = recording.get_num_segments() @@ -236,8 +234,6 @@ def __init__( time_pad, sparsity, ): - import scipy.interpolate - BasePreprocessorSegment.__init__(self, parent_recording_segment) self.triggers = np.asarray(triggers, dtype="int64") @@ -285,6 +281,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): elif trig + pad[1] >= end_frame - start_frame: traces[trig - pad[0] :, :] = 0 elif self.mode in ["linear", "cubic"]: + import scipy.interpolate + for trig in triggers: if pad is None: pre_data_end_idx = trig - 1 diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 710c4f76f4..0c3b9f95d1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,7 +6,7 @@ from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore +from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter try: import hdbscan @@ -22,7 +22,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, - "filtering": {"dtype": "float32"}, + "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, @@ -60,11 +60,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - recording_f = bandpass_filter(recording, **filtering_params) + recording_f = highpass_filter(recording, **filtering_params) recording_f = common_reference(recording_f) else: recording_f = recording + # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") ## Then, we are detecting peaks with a locally_exclusive method @@ -98,10 +99,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets clustering_params = params["clustering"].copy() - clustering_params["waveforms_kwargs"] = params["waveforms"] + clustering_params["waveforms"] = params["waveforms"].copy() for k in ["ms_before", "ms_after"]: - clustering_params["waveforms_kwargs"][k] = params["general"][k] + clustering_params["waveforms"][k] = params["general"][k] clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs @@ -118,8 +119,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if clustering_folder.exists(): shutil.rmtree(clustering_folder) - sorting = sorting.save(folder=clustering_folder) - ## We get the templates our of such a clustering waveforms_params = params["waveforms"].copy() waveforms_params.update(job_kwargs) @@ -131,6 +130,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): mode = "memory" waveforms_folder = None else: + sorting = sorting.save(folder=clustering_folder) mode = "folder" waveforms_folder = sorter_output_folder / "waveforms" diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index bd5667b15f..a49a605a75 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -91,7 +91,7 @@ def run_sorter( sorter_name: str, recording: BaseRecording, output_folder: Optional[str] = None, - remove_existing_folder: bool = True, + remove_existing_folder: bool = False, delete_output_folder: bool = False, verbose: bool = False, raise_error: bool = True, @@ -514,19 +514,19 @@ def run_sorter_container( res_output = container_client.run_command(cmd) cmd = f"cp -r {si_dev_path_unix} {si_source_folder}" res_output = container_client.run_command(cmd) - cmd = f"pip install {si_source_folder}/spikeinterface[full]" + cmd = f"pip install --user {si_source_folder}/spikeinterface[full]" else: si_source = "remote repository" - cmd = "pip install --upgrade --no-input git+https://github.com/SpikeInterface/spikeinterface.git#egg=spikeinterface[full]" + cmd = "pip install --user --upgrade --no-input git+https://github.com/SpikeInterface/spikeinterface.git#egg=spikeinterface[full]" if verbose: print(f"Installing dev spikeinterface from {si_source}") res_output = container_client.run_command(cmd) - cmd = "pip install --upgrade --no-input https://github.com/NeuralEnsemble/python-neo/archive/master.zip" + cmd = "pip install --user --upgrade --no-input https://github.com/NeuralEnsemble/python-neo/archive/master.zip" res_output = container_client.run_command(cmd) else: if verbose: print(f"Installing spikeinterface=={si_version} in {container_image}") - cmd = f"pip install --upgrade --no-input spikeinterface[full]=={si_version}" + cmd = f"pip install --user --upgrade --no-input spikeinterface[full]=={si_version}" res_output = container_client.run_command(cmd) else: # TODO version checking @@ -540,7 +540,7 @@ def run_sorter_container( if extra_requirements: if verbose: print(f"Installing extra requirements: {extra_requirements}") - cmd = f"pip install --upgrade --no-input {' '.join(extra_requirements)}" + cmd = f"pip install --user --upgrade --no-input {' '.join(extra_requirements)}" res_output = container_client.run_command(cmd) # run sorter on folder diff --git a/src/spikeinterface/sortingcomponents/clustering/clean.py b/src/spikeinterface/sortingcomponents/clustering/clean.py new file mode 100644 index 0000000000..cbded0c49f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/clean.py @@ -0,0 +1,45 @@ +import numpy as np + +from .tools import FeaturesLoader, compute_template_from_sparse + +# This is work in progress ... + + +def clean_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + peak_sign="neg", +): + total_channels = recording.get_num_channels() + + if isinstance(features_dict_or_folder, dict): + features = features_dict_or_folder + else: + features = FeaturesLoader(features_dict_or_folder) + + clean_labels = peak_labels.copy() + + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + + count = np.zeros(n, dtype="int64") + for i, label in enumerate(labels_set): + count[i] = np.sum(peak_labels == label) + print(count) + + templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels) + + if peak_sign == "both": + max_values = np.max(np.abs(templates), axis=(1, 2)) + elif peak_sign == "neg": + max_values = -np.min(templates, axis=(1, 2)) + elif peak_sign == "pos": + max_values = np.max(templates, axis=(1, 2)) + print(max_values) + + return clean_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 28a1a63065..b4938717f8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -574,6 +574,8 @@ def remove_duplicates_via_matching( if tmp_folder is None: tmp_folder = get_global_tmp_folder() + tmp_folder.mkdir(parents=True, exist_ok=True) + tmp_filename = tmp_folder / "tmp.raw" f = open(tmp_filename, "wb") @@ -583,38 +585,42 @@ def remove_duplicates_via_matching( f.close() recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") - recording.annotate(is_filtered=True) recording = recording.set_probe(waveform_extractor.recording.get_probe()) + recording.annotate(is_filtered=True) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) half_marging = margin // 2 chunk_size = duration + 3 * margin - method_kwargs.update( + local_params = method_kwargs.copy() + + local_params.update( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], - "omp_min_sps": 0.1, + "omp_min_sps": 0.05, } ) + spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) + indices = np.argsort(counts) + ignore_ids = [] similar_templates = [[], []] - for i in range(nb_templates): + for i in indices: t_start = padding + i * duration t_stop = padding + (i + 1) * duration sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging) - - method_kwargs.update({"ignored_ids": ignore_ids + [i]}) + local_params.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs ) if method == "circus-omp-svd": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -628,7 +634,7 @@ def remove_duplicates_via_matching( } ) elif method == "circus-omp": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -660,7 +666,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording + del recording, sub_recording, local_params, waveform_extractor os.remove(tmp_filename) return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py new file mode 100644 index 0000000000..d892d0723a --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -0,0 +1,581 @@ +from pathlib import Path +from multiprocessing import get_context +from concurrent.futures import ProcessPoolExecutor +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +import scipy.spatial +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from hdbscan import HDBSCAN + +import numpy as np +import networkx as nx + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + + +from .isocut5 import isocut5 + +from .tools import aggregate_sparse_features, FeaturesLoader, compute_template_from_sparse + + +def merge_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + radius_um=70, + method="waveforms_lda", + method_kwargs={}, + **job_kwargs, +): + """ + Merge cluster using differents methods. + + Parameters + ---------- + peaks: numpy.ndarray 1d + detected peaks (or a subset) + peak_labels: numpy.ndarray 1d + original label before merge + peak_labels.size == peaks.size + recording: Recording object + A recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method used + method_kwargs: dict + Option for the method. + Returns + ------- + merge_peak_labels: numpy.ndarray 1d + New vectors label after merges. + peak_shifts: numpy.ndarray 1d + A vector of sample shift to be reverse applied on original sample_index on peak detection + Negative shift means too early. + Posituve shift means too late. + So the correction must be applied like this externaly: + final_peaks = peaks.copy() + final_peaks['sample_index'] -= peak_shifts + + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + + features = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set, pair_mask, pair_shift, pair_values = find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=radius_um, + method=method, + method_kwargs=method_kwargs, + **job_kwargs, + ) + + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.matshow(pair_values) + + pair_values[~pair_mask] = 20 + + import hdbscan + + fig, ax = plt.subplots() + clusterer = hdbscan.HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) + clusterer.fit(pair_values) + print(clusterer.labels_) + clusterer.single_linkage_tree_.plot(cmap="viridis", colorbar=True) + # ~ fig, ax = plt.subplots() + # ~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', + # ~ edge_alpha=0.6, + # ~ node_size=80, + # ~ edge_linewidth=2) + + graph = clusterer.single_linkage_tree_.to_networkx() + + import scipy.cluster + + fig, ax = plt.subplots() + scipy.cluster.hierarchy.dendrogram(clusterer.single_linkage_tree_.to_numpy(), ax=ax) + + import networkx as nx + + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() + + plt.show() + + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + + group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) + + # apply final label and shift + merge_peak_labels = peak_labels.copy() + peak_shifts = np.zeros(peak_labels.size, dtype="int64") + for merge, shifts in zip(merges, group_shifts): + label0 = merge[0] + mask = np.in1d(peak_labels, merge) + merge_peak_labels[mask] = label0 + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + peak_shifts[peak_labels == label1] = shifts[l] + + return merge_peak_labels, peak_shifts + + +def resolve_final_shifts(labels_set, merges, pair_mask, pair_shift): + labels_set = list(labels_set) + + group_shifts = [] + for merge in merges: + shifts = np.zeros(len(merge), dtype="int64") + + label_inds = [labels_set.index(label) for label in merge] + + label0 = merge[0] + ind0 = label_inds[0] + + # First find relative shift to label0 (l=0) in the subgraph + local_pair_mask = pair_mask[label_inds, :][:, label_inds] + local_pair_shift = None + G = None + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + ind1 = label_inds[l] + if local_pair_mask[0, l]: + # easy case the pair label0<>label1 was existing + shift = pair_shift[ind0, ind1] + else: + # more complicated case need to find intermediate label and propagate the shift!! + if G is None: + # the the graph only once and only if needed + G = nx.from_numpy_array(local_pair_mask | local_pair_mask.T) + local_pair_shift = pair_shift[label_inds, :][:, label_inds] + local_pair_shift += local_pair_shift.T + + shift_chain = nx.shortest_path(G, source=l, target=0) + shift = 0 + for i in range(len(shift_chain) - 1): + shift += local_pair_shift[shift_chain[i + 1], shift_chain[i]] + shifts[l] = shift + + group_shifts.append(shifts) + + return group_shifts + + +def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"): + """ + Agglomerate merge pairs into final merge groups. + + The merges are ordered by label. + + """ + + labels_set = np.array(labels_set) + + merges = [] + + graph = nx.from_numpy_array(pair_mask | pair_mask.T) + # put real nodes names for debugging + maps = dict(zip(np.arange(labels_set.size), labels_set)) + graph = nx.relabel_nodes(graph, maps) + + groups = list(nx.connected_components(graph)) + for group in groups: + if len(group) == 1: + continue + sub_graph = graph.subgraph(group) + # print(group, sub_graph) + cliques = list(nx.find_cliques(sub_graph)) + if len(cliques) == 1 and len(cliques[0]) == len(group): + # the sub graph is full connected: no ambiguity + # merges.append(labels_set[cliques[0]]) + merges.append(cliques[0]) + elif len(cliques) > 1: + # the subgraph is not fully connected + if connection_mode == "full": + # node merge + pass + elif connection_mode == "partial": + group = list(group) + # merges.append(labels_set[group]) + merges.append(group) + elif connection_mode == "clique": + raise NotImplementedError + else: + raise ValueError + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(sub_graph) + plt.show() + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() + + # ensure ordered label + merges = [np.sort(merge) for merge in merges] + + return merges + + +def find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=70, + method="waveforms_lda", + method_kwargs={}, + **job_kwargs + # n_jobs=1, + # mp_context="fork", + # max_threads_per_process=1, + # progress_bar=True, +): + """ + Searh some possible merge 2 by 2. + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + # features_dict_or_folder = Path(features_dict_or_folder) + + # peaks = features_dict_or_folder['peaks'] + total_channels = recording.get_num_channels() + + # sparse_wfs = features['sparse_wfs'] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + pair_mask = np.triu(np.ones((n, n), dtype="bool")) & ~np.eye(n, dtype="bool") + pair_shift = np.zeros((n, n), dtype="int64") + pair_values = np.zeros((n, n), dtype="float64") + + # compute template (no shift at this step) + + templates = compute_template_from_sparse( + peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels, peak_shifts=None + ) + + max_chans = np.argmax(np.max(np.abs(templates), axis=1), axis=1) + + channel_locs = recording.get_channel_locations() + template_locs = channel_locs[max_chans, :] + template_dist = scipy.spatial.distance.cdist(template_locs, template_locs, metric="euclidean") + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + progress_bar = job_kwargs["progress_bar"] + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=find_pair_worker_init, + mp_context=get_context(mp_context), + initargs=(recording, features_dict_or_folder, peak_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + jobs = [] + for ind0, ind1 in zip(indices0, indices1): + label0 = labels_set[ind0] + label1 = labels_set[ind1] + jobs.append(pool.submit(find_pair_function_wrapper, label0, label1)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"find_merge_pairs with {method}", total=len(jobs)) + else: + iterator = jobs + + for res in iterator: + is_merge, label0, label1, shift, merge_value = res.result() + ind0 = labels_set.index(label0) + ind1 = labels_set.index(label1) + + pair_mask[ind0, ind1] = is_merge + if is_merge: + pair_shift[ind0, ind1] = shift + pair_values[ind0, ind1] = merge_value + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + return labels_set, pair_mask, pair_shift, pair_values + + +def find_pair_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = find_pair_method_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + + # if isinstance(features_dict_or_folder, dict): + # _ctx["features"] = features_dict_or_folder + # else: + # _ctx["features"] = FeaturesLoader(features_dict_or_folder) + + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def find_pair_function_wrapper(label0, label1): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( + label0, label1, _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + ) + return is_merge, label0, label1, shift, merge_value + + +class ProjectDistribution: + """ + This method is a refactorized mix between: + * old tridesclous code + * some ideas by Charlie Windolf in spikespvae + + The idea is : + * project the waveform (or features) samples on a 1d axis (using LDA for instance). + * check that it is the same or not distribution (diptest, distrib_overlap, ...) + + + """ + + name = "project_distribution" + + @staticmethod + def merge( + label0, + label1, + original_labels, + peaks, + features, + waveforms_sparse_mask=None, + feature_name="sparse_tsvd", + projection="centroid", + criteria="diptest", + threshold_diptest=0.5, + threshold_percentile=80.0, + threshold_overlap=0.4, + num_shift=2, + ): + if num_shift > 0: + assert feature_name == "sparse_wfs" + sparse_wfs = features[feature_name] + + assert waveforms_sparse_mask is not None + + (inds0,) = np.nonzero(original_labels == label0) + chans0 = np.unique(peaks["channel_index"][inds0]) + target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0)) + + (inds1,) = np.nonzero(original_labels == label1) + chans1 = np.unique(peaks["channel_index"][inds1]) + target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + + if inds0.size < 40 or inds1.size < 40: + is_merge = False + merge_value = 0 + final_shift = 0 + return is_merge, label0, label1, final_shift, merge_value + + target_chans = np.intersect1d(target_chans0, target_chans1) + + inds = np.concatenate([inds0, inds1]) + labels = np.zeros(inds.size, dtype="int") + labels[inds0.size :] = 1 + wfs, out = aggregate_sparse_features(peaks, inds, sparse_wfs, waveforms_sparse_mask, target_chans) + wfs = wfs[~out] + labels = labels[~out] + + cut = np.searchsorted(labels, 1) + wfs0_ = wfs[:cut, :, :] + wfs1_ = wfs[cut:, :, :] + + template0_ = np.mean(wfs0_, axis=0) + template1_ = np.mean(wfs1_, axis=0) + num_samples = template0_.shape[0] + + template0 = template0_[num_shift : num_samples - num_shift, :] + + wfs0 = wfs0_[:, num_shift : num_samples - num_shift, :] + + # best shift strategy 1 = max cosine + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # norm = np.linalg.norm(template0.flatten()) * np.linalg.norm(template1.flatten()) + # value = np.sum(template0.flatten() * template1.flatten()) / norm + # values.append(value) + # best_shift = np.argmax(values) + + # best shift strategy 2 = min dist**2 + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # value = np.sum((template1 - template0)**2) + # values.append(value) + # best_shift = np.argmin(values) + + # best shift strategy 3 : average delta argmin between channels + channel_shift = np.argmax(np.abs(template1_), axis=0) - np.argmax(np.abs(template0_), axis=0) + mask = np.abs(channel_shift) <= num_shift + channel_shift = channel_shift[mask] + if channel_shift.size > 0: + best_shift = int(np.round(np.mean(channel_shift))) + num_shift + else: + best_shift = num_shift + + wfs1 = wfs1_[:, best_shift : best_shift + template0.shape[0], :] + template1 = template1_[best_shift : best_shift + template0.shape[0], :] + + if projection == "lda": + wfs_0_1 = np.concatenate([wfs0, wfs1], axis=0) + flat_wfs = wfs_0_1.reshape(wfs_0_1.shape[0], -1) + feat = LinearDiscriminantAnalysis(n_components=1).fit_transform(flat_wfs, labels) + feat = feat[:, 0] + feat0 = feat[:cut] + feat1 = feat[cut:] + + elif projection == "centroid": + vector_0_1 = template1 - template0 + vector_0_1 /= np.sum(vector_0_1**2) + feat0 = np.sum((wfs0 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat1 = np.sum((wfs1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + # feat = np.sum((wfs_0_1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat = np.concatenate([feat0, feat1], axis=0) + + else: + raise ValueError(f"bad projection {projection}") + + if criteria == "diptest": + dipscore, cutpoint = isocut5(feat) + is_merge = dipscore < threshold_diptest + merge_value = dipscore + elif criteria == "percentile": + l0 = np.percentile(feat0, threshold_percentile) + l1 = np.percentile(feat1, 100.0 - threshold_percentile) + is_merge = l0 >= l1 + merge_value = l0 - l1 + elif criteria == "distrib_overlap": + lim0 = min(np.min(feat0), np.min(feat1)) + lim1 = max(np.max(feat0), np.max(feat1)) + bin_size = (lim1 - lim0) / 200.0 + bins = np.arange(lim0, lim1, bin_size) + + pdf0, _ = np.histogram(feat0, bins=bins, density=True) + pdf1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 *= bin_size + pdf1 *= bin_size + overlap = np.sum(np.minimum(pdf0, pdf1)) + + is_merge = overlap >= threshold_overlap + + merge_value = 1 - overlap + + else: + raise ValueError(f"bad criteria {criteria}") + + if is_merge: + final_shift = best_shift - num_shift + else: + final_shift = 0 + + # DEBUG = True + DEBUG = False + + if DEBUG and is_merge: + # if DEBUG and not is_merge: + # if DEBUG and (overlap > 0.05 and overlap <0.25): + # if label0 == 49 and label1== 65: + import matplotlib.pyplot as plt + + flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) + flatten_wfs1 = wfs1.swapaxes(1, 2).reshape(wfs1.shape[0], -1) + + fig, axs = plt.subplots(ncols=2) + ax = axs[0] + ax.plot(flatten_wfs0.T, color="C0", alpha=0.01) + ax.plot(flatten_wfs1.T, color="C1", alpha=0.01) + m0 = np.mean(flatten_wfs0, axis=0) + m1 = np.mean(flatten_wfs1, axis=0) + ax.plot(m0, color="C0", alpha=1, lw=4, label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", alpha=1, lw=4, label=f"{label1} {inds1.size}") + + ax.legend() + + bins = np.linspace(np.percentile(feat, 1), np.percentile(feat, 99), 100) + bin_size = bins[1] - bins[0] + count0, _ = np.histogram(feat0, bins=bins, density=True) + count1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 = count0 * bin_size + pdf1 = count1 * bin_size + + ax = axs[1] + ax.plot(bins[:-1], pdf0, color="C0") + ax.plot(bins[:-1], pdf1, color="C1") + + if criteria == "diptest": + ax.set_title(f"{dipscore:.4f} {is_merge}") + elif criteria == "percentile": + ax.set_title(f"{l0:.4f} {l1:.4f} {is_merge}") + ax.axvline(l0, color="C0") + ax.axvline(l1, color="C1") + elif criteria == "distrib_overlap": + print( + lim0, + lim1, + ) + ax.set_title(f"{overlap:.4f} {is_merge}") + ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls="--", color="k") + + plt.show() + + return is_merge, label0, label1, final_shift, merge_value + + +find_pair_method_list = [ + ProjectDistribution, +] +find_pair_method_dict = {e.name: e for e in find_pair_method_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 864548e7d4..a81458d7a8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -181,24 +181,26 @@ def sigmoid(x, L, x0, k, b): else: tmp_folder = Path(params["tmp_folder"]) + tmp_folder.mkdir(parents=True, exist_ok=True) + + sorting_folder = tmp_folder / "sorting" + unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) + sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) + if params["shared_memory"]: waveform_folder = None mode = "memory" else: waveform_folder = tmp_folder / "waveforms" mode = "folder" + sorting = sorting.save(folder=sorting_folder) - sorting_folder = tmp_folder / "sorting" - unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) - sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) - sorting = sorting.save(folder=sorting_folder) we = extract_waveforms( recording, sorting, waveform_folder, - ms_before=params["ms_before"], - ms_after=params["ms_after"], **params["job_kwargs"], + **params["waveforms"], return_scaled=False, mode=mode, ) @@ -219,12 +221,14 @@ def sigmoid(x, L, x0, k, b): we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) + del we, sorting + if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) else: if not params["shared_memory"]: shutil.rmtree(tmp_folder / "waveforms") - shutil.rmtree(tmp_folder / "sorting") + shutil.rmtree(tmp_folder / "sorting") if verbose: print("We kept %d non-duplicated clusters..." % len(labels)) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py new file mode 100644 index 0000000000..9836e9110f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -0,0 +1,278 @@ +from multiprocessing import get_context +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +from sklearn.decomposition import TruncatedSVD +from hdbscan import HDBSCAN + +import numpy as np + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + +from .tools import aggregate_sparse_features, FeaturesLoader +from .isocut5 import isocut5 + + +# important all DEBUG and matplotlib are left in the code intentionally + + +def split_clusters( + peak_labels, + recording, + features_dict_or_folder, + method="hdbscan_on_local_pca", + method_kwargs={}, + recursive=False, + recursive_depth=None, + returns_split_count=False, + **job_kwargs, +): + """ + Run recusrsively (or not) in a multi process pool a local split method. + + Parameters + ---------- + peak_labels: numpy.array + Peak label before split + recording: Recording + Recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method name + method_kwargs: dict + The method option + recursive: bool Default False + Reccursive or not. + recursive_depth: None or int + If recursive=True, then this is the max split per spikes. + returns_split_count: bool + Optionally return the split count vector. Same size as labels. + + Returns + ------- + new_labels: numpy.ndarray + The labels of peaks after split. + split_count: numpy.ndarray + Optionally returned + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] + progress_bar = job_kwargs["progress_bar"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + + original_labels = peak_labels + peak_labels = peak_labels.copy() + split_count = np.zeros(peak_labels.size, dtype=int) + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=split_worker_init, + mp_context=get_context(mp_context), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + labels_set = np.setdiff1d(peak_labels, [-1]) + current_max_label = np.max(labels_set) + 1 + + jobs = [] + for label in labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) + else: + iterator = jobs + + for res in iterator: + is_split, local_labels, peak_indices = res.result() + if not is_split: + continue + + mask = local_labels >= 0 + peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label + peak_labels[peak_indices[~mask]] = local_labels[~mask] + + split_count[peak_indices] += 1 + + current_max_label += np.max(local_labels[mask]) + 1 + + if recursive: + if recursive_depth is not None: + # stop reccursivity when recursive_depth is reach + extra_ball = np.max(split_count[peak_indices]) < recursive_depth + else: + # reccurssive always + extra_ball = True + + if extra_ball: + new_labels_set = np.setdiff1d(peak_labels[peak_indices], [-1]) + for label in new_labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + if progress_bar: + iterator.total += 1 + + if returns_split_count: + return peak_labels, split_count + else: + return peak_labels + + +global _ctx + + +def split_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + features_dict_or_folder + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = split_methods_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def split_function_wrapper(peak_indices): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_split, local_labels = _ctx["method_class"].split( + peak_indices, _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + ) + return is_split, local_labels, peak_indices + + +class LocalFeatureClustering: + """ + This method is a refactorized mix between: + * old tridesclous code + * "herding_split()" in DART/spikepsvae by Charlie Windolf + + The idea simple : + * agregate features (svd or even waveforms) with sparse channel. + * run a local feature reduction (pca or svd) + * try a new split (hdscan or isocut5) + """ + + name = "local_feature_clustering" + + @staticmethod + def split( + peak_indices, + peaks, + features, + clusterer="hdbscan", + feature_name="sparse_tsvd", + neighbours_mask=None, + waveforms_sparse_mask=None, + min_size_split=25, + min_cluster_size=25, + min_samples=25, + n_pca_features=2, + minimum_common_channels=2, + ): + local_labels = np.zeros(peak_indices.size, dtype=np.int64) + + # can be sparse_tsvd or sparse_wfs + sparse_features = features[feature_name] + + assert waveforms_sparse_mask is not None + + # target channel subset is done intersect local channels + neighbours + local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + + # TODO fix this a better way, this when cluster have too few overlapping channels + if target_channels.size < minimum_common_channels: + return False, None + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels + ) + + local_labels[dont_have_channels] = -2 + kept = np.flatnonzero(~dont_have_channels) + if kept.size < min_size_split: + return False, None + + aligned_wfs = aligned_wfs[kept, :, :] + + flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) + + # final_features = PCA(n_pca_features, whiten=True).fit_transform(flatten_features) + # final_features = PCA(n_pca_features, whiten=False).fit_transform(flatten_features) + final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) + + if clusterer == "hdbscan": + clust = HDBSCAN( + min_cluster_size=min_cluster_size, + min_samples=min_samples, + allow_single_cluster=True, + cluster_selection_method="leaf", + ) + clust.fit(final_features) + possible_labels = clust.labels_ + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + elif clusterer == "isocut5": + dipscore, cutpoint = isocut5(final_features[:, 0]) + possible_labels = np.zeros(final_features.shape[0]) + if dipscore > 1.5: + mask = final_features[:, 0] > cutpoint + if np.sum(mask) > min_cluster_size and np.sum(~mask): + possible_labels[mask] = 1 + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + else: + is_split = False + else: + raise ValueError(f"wrong clusterer {clusterer}") + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + labels_set = np.setdiff1d(possible_labels, [-1]) + colors = plt.get_cmap("tab10", len(labels_set)) + colors = {k: colors(i) for i, k in enumerate(labels_set)} + colors[-1] = "k" + fix, axs = plt.subplots(nrows=2) + + flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1) + + sl = slice(None, None, 10) + for k in np.unique(possible_labels): + mask = possible_labels == k + ax = axs[0] + ax.scatter(final_features[:, 0][mask][sl], final_features[:, 1][mask][sl], s=5, color=colors[k]) + + ax = axs[1] + ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) + + axs[0].set_title(f"{clusterer} {is_split}") + + plt.show() + + if not is_split: + return is_split, None + + local_labels[kept] = possible_labels + + return is_split, local_labels + + +split_methods_list = [ + LocalFeatureClustering, +] +split_methods_dict = {e.name: e for e in split_methods_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py new file mode 100644 index 0000000000..8e25c9cb7f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -0,0 +1,196 @@ +from pathlib import Path +from typing import Any +import numpy as np + + +# TODO find a way to attach a a sparse_mask to a given features (waveforms, pca, tsvd ....) + + +class FeaturesLoader: + """ + Feature can be computed in memory or in a folder contaning npy files. + + This class read the folder and behave like a dict of array lazily. + + Parameters + ---------- + feature_folder + + preload + + """ + + def __init__(self, feature_folder, preload=["peaks"]): + self.feature_folder = Path(feature_folder) + + self.file_feature = {} + self.loaded_features = {} + for file in self.feature_folder.glob("*.npy"): + name = file.stem + if name in preload: + self.loaded_features[name] = np.load(file) + else: + self.file_feature[name] = file + + def __getitem__(self, name): + if name in self.loaded_features: + return self.loaded_features[name] + else: + return np.load(self.file_feature[name], mmap_mode="r") + + @staticmethod + def from_dict_or_folder(features_dict_or_folder): + if isinstance(features_dict_or_folder, dict): + return features_dict_or_folder + else: + return FeaturesLoader(features_dict_or_folder) + + +def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, target_channels): + """ + Aggregate sparse features that have unaligned channels and realigned then on target_channels. + + This is usefull to aligned back peaks waveform or pca or tsvd when detected a differents channels. + + + Parameters + ---------- + peaks + + peak_indices + + sparse_feature + + sparse_mask + + target_channels + + Returns + ------- + aligned_features: numpy.array + Aligned features. shape is (local_peaks.size, sparse_feature.shape[1], target_channels.size) + dont_have_channels: numpy.array + Boolean vector to indicate spikes that do not have all target channels to be taken in account + shape is peak_indices.size + """ + local_peaks = peaks[peak_indices] + + aligned_features = np.zeros( + (local_peaks.size, sparse_feature.shape[1], target_channels.size), dtype=sparse_feature.dtype + ) + dont_have_channels = np.zeros(peak_indices.size, dtype=bool) + + for chan in np.unique(local_peaks["channel_index"]): + sparse_chans = np.flatnonzero(sparse_mask[chan, :]) + peak_inds = np.flatnonzero(local_peaks["channel_index"] == chan) + if np.all(np.isin(target_channels, sparse_chans)): + # peaks feature channel have all target_channels + source_chans = np.flatnonzero(np.in1d(sparse_chans, target_channels)) + aligned_features[peak_inds, :, :] = sparse_feature[peak_indices[peak_inds], :, :][:, :, source_chans] + else: + # some channel are missing, peak are not removde + dont_have_channels[peak_inds] = True + + return aligned_features, dont_have_channels + + +def compute_template_from_sparse( + peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None +): + """ + Compute template average from single sparse waveforms buffer. + + Parameters + ---------- + peaks + + labels + + labels_set + + sparse_waveforms + + sparse_mask + + total_channels + + peak_shifts + + Returns + ------- + templates: numpy.array + Templates shape : (len(labels_set), num_samples, total_channels) + """ + n = len(labels_set) + + templates = np.zeros((n, sparse_waveforms.shape[1], total_channels), dtype=sparse_waveforms.dtype) + + for i, label in enumerate(labels_set): + peak_indices = np.flatnonzero(labels == label) + + local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(sparse_mask[local_chans, :], axis=0)) + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_waveforms, sparse_mask, target_channels + ) + + if peak_shifts is not None: + apply_waveforms_shift(aligned_wfs, peak_shifts[peak_indices], inplace=True) + + templates[i, :, :][:, target_channels] = np.mean(aligned_wfs[~dont_have_channels], axis=0) + + return templates + + +def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): + """ + Apply a shift a spike level to realign waveforms buffers. + + This is usefull to compute template after merge when to cluster are shifted. + + A negative shift need the waveforms to be moved toward the right because the trough was too early. + A positive shift need the waveforms to be moved toward the left because the trough was too late. + + Note the border sample are left as before without move. + + Parameters + ---------- + + waveforms + + peak_shifts + + inplace + + Returns + ------- + aligned_waveforms + + + """ + + print("apply_waveforms_shift") + + if inplace: + aligned_waveforms = waveforms + else: + aligned_waveforms = waveforms.copy() + + shift_set = np.unique(peak_shifts) + assert max(np.abs(shift_set)) < aligned_waveforms.shape[1] + + for shift in shift_set: + if shift == 0: + continue + mask = peak_shifts == shift + wfs = waveforms[mask] + + if shift > 0: + aligned_waveforms[mask, :-shift, :] = wfs[:, shift:, :] + else: + aligned_waveforms[mask, -shift:, :] = wfs[:, :-shift, :] + + print("apply_waveforms_shift DONE") + + return aligned_waveforms diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 358691cd25..ea36b75847 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -592,6 +592,7 @@ def _prepare_templates(cls, d): d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) d["singular"] = d["singular"].T[:, :, np.newaxis] + return d @classmethod diff --git a/src/spikeinterface/sortingcomponents/tests/test_merge.py b/src/spikeinterface/sortingcomponents/tests/test_merge.py new file mode 100644 index 0000000000..6b3ea2a901 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_merge.py @@ -0,0 +1,14 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + +# no proper test at the moment this is used in tridesclous2 + + +def test_merge(): + pass + + +if __name__ == "__main__": + test_merge() diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py new file mode 100644 index 0000000000..5953f74e24 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -0,0 +1,14 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + +# no proper test at the moment this is used in tridesclous2 + + +def test_split(): + pass + + +if __name__ == "__main__": + test_split() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index d25f1ea97b..468b96ff3b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -43,6 +43,8 @@ def plot(self): self._do_plot() def _do_plot(self): + from matplotlib import pyplot as plt + fig = self.figure for ax in fig.axes: @@ -177,6 +179,8 @@ def plot(self): def _do_plot(self): import sklearn + import matplotlib.pyplot as plt + import matplotlib # compute similarity # take index of template (respect unit_ids order) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 39eb80e2e5..8814e0131a 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -32,10 +32,10 @@ def setUp(self): self.num_units = len(self._sorting.get_unit_ids()) #  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True) - if (cache_folder / "mearec_test").is_dir(): - self._we = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_old_api").is_dir(): + self._we = load_waveforms(cache_folder / "mearec_test_old_api") else: - self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test") + self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test_old_api", sparse=False) self._amplitudes = compute_spike_amplitudes(self._we, peak_sign="neg", outputs="by_unit") self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index f44878927d..1a2fdf38d9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -48,29 +48,30 @@ def setUpClass(cls): cls.sorting = se.MEArecSortingExtractor(local_path) cls.num_units = len(cls.sorting.get_unit_ids()) - if (cache_folder / "mearec_test").is_dir(): - cls.we = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_dense").is_dir(): + cls.we_dense = load_waveforms(cache_folder / "mearec_test_dense") else: - cls.we = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test") + cls.we_dense = extract_waveforms( + cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False + ) + metric_names = ["snr", "isi_violation", "num_spikes"] + _ = compute_spike_amplitudes(cls.we_dense) + _ = compute_unit_locations(cls.we_dense) + _ = compute_spike_locations(cls.we_dense) + _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) + _ = compute_template_metrics(cls.we_dense) + _ = compute_correlograms(cls.we_dense) + _ = compute_template_similarity(cls.we_dense) sw.set_default_plotter_backend("matplotlib") - metric_names = ["snr", "isi_violation", "num_spikes"] - _ = compute_spike_amplitudes(cls.we) - _ = compute_unit_locations(cls.we) - _ = compute_spike_locations(cls.we) - _ = compute_quality_metrics(cls.we, metric_names=metric_names) - _ = compute_template_metrics(cls.we) - _ = compute_correlograms(cls.we) - _ = compute_template_similarity(cls.we) - # make sparse waveforms - cls.sparsity_radius = compute_sparsity(cls.we, method="radius", radius_um=50) - cls.sparsity_best = compute_sparsity(cls.we, method="best_channels", num_channels=5) + cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50) + cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5) if (cache_folder / "mearec_test_sparse").is_dir(): cls.we_sparse = load_waveforms(cache_folder / "mearec_test_sparse") else: - cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) + cls.we_sparse = cls.we_dense.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) cls.skip_backends = ["ipywidgets", "ephyviewer"] @@ -124,17 +125,17 @@ def test_plot_unit_waveforms(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_waveforms(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_waveforms(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_best, unit_ids=unit_ids, backend=backend, @@ -148,10 +149,10 @@ def test_plot_unit_templates(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_templates(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_templates(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_templates( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, @@ -171,7 +172,7 @@ def test_plot_unit_waveforms_density_map(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): @@ -180,7 +181,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, + self.we_dense, sparsity=self.sparsity_radius, same_axis=False, unit_ids=unit_ids, @@ -234,11 +235,15 @@ def test_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_amplitudes(self.we, backend=backend, **self.backend_kwargs[backend]) - unit_ids = self.we.unit_ids[:4] - sw.plot_amplitudes(self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) + sw.plot_amplitudes(self.we_dense, backend=backend, **self.backend_kwargs[backend]) + unit_ids = self.we_dense.unit_ids[:4] + sw.plot_amplitudes(self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) sw.plot_amplitudes( - self.we, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend] + self.we_dense, + unit_ids=unit_ids, + plot_histograms=True, + backend=backend, + **self.backend_kwargs[backend], ) sw.plot_amplitudes( self.we_sparse, @@ -252,9 +257,9 @@ def test_plot_all_amplitudes_distributions(self): possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - unit_ids = self.we.unit_ids[:4] + unit_ids = self.we_dense.unit_ids[:4] sw.plot_all_amplitudes_distributions( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) sw.plot_all_amplitudes_distributions( self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] @@ -264,7 +269,9 @@ def test_unit_locations(self): possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_unit_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -273,7 +280,9 @@ def test_spike_locations(self): possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_spike_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_spike_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_spike_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -282,28 +291,28 @@ def test_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_similarity(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_quality_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_quality_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_template_metrics(self): possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_depths(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_depths(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): @@ -311,17 +320,17 @@ def test_plot_unit_summary(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( - self.we, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_dense, self.we_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) sw.plot_unit_summary( - self.we_sparse, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_sparse, self.we_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) def test_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_agreement_matrix(self): @@ -369,10 +378,10 @@ def test_plot_rasters(self): # mytest.test_quality_metrics() # mytest.test_template_metrics() # mytest.test_amplitudes() - # mytest.test_plot_agreement_matrix() + mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() - mytest.test_plot_rasters() + # mytest.test_plot_rasters() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 9b6716e8f3..fc8b30eb05 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,26 +88,32 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - layer_keys = list(recordings.keys()) + if rec0.has_channel_location(): + channel_locations = rec0.get_channel_locations() + else: + channel_locations = None - if segment_index is None: - if rec0.get_num_segments() != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 + if order_channel_by_depth and channel_locations is not None: + from ..preprocessing import depth_order + + rec0 = depth_order(rec0) + recordings = {k: depth_order(rec) for k, rec in recordings.items()} + + if channel_ids is not None: + # ensure that channel_ids are in the good order + channel_ids_ = list(rec0.channel_ids) + order = np.argsort([channel_ids_.index(c) for c in channel_ids]) + channel_ids = list(np.array(channel_ids)[order]) if channel_ids is None: channel_ids = rec0.channel_ids - if "location" in rec0.get_property_keys(): - channel_locations = rec0.get_channel_locations() - else: - channel_locations = None + layer_keys = list(recordings.keys()) - if order_channel_by_depth: - if channel_locations is not None: - order, _ = order_channels_by_depth(rec0, channel_ids) - else: - order = None + if segment_index is None: + if rec0.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 fs = rec0.get_sampling_frequency() if time_range is None: @@ -124,7 +130,7 @@ def __init__( cmap = cmap times, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, order, return_scaled + recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled ) # stat for auto scaling done on the first layer @@ -138,9 +144,10 @@ def __init__( # colors is a nested dict by layer and channels # lets first create black for all channels and layer + # all color are generated for ipywidgets colors = {} for k in layer_keys: - colors[k] = {chan_id: "k" for chan_id in channel_ids} + colors[k] = {chan_id: "k" for chan_id in rec0.channel_ids} if color_groups: channel_groups = rec0.get_channel_groups(channel_ids=channel_ids) @@ -149,7 +156,7 @@ def __init__( group_colors = get_some_colors(groups, color_engine="auto") channel_colors = {} - for i, chan_id in enumerate(channel_ids): + for i, chan_id in enumerate(rec0.channel_ids): group = channel_groups[i] channel_colors[chan_id] = group_colors[group] @@ -159,12 +166,12 @@ def __init__( elif color is not None: # old behavior one color for all channel # if multi layer then black for all - colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids} + colors[layer_keys[0]] = {chan_id: color for chan_id in rec0.channel_ids} elif color is None and len(recordings) > 1: # several layer layer_colors = get_some_colors(layer_keys) for k in layer_keys: - colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids} + colors[k] = {chan_id: layer_colors[k] for chan_id in rec0.channel_ids} else: # color is None unique layer : all channels black pass @@ -201,7 +208,6 @@ def __init__( show_channel_ids=show_channel_ids, add_legend=add_legend, order_channel_by_depth=order_channel_by_depth, - order=order, tile_size=tile_size, num_timepoints_per_row=int(seconds_per_row * fs), return_scaled=return_scaled, @@ -336,6 +342,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) + self.channel_selector.value = list(data_plot["channel_ids"]) left_sidebar = W.VBox( children=[ @@ -398,17 +405,17 @@ def _mode_changed(self, change=None): def _retrieve_traces(self, change=None): channel_ids = np.array(self.channel_selector.value) - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None + # if self.data_plot["order_channel_by_depth"]: + # order, _ = order_channels_by_depth(self.rec0, channel_ids) + # else: + # order = None start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled ) self._channel_ids = channel_ids @@ -523,7 +530,7 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs): app.exec() -def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): +def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] rec0 = recordings[k0] @@ -550,11 +557,6 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No return_scaled=return_scaled, ) - if order is not None: - traces = traces[:, order] list_traces.append(traces) - if order is not None: - channel_ids = np.array(channel_ids)[order] - return times, list_traces, frame_range, channel_ids diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 6e872eca55..58dd5c7f32 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -235,8 +235,7 @@ def __init__(self, channel_ids, **kwargs): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.selector.observe(self.on_selector_changed, names=["value"], type="change") - # TODO external value change - # self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value @@ -260,6 +259,18 @@ def on_selector_changed(self, change=None): self.value = channel_ids + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + channel_ids = self.selector.value + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + class ScaleWidget(W.VBox): value = traitlets.Float()