From 1424060effa175fdf54b94551017fd0a6ed31be4 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 15 Apr 2024 08:51:16 -0400 Subject: [PATCH 1/2] port bug fixes from a couple PRs --- .../curation/remove_excess_spikes.py | 3 +- .../tests/test_remove_excess_spikes.py | 11 ++++- .../sorters/external/kilosort4.py | 4 +- .../sorters/external/mountainsort5.py | 46 ++++++++++--------- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 9c37a48ba1..5fb05a64a7 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -78,8 +78,9 @@ def get_unit_spike_train( ) -> np.ndarray: spike_train = self._parent_segment.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame) max_spike = np.searchsorted(spike_train, self._num_samples, side="left") + min_spike = np.searchsorted(spike_train, 0, side="left") - return spike_train[:max_spike] + return spike_train[min_spike:max_spike] def remove_excess_spikes(sorting, recording): diff --git a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py index f99c408c24..7175e0a614 100644 --- a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py +++ b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py @@ -14,6 +14,7 @@ def test_remove_excess_spikes(): num_spikes = 100 num_num_samples_spikes_per_segment = 5 num_excess_spikes_per_segment = 5 + num_neg_spike_times_per_segment = 2 times = [] labels = [] for segment_index in range(recording.get_num_segments()): @@ -21,12 +22,15 @@ def test_remove_excess_spikes(): times_segment = np.array([], dtype=int) labels_segment = np.array([], dtype=int) for unit in range(num_units): + neg_spike_times = np.random.randint(-50, 0, num_neg_spike_times_per_segment) spike_times = np.random.randint(0, num_samples, num_spikes) last_samples_spikes = (num_samples - 1) * np.ones(num_num_samples_spikes_per_segment, dtype=int) num_samples_spike_times = num_samples * np.ones(num_num_samples_spikes_per_segment, dtype=int) excess_spikes = np.random.randint(num_samples, num_samples + 100, num_excess_spikes_per_segment) spike_times = np.sort( - np.concatenate((spike_times, last_samples_spikes, num_samples_spike_times, excess_spikes)) + np.concatenate( + (neg_spike_times, spike_times, last_samples_spikes, num_samples_spike_times, excess_spikes) + ) ) spike_labels = unit * np.ones_like(spike_times) times_segment = np.concatenate((times_segment, spike_times)) @@ -47,7 +51,10 @@ def test_remove_excess_spikes(): assert ( len(spike_train_corrected) - == len(spike_train_excess) - num_num_samples_spikes_per_segment - num_excess_spikes_per_segment + == len(spike_train_excess) + - num_num_samples_spikes_per_segment + - num_excess_spikes_per_segment + - num_neg_spike_times_per_segment ) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 99de91a795..6ff836b753 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -32,7 +32,7 @@ class Kilosort4Sorter(BaseSorter): "sig_interp": 20, "nt0min": None, "dmin": None, - "dminx": None, + "dminx": 32, "min_template_size": 10, "template_sizes": 5, "nearest_chans": 10, @@ -68,7 +68,7 @@ class Kilosort4Sorter(BaseSorter): "sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.", "nt0min": "Sample index for aligning waveforms, so that their minimum or maximum value happens here. Default of 20. Default value: None.", "dmin": "Vertical spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.", - "dminx": "Horizontal spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.", + "dminx": "Horizontal spacing of template centers used for spike detection, in microns. Default value: 32.", "min_template_size": "Standard deviation of the smallest, spatial envelope Gaussian used for universal templates. Default value: 10.", "template_sizes": "Number of sizes for universal spike templates (multiples of the min_template_size). Default value: 5.", "nearest_chans": "Number of nearest channels to consider when finding local maxima during spike detection. Default value: 10.", diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 1fcbe35c14..d516089d34 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -3,14 +3,15 @@ from pathlib import Path from packaging.version import parse -from tempfile import TemporaryDirectory +import shutil import numpy as np +from warnings import warn from spikeinterface.preprocessing import bandpass_filter, whiten from spikeinterface.core.baserecording import BaseRecording -from ..basesorter import BaseSorter +from ..basesorter import BaseSorter, get_job_kwargs from spikeinterface.extractors import NpzSortingExtractor @@ -43,8 +44,7 @@ class Mountainsort5Sorter(BaseSorter): "freq_max": 6000, "filter": True, "whiten": True, # Important to do whitening - "temporary_base_dir": None, - "n_jobs_for_preprocessing": -1, + "delete_temporary_recording": True, } _params_description = { @@ -68,8 +68,7 @@ class Mountainsort5Sorter(BaseSorter): "freq_max": "Low-pass filter cutoff frequency", "filter": "Enable or disable filter", "whiten": "Enable or disable whitening", - "temporary_base_dir": "Temporary directory base directory for storing cached recording", - "n_jobs_for_preprocessing": "Number of parallel jobs for creating the cached recording", + "delete_temporary_recording": "If True, the temporary recording file is deleted after sorting (this may fail on Windows requiring the end-user to delete the file themselves later)", } sorter_description = "MountainSort5 uses Isosplit clustering. It is an updated version of MountainSort4. See https://doi.org/10.1016/j.neuron.2017.08.030" @@ -186,21 +185,26 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"] ) - with TemporaryDirectory(dir=p["temporary_base_dir"]) as tmpdir: - # cache the recording to a temporary directory for efficient reading (so we don't have to re-filter) - recording_cached = create_cached_recording( - recording=recording, folder=tmpdir, n_jobs=p["n_jobs_for_preprocessing"] - ) - - scheme = p["scheme"] - if scheme == "1": - sorting = ms5.sorting_scheme1(recording=recording_cached, sorting_parameters=scheme1_sorting_parameters) - elif p["scheme"] == "2": - sorting = ms5.sorting_scheme2(recording=recording_cached, sorting_parameters=scheme2_sorting_parameters) - elif p["scheme"] == "3": - sorting = ms5.sorting_scheme3(recording=recording_cached, sorting_parameters=scheme3_sorting_parameters) - else: - raise ValueError(f"Invalid scheme: {scheme} given. scheme must be one of '1', '2' or '3'") + if not recording.is_binary_compatible(): + recording_cached = recording.save(folder=sorter_output_folder / "recording", **get_job_kwargs(p, verbose)) + else: + recording_cached = recording + + if p["scheme"] == "1": + sorting = ms5.sorting_scheme1(recording=recording_cached, sorting_parameters=scheme1_sorting_parameters) + elif p["scheme"] == "2": + sorting = ms5.sorting_scheme2(recording=recording_cached, sorting_parameters=scheme2_sorting_parameters) + elif p["scheme"] == "3": + sorting = ms5.sorting_scheme3(recording=recording_cached, sorting_parameters=scheme3_sorting_parameters) + else: + raise ValueError(f"Invalid scheme: {p['scheme']} given. scheme must be one of '1', '2' or '3'") + + if p["delete_temporary_recording"]: + if not recording.is_binary_compatible(): + del recording_cached + shutil.rmtree(sorter_output_folder / "recording", ignore_errors=True) + if Path(sorter_output_folder / "recording").is_dir(): + warn("cleanup failed, please remove file yourself if desired") NpzSortingExtractor.write_sorting(sorting, str(sorter_output_folder / "firings.npz")) From 7ad9bbb396eb47f824ada4c7653d1f0fba7783cf Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 17 Apr 2024 12:50:38 -0400 Subject: [PATCH 2/2] fix has exceeding for bug fix --- src/spikeinterface/core/waveform_tools.py | 2 ++ src/spikeinterface/curation/remove_excess_spikes.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 58243ceea2..f9e39382df 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -696,6 +696,8 @@ def has_exceeding_spikes(recording, sorting): if len(spike_vector_seg) > 0: if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1: return True + if spike_vector_seg["sample_index"][0] < 0: + return True return False diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 5fb05a64a7..450b31e3d4 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -61,7 +61,8 @@ def _custom_cache_spike_vector(self) -> None: for segment_index in range(num_segments): spike_vector = parent_spike_vector[segments_bounds[segment_index] : segments_bounds[segment_index + 1]] end = np.searchsorted(spike_vector["sample_index"], self._num_samples[segment_index]) - list_spike_vectors.append(spike_vector[:end]) + start = np.searchsorted(spike_vector["sample_index"], 0) + list_spike_vectors.append(spike_vector[start:end]) spike_vector = np.concatenate(list_spike_vectors) self._cached_spike_vector = spike_vector