From 688afa7c07396a6ad57203bb89649dd294f4c511 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 30 Nov 2023 10:54:05 +0100 Subject: [PATCH 1/2] Strict inegality for radius_um --- src/spikeinterface/core/sparsity.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 3b8b6025ca..893da59d74 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -292,7 +292,7 @@ def from_radius(cls, we, radius_um, peak_sign="neg"): best_chan = get_template_extremum_channel(we, peak_sign=peak_sign, outputs="index") for unit_ind, unit_id in enumerate(we.unit_ids): chan_ind = best_chan[unit_id] - (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) + (chan_inds,) = np.nonzero(distances[chan_ind, :] < radius_um) mask[unit_ind, chan_inds] = True return cls(mask, we.unit_ids, we.channel_ids) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 629b0b13ac..050ba10efb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -368,7 +368,7 @@ def auto_clean_clustering( # we use (radius_chans,) = np.nonzero( - (channel_distances[main_chan0, :] <= radius_um) | (channel_distances[main_chan1, :] <= radius_um) + (channel_distances[main_chan0, :] < radius_um) | (channel_distances[main_chan1, :] < radius_um) ) if radius_chans.size < (intersect_chans.size * ratio_num_channel_intersect): # ~ print('WARNING INTERSECT') From c4994617b2b2e88a0897f38124905a0b32a88b8a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 1 Dec 2023 09:52:43 +0100 Subject: [PATCH 2/2] Radius_um now <= everywhere --- src/spikeinterface/core/node_pipeline.py | 4 ++-- src/spikeinterface/core/sparsity.py | 2 +- .../postprocessing/unit_localization.py | 2 +- src/spikeinterface/preprocessing/whiten.py | 2 +- .../clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/merge.py | 4 ++-- .../sortingcomponents/features_from_peaks.py | 16 ++++++++-------- .../sortingcomponents/peak_detection.py | 6 +++--- 8 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index a00df98e05..fd8dbd35b6 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -175,7 +175,7 @@ def __init__( if not channel_from_template: channel_distance = get_channel_distances(recording) - self.neighbours_mask = channel_distance < radius_um + self.neighbours_mask = channel_distance <= radius_um self.peak_sign = peak_sign # precompute segment slice @@ -367,7 +367,7 @@ def __init__( self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1)) def get_trace_margin(self): diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 893da59d74..3b8b6025ca 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -292,7 +292,7 @@ def from_radius(cls, we, radius_um, peak_sign="neg"): best_chan = get_template_extremum_channel(we, peak_sign=peak_sign, outputs="index") for unit_ind, unit_id in enumerate(we.unit_ids): chan_ind = best_chan[unit_id] - (chan_inds,) = np.nonzero(distances[chan_ind, :] < radius_um) + (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) mask[unit_ind, chan_inds] = True return cls(mask, we.unit_ids, we.channel_ids) diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index f665bac8d6..2ac841c148 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -597,7 +597,7 @@ def get_grid_convolution_templates_and_weights( # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) - nearest_template_mask = dist < radius_um + nearest_template_mask = dist <= radius_um weights = np.zeros((len(sigma_um), len(contact_locations), nb_templates), dtype=np.float32) for count, sigma in enumerate(sigma_um): diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 3bea9b91bb..766229b62a 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -197,7 +197,7 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r distances = get_channel_distances(recording) W = np.zeros((n, n), dtype="float64") for c in range(n): - (inds,) = np.nonzero(distances[c, :] < radius_um) + (inds,) = np.nonzero(distances[c, :] <= radius_um) cov_local = cov[inds, :][:, inds] U, S, Ut = np.linalg.svd(cov_local, full_matrices=True) W_local = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 050ba10efb..629b0b13ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -368,7 +368,7 @@ def auto_clean_clustering( # we use (radius_chans,) = np.nonzero( - (channel_distances[main_chan0, :] < radius_um) | (channel_distances[main_chan1, :] < radius_um) + (channel_distances[main_chan0, :] <= radius_um) | (channel_distances[main_chan1, :] <= radius_um) ) if radius_chans.size < (intersect_chans.size * ratio_num_channel_intersect): # ~ print('WARNING INTERSECT') diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 24ec923f06..285a9ff2f2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -291,7 +291,7 @@ def find_merge_pairs( template_locs = channel_locs[max_chans, :] template_dist = scipy.spatial.distance.cdist(template_locs, template_locs, metric="euclidean") - pair_mask = pair_mask & (template_dist < radius_um) + pair_mask = pair_mask & (template_dist <= radius_um) indices0, indices1 = np.nonzero(pair_mask) n_jobs = job_kwargs["n_jobs"] @@ -337,7 +337,7 @@ def find_merge_pairs( pair_shift[ind0, ind1] = shift pair_values[ind0, ind1] = merge_value - pair_mask = pair_mask & (template_dist < radius_um) + pair_mask = pair_mask & (template_dist <= radius_um) indices0, indices1 = np.nonzero(pair_mask) return labels_set, pair_mask, pair_shift, pair_values diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index f7f020d153..4006939b22 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -119,7 +119,7 @@ def __init__( self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.all_channels = all_channels self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() @@ -157,7 +157,7 @@ def __init__( self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() @@ -202,7 +202,7 @@ def __init__( self.sigmoid = sigmoid self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.radius_um = radius_um self.sparse = sparse self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse)) @@ -253,7 +253,7 @@ def __init__( self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self.projections = projections self.min_values = min_values @@ -288,7 +288,7 @@ def __init__(self, recording, name="std_ptp_feature", return_output=True, parent self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) @@ -313,7 +313,7 @@ def __init__(self, recording, name="global_ptp_feature", return_output=True, par self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) @@ -338,7 +338,7 @@ def __init__(self, recording, name="kurtosis_ptp_feature", return_output=True, p self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) @@ -365,7 +365,7 @@ def __init__(self, recording, name="energy_feature", return_output=True, parents self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs.update(dict(radius_um=radius_um)) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index e66c8be874..22438c0934 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -542,7 +542,7 @@ def check_params( ) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < radius_um + neighbours_mask = channel_distance <= radius_um return args + (neighbours_mask,) @classmethod @@ -624,7 +624,7 @@ def check_params( neighbour_indices_by_chan = [] num_channels = recording.get_num_channels() for chan in range(num_channels): - neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] < radius_um)[0]) + neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] <= radius_um)[0]) max_neighbs = np.max([len(neigh) for neigh in neighbour_indices_by_chan]) neighbours_idxs = num_channels * np.ones((num_channels, max_neighbs), dtype=int) for i, neigh in enumerate(neighbour_indices_by_chan): @@ -856,7 +856,7 @@ def check_params( abs_threholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < radius_um + neighbours_mask = channel_distance <= radius_um executor = OpenCLDetectPeakExecutor(abs_threholds, exclude_sweep_size, neighbours_mask, peak_sign)