From 5370884331aa3b4ebdf25c10bf4d103fd502f28a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 1 Sep 2023 12:05:41 +0200 Subject: [PATCH 01/29] Refactor compite_spike_location using SpikeRetriver. --- .../core/tests/test_node_pipeline.py | 3 - .../postprocessing/spike_locations.py | 37 ++++++-- .../tests/test_spike_locations.py | 3 +- .../sortingcomponents/peak_localization.py | 84 +++++++++++-------- src/spikeinterface/sortingcomponents/tools.py | 3 + 5 files changed, 84 insertions(+), 46 deletions(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 339167f673..db5305c313 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -97,9 +97,6 @@ def test_run_node_pipeline(): # test with 2 diffrents first node for peak_source in (peak_retriever, spike_retriever_T, spike_retriever_S): - - - # one step only : squeeze output nodes = [ peak_source, diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..32443d44d0 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -5,6 +5,7 @@ from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +from spikeinterface.core.node_pipeline import SpikeRetriever class SpikeLocationsCalculator(BaseWaveformExtractorExtension): @@ -25,9 +26,12 @@ def __init__(self, waveform_extractor): extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", method_kwargs={}): - params = dict(ms_before=ms_before, ms_after=ms_after, method=method) + + + def _set_params(self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={}): + params = dict(ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method) params.update(**method_kwargs) + print(params) return params def _select_extension_data(self, unit_ids): @@ -44,13 +48,28 @@ def _run(self, **job_kwargs): uses the`sortingcomponents.peak_localization.localize_peaks()` function to triangulate spike locations. """ - from spikeinterface.sortingcomponents.peak_localization import localize_peaks + from spikeinterface.sortingcomponents.peak_localization import _run_localization_from_peak_source job_kwargs = fix_job_kwargs(job_kwargs) we = self.waveform_extractor - spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs) + extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") + + params = self._params.copy() + channel_from_template = params.pop("channel_from_template") + + # @alessio @pierre: where do we expose the parameters of radius for the retriever (this is not the same as the one for locatization it is smaller) ??? + spike_retriever = SpikeRetriever( + we.recording, + we.sorting, + channel_from_template=channel_from_template, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign=self._params.get("peaks_sign", "neg") + ) + spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) + self._extension_data["spike_locations"] = spike_locations def get_data(self, outputs="concatenated"): @@ -95,12 +114,15 @@ def get_extension_function(): WaveformExtractor.register_extension(SpikeLocationsCalculator) +# @alessio @pierre: channel_from_template=True is the old behavior but this is not accurate +# what do we put by default ? def compute_spike_locations( waveform_extractor, load_if_exists=False, ms_before=0.5, ms_after=0.5, + channel_from_template=True, method="center_of_mass", method_kwargs={}, outputs="concatenated", @@ -119,6 +141,10 @@ def compute_spike_locations( The left window, before a peak, in milliseconds. ms_after : float The right window, after a peak, in milliseconds. + channel_from_template: bool, default True + For each spike is the maximum channel computed from template or re estimated at every spikes. + channel_from_template = True is old behavior but less acurate + channel_from_template = False is slower but more accurate method : str 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' method_kwargs : dict @@ -138,7 +164,8 @@ def compute_spike_locations( slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name) else: slc = SpikeLocationsCalculator(waveform_extractor) - slc.set_params(ms_before=ms_before, ms_after=ms_after, method=method, method_kwargs=method_kwargs) + slc.set_params(ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, + method=method, method_kwargs=method_kwargs) slc.run(**job_kwargs) locs = slc.get_data(outputs=outputs) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 521b49e6cd..ab2345b1f5 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -10,7 +10,8 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes extension_class = SpikeLocationsCalculator extension_data_names = ["spike_locations"] extension_function_kwargs_list = [ - dict(method="center_of_mass", chunk_size=10000, n_jobs=1), + dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=True), + dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=False), dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index fa6101f896..b638e8ed3a 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -7,6 +7,7 @@ run_node_pipeline, find_parent_of_type, PeakRetriever, + SpikeRetriever, PipelineNode, WaveformsNode, ExtractDenseWaveforms, @@ -27,72 +28,49 @@ from .tools import get_prototype_spike -def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): - """Localize peak (spike) in 2D or 3D depending the method. - - When a probe is 2D then: - * X is axis 0 of the probe - * Y is axis 1 of the probe - * Z is orthogonal to the plane of the probe - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object. - peaks: array - Peaks array, as returned by detect_peaks() in "compact_numpy" way. - - {method_doc} - - {job_doc} - - Returns - ------- - peak_locations: ndarray - Array with estimated location for each spike. - The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha'). - """ +def _run_localization_from_peak_source(recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): + # use by localize_peaks() and compute_spike_locations() assert ( method in possible_localization_methods ), f"Method {method} is not supported. Choose from {possible_localization_methods}" method_kwargs, job_kwargs = split_job_kwargs(kwargs) - peak_retriever = PeakRetriever(recording, peaks) if method == "center_of_mass": extract_dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False ) pipeline_nodes = [ - peak_retriever, + peak_source, extract_dense_waveforms, - LocalizeCenterOfMass(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs), + LocalizeCenterOfMass(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] elif method == "monopolar_triangulation": extract_dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False ) pipeline_nodes = [ - peak_retriever, + peak_source, extract_dense_waveforms, LocalizeMonopolarTriangulation( - recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs + recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs ), ] elif method == "peak_channel": - pipeline_nodes = [peak_retriever, LocalizePeakChannel(recording, parents=[peak_retriever], **method_kwargs)] + pipeline_nodes = [peak_source, LocalizePeakChannel(recording, parents=[peak_source], **method_kwargs)] elif method == "grid_convolution": if "prototype" not in method_kwargs: + assert isinstance(peak_source, (PeakRetriever, SpikeRetriever)) method_kwargs["prototype"] = get_prototype_spike( - recording, peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs + recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, job_kwargs=job_kwargs ) extract_dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False ) pipeline_nodes = [ - peak_retriever, + peak_source, extract_dense_waveforms, - LocalizeGridConvolution(recording, parents=[peak_retriever, extract_dense_waveforms], **method_kwargs), + LocalizeGridConvolution(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] job_name = f"localize peaks using {method}" @@ -101,6 +79,38 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ return peak_locations + +def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): + """Localize peak (spike) in 2D or 3D depending the method. + + When a probe is 2D then: + * X is axis 0 of the probe + * Y is axis 1 of the probe + * Z is orthogonal to the plane of the probe + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object. + peaks: array + Peaks array, as returned by detect_peaks() in "compact_numpy" way. + + {method_doc} + + {job_doc} + + Returns + ------- + peak_locations: ndarray + Array with estimated location for each spike. + The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha'). + """ + peak_retriever = PeakRetriever(recording, peaks) + peak_locations = _run_localization_from_peak_source(recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs) + return peak_locations + + + class LocalizeBase(PipelineNode): def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 45b9079ea9..576732baa2 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -19,6 +19,9 @@ def make_multi_method_doc(methods, ident=" "): def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): + # TODO for Pierre: this function is really unefficient because it runa full pipeline only for a few + # spikes, which leans that traces are entirally computed!!!!! + # Please find a better way nb_peaks = min(len(peaks), nb_peaks) idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) peak_retriever = PeakRetriever(recording, peaks[idx]) From 68df57384e5ca424c6b07aacf3e48933c1b5fa55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Sep 2023 10:06:45 +0000 Subject: [PATCH 02/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 19 +++++++--------- .../core/tests/test_node_pipeline.py | 22 ++++++++++--------- .../postprocessing/spike_locations.py | 22 +++++++++++++------ .../sortingcomponents/peak_localization.py | 14 ++++++------ 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ff747fe2a0..610ae42398 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -141,7 +141,7 @@ class SpikeRetriever(PeakSource): """ This class is usefull to inject a sorting object in the node pipepline mechanisim. It allows to compute some post processing with the same machinery used for sorting components. - This is a first step to totaly refactor: + This is a first step to totaly refactor: * compute_spike_locations() * compute_amplitude_scalings() * compute_spike_amplitudes() @@ -164,16 +164,14 @@ class SpikeRetriever(PeakSource): Peak sign to find the max channel. Used only when channel_from_template=False """ - def __init__(self, recording, sorting, - channel_from_template=True, - extremum_channel_inds=None, - radius_um=50, - peak_sign="neg" - ): + + def __init__( + self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg" + ): PipelineNode.__init__(self, recording, return_output=False) self.channel_from_template = channel_from_template - + assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" self.peaks = sorting_to_peak(sorting, extremum_channel_inds) @@ -181,8 +179,7 @@ def __init__(self, recording, sorting, if not channel_from_template: channel_distance = get_channel_distances(recording) self.neighbours_mask = channel_distance < radius_um - self.peak_sign = peak_sign - + self.peak_sign = peak_sign # precompute segment slice self.segment_slices = [] @@ -219,7 +216,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): elif self.peak_sign == "pos": local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] elif self.peak_sign == "both": - local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] # TODO: "amplitude" ??? diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index db5305c313..f271e81869 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -81,22 +81,24 @@ def test_run_node_pipeline(): we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") peaks = sorting_to_peak(sorting, extremum_channel_inds) - + peak_retriever = PeakRetriever(recording, peaks) # channel index is from template - spike_retriever_T = SpikeRetriever(recording, sorting, - channel_from_template=True, - extremum_channel_inds=extremum_channel_inds) + spike_retriever_T = SpikeRetriever( + recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + ) # channel index is per spike - spike_retriever_S = SpikeRetriever(recording, sorting, - channel_from_template=False, - extremum_channel_inds=extremum_channel_inds, - radius_um=50, - peak_sign="neg") + spike_retriever_S = SpikeRetriever( + recording, + sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg", + ) # test with 2 diffrents first node for peak_source in (peak_retriever, spike_retriever_T, spike_retriever_S): - # one step only : squeeze output nodes = [ peak_source, diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 32443d44d0..1da2858142 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -26,10 +26,12 @@ def __init__(self, waveform_extractor): extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - - - def _set_params(self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={}): - params = dict(ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method) + def _set_params( + self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={} + ): + params = dict( + ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method + ) params.update(**method_kwargs) print(params) return params @@ -66,7 +68,7 @@ def _run(self, **job_kwargs): channel_from_template=channel_from_template, extremum_channel_inds=extremum_channel_inds, radius_um=50, - peak_sign=self._params.get("peaks_sign", "neg") + peak_sign=self._params.get("peaks_sign", "neg"), ) spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) @@ -117,6 +119,7 @@ def get_extension_function(): # @alessio @pierre: channel_from_template=True is the old behavior but this is not accurate # what do we put by default ? + def compute_spike_locations( waveform_extractor, load_if_exists=False, @@ -164,8 +167,13 @@ def compute_spike_locations( slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name) else: slc = SpikeLocationsCalculator(waveform_extractor) - slc.set_params(ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, - method=method, method_kwargs=method_kwargs) + slc.set_params( + ms_before=ms_before, + ms_after=ms_after, + channel_from_template=channel_from_template, + method=method, + method_kwargs=method_kwargs, + ) slc.run(**job_kwargs) locs = slc.get_data(outputs=outputs) diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index b638e8ed3a..6495503b43 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -28,7 +28,9 @@ from .tools import get_prototype_spike -def _run_localization_from_peak_source(recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): +def _run_localization_from_peak_source( + recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs +): # use by localize_peaks() and compute_spike_locations() assert ( method in possible_localization_methods @@ -52,9 +54,7 @@ def _run_localization_from_peak_source(recording, peak_source, method="center_of pipeline_nodes = [ peak_source, extract_dense_waveforms, - LocalizeMonopolarTriangulation( - recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs - ), + LocalizeMonopolarTriangulation(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] elif method == "peak_channel": pipeline_nodes = [peak_source, LocalizePeakChannel(recording, parents=[peak_source], **method_kwargs)] @@ -79,7 +79,6 @@ def _run_localization_from_peak_source(recording, peak_source, method="center_of return peak_locations - def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): """Localize peak (spike) in 2D or 3D depending the method. @@ -106,11 +105,12 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ The dtype depends on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha'). """ peak_retriever = PeakRetriever(recording, peaks) - peak_locations = _run_localization_from_peak_source(recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs) + peak_locations = _run_localization_from_peak_source( + recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs + ) return peak_locations - class LocalizeBase(PipelineNode): def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) From bfc7ebe7f51e1b11f1556063e220d7eb1846a723 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:17:59 -0400 Subject: [PATCH 03/29] refactor isi calculation --- src/spikeinterface/postprocessing/isi.py | 85 ++++-------------------- 1 file changed, 13 insertions(+), 72 deletions(-) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index aec70141cf..9f6c649693 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -65,61 +65,6 @@ def get_extension_function(): WaveformExtractor.register_extension(ISIHistogramsCalculator) -def compute_isi_histograms_from_spiketrain(spike_train: np.ndarray, max_time: int, bin_size: int, sampling_f: float): - """ - Computes the Inter-Spike Intervals histogram from a given spike train. - - This implementation only works if you have numba installed, to accelerate the - computation time. - - Parameters - ---------- - spike_train: np.ndarray - The ordered spike train to compute the ISI. - max_time: int - Compute the ISI from 0 to +max_time (in sampling time). - bin_size: int - Size of a bin (in sampling time). - sampling_f: float - Sampling rate/frequency (in Hz). - - Returns - ------- - tuple (ISI, bins) - ISI: np.ndarray[int64] - The computed ISI histogram. - bins: np.ndarray[float64] - The bins for the ISI histogram. - """ - if not HAVE_NUMBA: - print("Error: numba is not installed.") - print("compute_ISI_from_spiketrain cannot run without numba.") - return 0 - - return _compute_isi_histograms_from_spiketrain(spike_train.astype(np.int64), max_time, bin_size, sampling_f) - - -if HAVE_NUMBA: - - @numba.jit((numba.int64[::1], numba.int32, numba.int32, numba.float32), nopython=True, nogil=True, cache=True) - def _compute_isi_histograms_from_spiketrain(spike_train, max_time, bin_size, sampling_f): - n_bins = int(max_time / bin_size) - - bins = np.arange(0, max_time + bin_size, bin_size) * 1e3 / sampling_f - ISI = np.zeros(n_bins, dtype=np.int64) - - for i in range(1, len(spike_train)): - diff = spike_train[i] - spike_train[i - 1] - - if diff >= max_time: - continue - - bin = int(diff / bin_size) - ISI[bin] += 1 - - return ISI, bins - - def compute_isi_histograms( waveform_or_sorting_extractor, load_if_exists=False, @@ -140,7 +85,7 @@ def compute_isi_histograms( bin_ms : float, optional The bin size in ms, by default 1.0. method : str, optional - "auto" | "numpy" | "numba". If _auto" and numba is installed, numba is used, by default "auto" + "auto" | "numpy" | "numba". If "auto" and numba is installed, numba is used, by default "auto" Returns ------- @@ -191,21 +136,20 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float """ fs = sorting.get_sampling_frequency() num_units = len(sorting.unit_ids) - + assert bin_ms >= 1/fs, f"bin size must be larger than the sampling period {1/fs}" + assert bin_ms <= window_ms window_size = int(round(fs * window_ms * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size - num_bins = int(window_size / bin_size) - assert num_bins >= 1 - - ISIs = np.zeros((num_units, num_bins), dtype=np.int64) bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs + ISIs = np.zeros((num_units, len(bins)-1), dtype=np.int64) + # TODO: There might be a better way than a double for loop? for i, unit_id in enumerate(sorting.unit_ids): for seg_index in range(sorting.get_num_segments()): spike_train = sorting.get_unit_spike_train(unit_id, segment_index=seg_index) - ISI = np.histogram(np.diff(spike_train), bins=num_bins, range=(0, window_size - 1))[0] + ISI = np.histogram(np.diff(spike_train), bins=bins)[0] ISIs[i] += ISI return ISIs, bins @@ -224,18 +168,18 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float assert HAVE_NUMBA fs = sorting.get_sampling_frequency() + assert bin_ms >= 1/fs, f"the bin_ms must be larger than the sampling period: {1/fs}" + assert bin_ms <= window_ms num_units = len(sorting.unit_ids) window_size = int(round(fs * window_ms * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size - num_bins = int(window_size / bin_size) - assert num_bins >= 1 bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs spikes = sorting.to_spike_vector(concatenated=False) - ISIs = np.zeros((num_units, num_bins), dtype=np.int64) + ISIs = np.zeros((num_units, len(bins)-1), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): spike_times = spikes[seg_index]["sample_index"].astype(np.int64) @@ -245,9 +189,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float ISIs, spike_times, spike_labels, - window_size, - bin_size, - fs, + bins, ) return ISIs, bins @@ -256,16 +198,15 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float if HAVE_NUMBA: @numba.jit( - (numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.int32, numba.int32, numba.float32), + (numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.float64[::1]), nopython=True, nogil=True, cache=True, parallel=True, ) - def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, max_time, bin_size, sampling_f): + def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins): n_units = ISIs.shape[0] for i in numba.prange(n_units): spike_train = spike_trains[spike_clusters == i] - - ISIs[i] += _compute_isi_histograms_from_spiketrain(spike_train, max_time, bin_size, sampling_f)[0] + ISIs[i] += np.histogram(np.diff(spike_train), bins=bins)[0] From e94b9e5e81e994629d38b10986a0bc4dbcabef3a Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:42:10 -0400 Subject: [PATCH 04/29] fix assert check of bin_ms --- src/spikeinterface/postprocessing/isi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 9f6c649693..5fc18960d9 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -136,7 +136,7 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float """ fs = sorting.get_sampling_frequency() num_units = len(sorting.unit_ids) - assert bin_ms >= 1/fs, f"bin size must be larger than the sampling period {1/fs}" + assert bin_ms * 1e-3 >= 1/fs, f"bin size must be larger than the sampling period {1/fs}" assert bin_ms <= window_ms window_size = int(round(fs * window_ms * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) @@ -168,7 +168,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float assert HAVE_NUMBA fs = sorting.get_sampling_frequency() - assert bin_ms >= 1/fs, f"the bin_ms must be larger than the sampling period: {1/fs}" + assert bin_ms * 1e-3 >= 1/fs, f"the bin_ms must be larger than the sampling period: {1/fs}" assert bin_ms <= window_ms num_units = len(sorting.unit_ids) From 5a9b3b0cba29558b844bea55b8eafbedb5cbe3aa Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 1 Sep 2023 22:21:45 -0400 Subject: [PATCH 05/29] remove extra function --- src/spikeinterface/postprocessing/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 223bda5e30..114a53176b 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -39,7 +39,6 @@ from .isi import ( ISIHistogramsCalculator, - compute_isi_histograms_from_spiketrain, compute_isi_histograms, compute_isi_histograms_numpy, compute_isi_histograms_numba, From 03822ee886e81e199293d8b37b9b3df9d58bc051 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Sat, 2 Sep 2023 08:41:13 -0400 Subject: [PATCH 06/29] remove parallel -> speed up 25% --- src/spikeinterface/postprocessing/isi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 5fc18960d9..139df1aec6 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -202,11 +202,10 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float nopython=True, nogil=True, cache=True, - parallel=True, ) def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins): n_units = ISIs.shape[0] - for i in numba.prange(n_units): + for i in range(n_units): spike_train = spike_trains[spike_clusters == i] ISIs[i] += np.histogram(np.diff(spike_train), bins=bins)[0] From 6fe79d8bdecd68fb6ff29c10874506cb18206386 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Sun, 3 Sep 2023 15:35:01 -0400 Subject: [PATCH 07/29] fix assert to be in milliseconds --- src/spikeinterface/postprocessing/isi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 139df1aec6..782838d7f0 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -136,7 +136,7 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float """ fs = sorting.get_sampling_frequency() num_units = len(sorting.unit_ids) - assert bin_ms * 1e-3 >= 1/fs, f"bin size must be larger than the sampling period {1/fs}" + assert bin_ms * 1e-3 >= 1 / fs, f"bin size must be larger than the sampling period {1e3 / fs}" assert bin_ms <= window_ms window_size = int(round(fs * window_ms * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) @@ -168,7 +168,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float assert HAVE_NUMBA fs = sorting.get_sampling_frequency() - assert bin_ms * 1e-3 >= 1/fs, f"the bin_ms must be larger than the sampling period: {1/fs}" + assert bin_ms * 1e-3 >= 1 / fs, f"the bin_ms must be larger than the sampling period: {1e3 / fs}" assert bin_ms <= window_ms num_units = len(sorting.unit_ids) From 2fc6dd502679b0d7903ceddd1d9731dc205b7f87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Sep 2023 12:31:55 +0000 Subject: [PATCH 08/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/postprocessing/isi.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 782838d7f0..57e4bdc5b2 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -142,8 +142,7 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs - ISIs = np.zeros((num_units, len(bins)-1), dtype=np.int64) - + ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64) # TODO: There might be a better way than a double for loop? for i, unit_id in enumerate(sorting.unit_ids): @@ -179,7 +178,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs spikes = sorting.to_spike_vector(concatenated=False) - ISIs = np.zeros((num_units, len(bins)-1), dtype=np.int64) + ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): spike_times = spikes[seg_index]["sample_index"].astype(np.int64) From 8710551a8d9d762b4284c788b63125ef2719dad3 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 4 Sep 2023 09:02:33 -0400 Subject: [PATCH 09/29] test add seed --- src/spikeinterface/postprocessing/tests/test_correlograms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index d6648150de..071a027235 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -89,7 +89,7 @@ def test_equal_results_correlograms(): def test_flat_cross_correlogram(): - sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0]) + sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=371532) methods = ["numpy"] if HAVE_NUMBA: From 008078165c09433dc6bd923206066ecc5117e307 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 7 Sep 2023 06:37:14 -0400 Subject: [PATCH 10/29] resolve merge conflict --- src/spikeinterface/postprocessing/tests/test_correlograms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 071a027235..0df00dd037 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -89,7 +89,7 @@ def test_equal_results_correlograms(): def test_flat_cross_correlogram(): - sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=371532) + sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=0) methods = ["numpy"] if HAVE_NUMBA: From 6c57b5ca2b11c9a0f75b01eb59a2c81e00095bb6 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Wed, 4 Oct 2023 09:46:41 -0400 Subject: [PATCH 11/29] adjust eps for whitening in case of very small magnitude data --- src/spikeinterface/preprocessing/whiten.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index cb2346ba68..c8eece2623 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -68,7 +68,7 @@ def __init__( M = np.asarray(M) else: W, M = compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=1e-8 + recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um ) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -122,7 +122,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") -def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=1e-8): +def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None): """ Compute whitening matrix @@ -167,6 +167,20 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r cov = data.T @ data cov = cov / data.shape[0] + # Here we determine eps used below to avoid division by zero. + # Typically we can assume that data is in units of + # microvolts, but this is not always the case. When data + # is float type and scaled down to very small values, then the + # default eps=1e-6 can be too large, resulting in incorrect + # whitening. We therefore check to see if the data is float + # type and we estimate a more reasonable eps in the case + # where the data is on a scale less than 1. + eps = 1e-6 # the default + if data.dtype.kind == "f": + median_data_sqr = np.median(data ** 2) # use the square because cov (and hence S) scales as the square + if median_data_sqr < 1 and median_data_sqr > 0: + eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data + if mode == "global": U, S, Ut = np.linalg.svd(cov, full_matrices=True) W = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut From 39ef079b0d8778e27a0500830464259e38aa528b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 13:52:43 +0000 Subject: [PATCH 12/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/whiten.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index c8eece2623..49ca3c7926 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -67,9 +67,7 @@ def __init__( if M is not None: M = np.asarray(M) else: - W, M = compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um - ) + W, M = compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -175,11 +173,11 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r # whitening. We therefore check to see if the data is float # type and we estimate a more reasonable eps in the case # where the data is on a scale less than 1. - eps = 1e-6 # the default + eps = 1e-6 # the default if data.dtype.kind == "f": - median_data_sqr = np.median(data ** 2) # use the square because cov (and hence S) scales as the square + median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square if median_data_sqr < 1 and median_data_sqr > 0: - eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data + eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data if mode == "global": U, S, Ut = np.linalg.svd(cov, full_matrices=True) From 2b7d2cea7b6a90b9f807b2322fee98ffa6464248 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Wed, 4 Oct 2023 11:22:11 -0400 Subject: [PATCH 13/29] adjust whiten tests to not use eps arg --- src/spikeinterface/preprocessing/tests/test_whiten.py | 6 +++--- src/spikeinterface/preprocessing/whiten.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 0848c1a176..40674a08f4 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -20,13 +20,13 @@ def test_whiten(): print(rec.get_channel_locations()) random_chunk_kwargs = {} - W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8) + W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) print(W) print(M) with pytest.raises(AssertionError): - W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, eps=1e-8) - W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25, eps=1e-8) + W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None) + W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=25) # W must be sparse np.sum(W == 0) == 6 diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index c8eece2623..5300b97de3 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -171,11 +171,11 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r # Typically we can assume that data is in units of # microvolts, but this is not always the case. When data # is float type and scaled down to very small values, then the - # default eps=1e-6 can be too large, resulting in incorrect + # default eps=1e-8 can be too large, resulting in incorrect # whitening. We therefore check to see if the data is float # type and we estimate a more reasonable eps in the case # where the data is on a scale less than 1. - eps = 1e-6 # the default + eps = 1e-8 # the default if data.dtype.kind == "f": median_data_sqr = np.median(data ** 2) # use the square because cov (and hence S) scales as the square if median_data_sqr < 1 and median_data_sqr > 0: From b2337d0d65497f1cb0cc6ef6c18f21c0b3fc5ac4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:24:27 +0000 Subject: [PATCH 14/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/whiten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 8e8bf3e9cb..afa3227e76 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -173,7 +173,7 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r # whitening. We therefore check to see if the data is float # type and we estimate a more reasonable eps in the case # where the data is on a scale less than 1. - eps = 1e-8 # the default + eps = 1e-8 # the default if data.dtype.kind == "f": median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square if median_data_sqr < 1 and median_data_sqr > 0: From 2f7176469b2b06b04baf36f8f8d2ff704cdf2095 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Wed, 4 Oct 2023 11:32:10 -0400 Subject: [PATCH 15/29] use "mV" in comment --- src/spikeinterface/preprocessing/whiten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index afa3227e76..7cd4766f2b 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -167,7 +167,7 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r # Here we determine eps used below to avoid division by zero. # Typically we can assume that data is in units of - # microvolts, but this is not always the case. When data + # mV, but this is not always the case. When data # is float type and scaled down to very small values, then the # default eps=1e-8 can be too large, resulting in incorrect # whitening. We therefore check to see if the data is float From ba81bb111a992320427219fbdad3ff4c72ae9003 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 6 Oct 2023 17:10:07 -0400 Subject: [PATCH 16/29] Update src/spikeinterface/preprocessing/whiten.py Co-authored-by: Alessio Buccino --- src/spikeinterface/preprocessing/whiten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 7cd4766f2b..cae7fe6334 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -166,8 +166,8 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r cov = cov / data.shape[0] # Here we determine eps used below to avoid division by zero. - # Typically we can assume that data is in units of - # mV, but this is not always the case. When data + # Typically we can assume that data is either unscaled integers or in units of + # uV, but this is not always the case. When data # is float type and scaled down to very small values, then the # default eps=1e-8 can be too large, resulting in incorrect # whitening. We therefore check to see if the data is float From 8e316efd8bf280374598c33c490e8b3c6c90dc3a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 21:29:16 +0200 Subject: [PATCH 17/29] wip compute_spike_location with true channel --- src/spikeinterface/core/node_pipeline.py | 4 +- .../postprocessing/spike_locations.py | 46 +++++++++++-------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index a0ded216d1..e17c5d5caa 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -213,7 +213,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): elif self.peak_sign == "both": local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] - # TODO: "amplitude" ??? + # handle amplitude + for i, peak in enumerate(local_peaks): + local_peaks["amplitude"][i] = traces[local_peaks["sample_index"], local_peaks[i]["channel_index"]] return (local_peaks,) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index e4b60d401e..2807fb992c 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -27,13 +27,18 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) def _set_params( - self, ms_before=0.5, ms_after=0.5, channel_from_template=True, method="center_of_mass", method_kwargs={} + self, ms_before=0.5, ms_after=0.5, + spike_retriver_kwargs=dict( + channel_from_template=False, + radius_um=50, + peaks_sign="neg", + ), + method="center_of_mass", method_kwargs={} ): params = dict( - ms_before=ms_before, ms_after=ms_after, channel_from_template=channel_from_template, method=method + ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs, method=method ) params.update(**method_kwargs) - print(params) return params def _select_extension_data(self, unit_ids): @@ -59,16 +64,13 @@ def _run(self, **job_kwargs): extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") params = self._params.copy() - channel_from_template = params.pop("channel_from_template") + spike_retriver_kwargs = params.pop("spike_retriver_kwargs") - # @alessio @pierre: where do we expose the parameters of radius for the retriever (this is not the same as the one for locatization it is smaller) ??? spike_retriever = SpikeRetriever( we.recording, we.sorting, - channel_from_template=channel_from_template, extremum_channel_inds=extremum_channel_inds, - radius_um=50, - peak_sign=self._params.get("peaks_sign", "neg"), + **spike_retriver_kwargs ) spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) @@ -116,16 +118,17 @@ def get_extension_function(): WaveformExtractor.register_extension(SpikeLocationsCalculator) -# @alessio @pierre: channel_from_template=True is the old behavior but this is not accurate -# what do we put by default ? - - def compute_spike_locations( waveform_extractor, load_if_exists=False, ms_before=0.5, ms_after=0.5, - channel_from_template=True, + spike_retriver_kwargs=dict( + channel_from_template=False, + radius_um=50, + peaks_sign="neg", + ), + method="center_of_mass", method_kwargs={}, outputs="concatenated", @@ -144,10 +147,17 @@ def compute_spike_locations( The left window, before a peak, in milliseconds. ms_after : float The right window, after a peak, in milliseconds. - channel_from_template: bool, default True - For each spike is the maximum channel computed from template or re estimated at every spikes. - channel_from_template = True is old behavior but less acurate - channel_from_template = False is slower but more accurate + spike_retriver_kwargs: dict + A dict that contains the behavior for getting the maximum channel for each spike. + This contain dict contains: + * channel_from_template: bool, default True + For each spike is the maximum channel computed from template or re estimated at every spikes. + channel_from_template = True is old behavior but less acurate + channel_from_template = False is slower but more accurate + * radius_um: float, default 50 + In case channel_from_template=False, this is the radius to get the true peak. + * peaks_sign="neg" + In case channel_from_template=False, this is the peak sign. method : str 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' method_kwargs : dict @@ -170,7 +180,7 @@ def compute_spike_locations( slc.set_params( ms_before=ms_before, ms_after=ms_after, - channel_from_template=channel_from_template, + spike_retriver_kwargs=spike_retriver_kwargs, method=method, method_kwargs=method_kwargs, ) From 0c790f4687251803ab1fbad96712126ef1f49a2a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 19:29:39 +0000 Subject: [PATCH 18/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 2 +- .../postprocessing/spike_locations.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index e17c5d5caa..9b61ec0dab 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -215,7 +215,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # handle amplitude for i, peak in enumerate(local_peaks): - local_peaks["amplitude"][i] = traces[local_peaks["sample_index"], local_peaks[i]["channel_index"]] + local_peaks["amplitude"][i] = traces[local_peaks["sample_index"], local_peaks[i]["channel_index"]] return (local_peaks,) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 2807fb992c..6f8a8aabcb 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -27,13 +27,16 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) def _set_params( - self, ms_before=0.5, ms_after=0.5, + self, + ms_before=0.5, + ms_after=0.5, spike_retriver_kwargs=dict( channel_from_template=False, radius_um=50, peaks_sign="neg", ), - method="center_of_mass", method_kwargs={} + method="center_of_mass", + method_kwargs={}, ): params = dict( ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs, method=method @@ -67,10 +70,7 @@ def _run(self, **job_kwargs): spike_retriver_kwargs = params.pop("spike_retriver_kwargs") spike_retriever = SpikeRetriever( - we.recording, - we.sorting, - extremum_channel_inds=extremum_channel_inds, - **spike_retriver_kwargs + we.recording, we.sorting, extremum_channel_inds=extremum_channel_inds, **spike_retriver_kwargs ) spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) @@ -118,6 +118,7 @@ def get_extension_function(): WaveformExtractor.register_extension(SpikeLocationsCalculator) + def compute_spike_locations( waveform_extractor, load_if_exists=False, @@ -128,7 +129,6 @@ def compute_spike_locations( radius_um=50, peaks_sign="neg", ), - method="center_of_mass", method_kwargs={}, outputs="concatenated", From 6887c97bf81e502f8daee5eb60a049d2fb387c73 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 09:50:01 +0200 Subject: [PATCH 19/29] oups --- src/spikeinterface/core/node_pipeline.py | 2 +- src/spikeinterface/postprocessing/spike_locations.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 9b61ec0dab..a6dabf77b5 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -215,7 +215,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # handle amplitude for i, peak in enumerate(local_peaks): - local_peaks["amplitude"][i] = traces[local_peaks["sample_index"], local_peaks[i]["channel_index"]] + local_peaks["amplitude"][i] = traces[peak["sample_index"], peak["channel_index"]] return (local_peaks,) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 6f8a8aabcb..0e471444d8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -31,7 +31,7 @@ def _set_params( ms_before=0.5, ms_after=0.5, spike_retriver_kwargs=dict( - channel_from_template=False, + channel_from_template=True, radius_um=50, peaks_sign="neg", ), @@ -125,7 +125,7 @@ def compute_spike_locations( ms_before=0.5, ms_after=0.5, spike_retriver_kwargs=dict( - channel_from_template=False, + channel_from_template=True, radius_um=50, peaks_sign="neg", ), From e238608afde4b3f06e6dac8276fb8c398b1beeab Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 10:39:39 +0200 Subject: [PATCH 20/29] less strict on amplitude for spikeretreiver tests --- src/spikeinterface/core/tests/test_node_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 4b86c538a9..9b65eba726 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -105,7 +105,8 @@ def test_run_node_pipeline(): AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), ] step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) + if loop ==0: + assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) # 3 nodes two have outputs ms_before = 0.5 @@ -133,7 +134,6 @@ def test_run_node_pipeline(): # gather memory mode output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") amplitudes, waveforms_rms, denoised_waveforms_rms = output - assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) num_peaks = peaks.shape[0] num_channels = recording.get_num_channels() From 59f0473e71fb7b1ea029006021c55adf34b7b69f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Oct 2023 08:42:49 +0000 Subject: [PATCH 21/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_node_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 9b65eba726..e5a6dd055c 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -105,7 +105,7 @@ def test_run_node_pipeline(): AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), ] step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - if loop ==0: + if loop == 0: assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) # 3 nodes two have outputs From 0699434f6573f744dddd4a4ffe3c6b0b2f7e67d8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Oct 2023 17:50:25 +0200 Subject: [PATCH 22/29] Expose eps at the function level and clarify options --- src/spikeinterface/preprocessing/whiten.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index cae7fe6334..54f6e0e903 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -29,6 +29,10 @@ class WhitenRecording(BasePreprocessor): Apply a scaling factor to fit the integer range. This is used when the dtype is an integer, so that the output is scaled. For example, a value of `int_scale=200` will scale the traces value to a standard deviation of 200. + eps : float, default 1e-8 + Small epsilon to regularize SVD. + If None, eps is estimated from the data. If the data is float type and scaled down to very small values, + then the eps is automatically set to a small fraction of the median of the squared data. W : 2d np.array Pre-computed whitening matrix, by default None M : 1d np.array or None @@ -52,6 +56,7 @@ def __init__( mode="global", radius_um=100.0, int_scale=None, + eps=1e-8, W=None, M=None, **random_chunk_kwargs, @@ -67,7 +72,9 @@ def __init__( if M is not None: M = np.asarray(M) else: - W, M = compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um) + W, M = compute_whitening_matrix( + recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps + ) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -120,7 +127,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") -def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None): +def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=1e-8): """ Compute whitening matrix @@ -173,8 +180,7 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r # whitening. We therefore check to see if the data is float # type and we estimate a more reasonable eps in the case # where the data is on a scale less than 1. - eps = 1e-8 # the default - if data.dtype.kind == "f": + if data.dtype.kind == "f" or eps is None: median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square if median_data_sqr < 1 and median_data_sqr > 0: eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data From 3ed9e5f0a56cac3b8005c2d7fe5e78ead4078635 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 21:11:55 +0200 Subject: [PATCH 23/29] oups --- src/spikeinterface/postprocessing/spike_locations.py | 6 +++--- .../postprocessing/tests/test_spike_locations.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 0e471444d8..ccf321ba80 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -33,7 +33,7 @@ def _set_params( spike_retriver_kwargs=dict( channel_from_template=True, radius_um=50, - peaks_sign="neg", + peak_sign="neg", ), method="center_of_mass", method_kwargs={}, @@ -127,7 +127,7 @@ def compute_spike_locations( spike_retriver_kwargs=dict( channel_from_template=True, radius_um=50, - peaks_sign="neg", + peak_sign="neg", ), method="center_of_mass", method_kwargs={}, @@ -156,7 +156,7 @@ def compute_spike_locations( channel_from_template = False is slower but more accurate * radius_um: float, default 50 In case channel_from_template=False, this is the radius to get the true peak. - * peaks_sign="neg" + * peak_sign="neg" In case channel_from_template=False, this is the peak sign. method : str 'center_of_mass' / 'monopolar_triangulation' / 'grid_convolution' diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index ab2345b1f5..89b015f1da 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -10,8 +10,8 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes extension_class = SpikeLocationsCalculator extension_data_names = ["spike_locations"] extension_function_kwargs_list = [ - dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=True), - dict(method="center_of_mass", chunk_size=10000, n_jobs=1, channel_from_template=False), + dict(method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True)), + dict(method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False)), dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), From 8a987b87d9b5e6fad4d9e1e03036898b598c8939 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Oct 2023 19:12:17 +0000 Subject: [PATCH 24/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/tests/test_spike_locations.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 89b015f1da..d047a2f67e 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -10,8 +10,12 @@ class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Tes extension_class = SpikeLocationsCalculator extension_data_names = ["spike_locations"] extension_function_kwargs_list = [ - dict(method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True)), - dict(method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False)), + dict( + method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True) + ), + dict( + method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False) + ), dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), From 4a6a6b91d4f8cd6c3e48ee719e58d81b8c9735e8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 11:58:02 +0200 Subject: [PATCH 25/29] Use eps=None by default --- src/spikeinterface/preprocessing/whiten.py | 39 ++++++++++++---------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 54f6e0e903..ec3a3e91a9 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -15,27 +15,27 @@ class WhitenRecording(BasePreprocessor): ---------- recording: RecordingExtractor The recording extractor to be whitened. - dtype: None or dtype + dtype: None or dtype, default: None If None the the parent dtype is kept. For integer dtype a int_scale must be also given. - mode: 'global' / 'local' + mode: 'global' / 'local', default: 'global' 'global' use the entire covariance matrix to compute the W matrix 'local' use local covariance (by radius) to compute the W matrix - radius_um: None or float + radius_um: None or float, default: None Used for mode = 'local' to get the neighborhood - apply_mean: bool + apply_mean: bool, default: False Substract or not the mean matrix M before the dot product with W. - int_scale : None or float + int_scale : None or float, default: None Apply a scaling factor to fit the integer range. This is used when the dtype is an integer, so that the output is scaled. For example, a value of `int_scale=200` will scale the traces value to a standard deviation of 200. - eps : float, default 1e-8 + eps : float or None, default: None Small epsilon to regularize SVD. - If None, eps is estimated from the data. If the data is float type and scaled down to very small values, - then the eps is automatically set to a small fraction of the median of the squared data. - W : 2d np.array - Pre-computed whitening matrix, by default None - M : 1d np.array or None + If None, eps is default to 1e-8. If the data is float type and scaled down to very small values, + then the eps is automatically set to a small fraction (1e-3) of the median of the squared data. + W : 2d np.array, default: None + Pre-computed whitening matrix + M : 1d np.array or None, default: None Pre-computed means. M can be None when previously computed with apply_mean=False **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function @@ -56,7 +56,7 @@ def __init__( mode="global", radius_um=100.0, int_scale=None, - eps=1e-8, + eps=None, W=None, M=None, **random_chunk_kwargs, @@ -127,7 +127,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") -def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=1e-8): +def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None): """ Compute whitening matrix @@ -145,10 +145,11 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r Keyword arguments for get_random_data_chunks() apply_mean : bool If True, the mean is removed prior to computing the covariance - radius_um : float, optional - Used for mode = 'local' to get the neighborhood, by default None - eps : float, optional - Small epsilon to regularize SVD, by default 1e-8 + radius_um : float, default: None + Used for mode = 'local' to get the neighborhood + eps : float, default: None + Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled + down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data. Returns ------- @@ -180,7 +181,9 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r # whitening. We therefore check to see if the data is float # type and we estimate a more reasonable eps in the case # where the data is on a scale less than 1. - if data.dtype.kind == "f" or eps is None: + if eps is None: + eps = 1e-8 + if data.dtype.kind == "f": median_data_sqr = np.median(data**2) # use the square because cov (and hence S) scales as the square if median_data_sqr < 1 and median_data_sqr > 0: eps = max(1e-16, median_data_sqr * 1e-3) # use a small fraction of the median of the squared data From 4e843cb09ac9a9217af0f2a9476514609400d789 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:01:19 +0200 Subject: [PATCH 26/29] Update src/spikeinterface/sortingcomponents/tools.py --- src/spikeinterface/sortingcomponents/tools.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 576732baa2..cd9226d5e8 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -19,9 +19,8 @@ def make_multi_method_doc(methods, ident=" "): def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): - # TODO for Pierre: this function is really unefficient because it runa full pipeline only for a few - # spikes, which leans that traces are entirally computed!!!!! - # Please find a better way + # TODO for Pierre: this function is really inefficient because it runs a full pipeline only for a few + # spikes, which means that all traces need to be accesses! Please find a better way nb_peaks = min(len(peaks), nb_peaks) idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) peak_retriever = PeakRetriever(recording, peaks[idx]) From 1af5722a356c733d9afabe69bab354362273a4f8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:03:12 +0200 Subject: [PATCH 27/29] Update src/spikeinterface/postprocessing/spike_locations.py --- src/spikeinterface/postprocessing/spike_locations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index ccf321ba80..28eed131cd 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -148,8 +148,8 @@ def compute_spike_locations( ms_after : float The right window, after a peak, in milliseconds. spike_retriver_kwargs: dict - A dict that contains the behavior for getting the maximum channel for each spike. - This contain dict contains: + A dictionary to control the behavior for getting the maximum channel for each spike. + This dictionary contains: * channel_from_template: bool, default True For each spike is the maximum channel computed from template or re estimated at every spikes. channel_from_template = True is old behavior but less acurate From d93ba0fe48e3395fc940d0a0f05e483631dc5651 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:19:42 +0200 Subject: [PATCH 28/29] Update src/spikeinterface/preprocessing/whiten.py --- src/spikeinterface/preprocessing/whiten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index ec3a3e91a9..ac80f58182 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -31,7 +31,7 @@ class WhitenRecording(BasePreprocessor): For example, a value of `int_scale=200` will scale the traces value to a standard deviation of 200. eps : float or None, default: None Small epsilon to regularize SVD. - If None, eps is default to 1e-8. If the data is float type and scaled down to very small values, + If None, eps is defaulted to 1e-8. If the data is float type and scaled down to very small values, then the eps is automatically set to a small fraction (1e-3) of the median of the squared data. W : 2d np.array, default: None Pre-computed whitening matrix From 3f82c5968c471cc90ec05f089ebb43bd777a3738 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Thu, 26 Oct 2023 06:27:56 -0400 Subject: [PATCH 29/29] parallelization if > cutoff of units Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/isi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 57e4bdc5b2..e98e64f753 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -205,6 +205,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins): n_units = ISIs.shape[0] - for i in range(n_units): + units_loop = numba.prange(n_units) if n_units > 300 else range(n_units) + for i in units_loop: spike_train = spike_trains[spike_clusters == i] ISIs[i] += np.histogram(np.diff(spike_train), bins=bins)[0]