Skip to content

Commit

Permalink
Merge pull request #2277 from yger/sparsity_inconsistencies
Browse files Browse the repository at this point in the history
Strict inegality for sparsity with radius_um
  • Loading branch information
samuelgarcia authored Dec 1, 2023
2 parents 6a38d06 + ca932c8 commit 7f204c6
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(

if not channel_from_template:
channel_distance = get_channel_distances(recording)
self.neighbours_mask = channel_distance < radius_um
self.neighbours_mask = channel_distance <= radius_um
self.peak_sign = peak_sign

# precompute segment slice
Expand Down Expand Up @@ -367,7 +367,7 @@ def __init__(
self.radius_um = radius_um
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um
self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1))

def get_trace_margin(self):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/unit_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def get_grid_convolution_templates_and_weights(

# mask to get nearest template given a channel
dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions)
nearest_template_mask = dist < radius_um
nearest_template_mask = dist <= radius_um

weights = np.zeros((len(sigma_um), len(contact_locations), nb_templates), dtype=np.float32)
for count, sigma in enumerate(sigma_um):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r
distances = get_channel_distances(recording)
W = np.zeros((n, n), dtype="float64")
for c in range(n):
(inds,) = np.nonzero(distances[c, :] < radius_um)
(inds,) = np.nonzero(distances[c, :] <= radius_um)
cov_local = cov[inds, :][:, inds]
U, S, Ut = np.linalg.svd(cov_local, full_matrices=True)
W_local = (U @ np.diag(1 / np.sqrt(S + eps))) @ Ut
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/clustering/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def find_merge_pairs(
template_locs = channel_locs[max_chans, :]
template_dist = scipy.spatial.distance.cdist(template_locs, template_locs, metric="euclidean")

pair_mask = pair_mask & (template_dist < radius_um)
pair_mask = pair_mask & (template_dist <= radius_um)
indices0, indices1 = np.nonzero(pair_mask)

n_jobs = job_kwargs["n_jobs"]
Expand Down Expand Up @@ -337,7 +337,7 @@ def find_merge_pairs(
pair_shift[ind0, ind1] = shift
pair_values[ind0, ind1] = merge_value

pair_mask = pair_mask & (template_dist < radius_um)
pair_mask = pair_mask & (template_dist <= radius_um)
indices0, indices1 = np.nonzero(pair_mask)

return labels_set, pair_mask, pair_shift, pair_values
Expand Down
16 changes: 8 additions & 8 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(

self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um
self.all_channels = all_channels
self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels))
self._dtype = recording.get_dtype()
Expand Down Expand Up @@ -157,7 +157,7 @@ def __init__(

self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um

self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels))
self._dtype = recording.get_dtype()
Expand Down Expand Up @@ -202,7 +202,7 @@ def __init__(
self.sigmoid = sigmoid
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um
self.radius_um = radius_um
self.sparse = sparse
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse))
Expand Down Expand Up @@ -253,7 +253,7 @@ def __init__(

self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um

self.projections = projections
self.min_values = min_values
Expand Down Expand Up @@ -288,7 +288,7 @@ def __init__(self, recording, name="std_ptp_feature", return_output=True, parent

self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um

self._kwargs.update(dict(radius_um=radius_um))

Expand All @@ -313,7 +313,7 @@ def __init__(self, recording, name="global_ptp_feature", return_output=True, par

self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um

self._kwargs.update(dict(radius_um=radius_um))

Expand All @@ -338,7 +338,7 @@ def __init__(self, recording, name="kurtosis_ptp_feature", return_output=True, p

self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um

self._kwargs.update(dict(radius_um=radius_um))

Expand All @@ -365,7 +365,7 @@ def __init__(self, recording, name="energy_feature", return_output=True, parents

self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um

self._kwargs.update(dict(radius_um=radius_um))

Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def check_params(
)

channel_distance = get_channel_distances(recording)
neighbours_mask = channel_distance < radius_um
neighbours_mask = channel_distance <= radius_um
return args + (neighbours_mask,)

@classmethod
Expand Down Expand Up @@ -624,7 +624,7 @@ def check_params(
neighbour_indices_by_chan = []
num_channels = recording.get_num_channels()
for chan in range(num_channels):
neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] < radius_um)[0])
neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] <= radius_um)[0])
max_neighbs = np.max([len(neigh) for neigh in neighbour_indices_by_chan])
neighbours_idxs = num_channels * np.ones((num_channels, max_neighbs), dtype=int)
for i, neigh in enumerate(neighbour_indices_by_chan):
Expand Down Expand Up @@ -856,7 +856,7 @@ def check_params(
abs_threholds = noise_levels * detect_threshold
exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0)
channel_distance = get_channel_distances(recording)
neighbours_mask = channel_distance < radius_um
neighbours_mask = channel_distance <= radius_um

executor = OpenCLDetectPeakExecutor(abs_threholds, exclude_sweep_size, neighbours_mask, peak_sign)

Expand Down

0 comments on commit 7f204c6

Please sign in to comment.