Skip to content

Commit

Permalink
Extracting sparse waveforms
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Oct 9, 2023
1 parent f68da6a commit ed44aaf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from spikeinterface.core import extract_waveforms
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature
from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever
from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, ExtractSparseWaveforms, PeakRetriever


class RandomProjectionClustering:
Expand Down Expand Up @@ -90,8 +90,9 @@ def main_function(cls, recording, peaks, params):
tmp_folder.mkdir(parents=True, exist_ok=True)

node0 = PeakRetriever(recording, peaks)
node1 = ExtractDenseWaveforms(
recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"]
node1 = ExtractSparseWaveforms(
recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"],
radius_um=params['radius_um']
)

node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"])
Expand Down Expand Up @@ -129,6 +130,8 @@ def sigmoid(x, L, x0, k, b):
return_output=True,
projections=projections,
radius_um=params["radius_um"],
sigmoid=None,
sparse=True
)

pipeline_nodes = [node0, node1, node2, node3]
Expand All @@ -142,6 +145,18 @@ def sigmoid(x, L, x0, k, b):
clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"])
peak_labels = clustering[0]

# peak_labels = -1 * np.ones(len(peaks), dtype=int)
# nb_clusters = 0
# for c in np.unique(peaks['channel_index']):
# mask = peaks['channel_index'] == c
# clustering = hdbscan.hdbscan(hdbscan_data[mask], **d['hdbscan_kwargs'])
# local_labels = clustering[0]
# valid_clusters = local_labels > -1
# if np.sum(valid_clusters) > 0:
# local_labels[valid_clusters] += nb_clusters
# peak_labels[mask] = local_labels
# nb_clusters += len(np.unique(local_labels[valid_clusters]))

labels = np.unique(peak_labels)
labels = labels[labels >= 0]

Expand Down
9 changes: 7 additions & 2 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
projections=None,
sigmoid=None,
radius_um=None,
sparse=True
):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)

Expand All @@ -195,7 +196,8 @@ def __init__(
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.radius_um = radius_um
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um))
self.sparse = sparse
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse))
self._dtype = recording.get_dtype()

def get_dtype(self):
Expand All @@ -213,7 +215,10 @@ def compute(self, traces, peaks, waveforms):
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.neighbours_mask[main_chan])
local_projections = self.projections[chan_inds, :]
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)
if self.sparse:
wf_ptp = np.ptp(waveforms[idx][:, :, :len(chan_inds)], axis=1)
else:
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)

if self.sigmoid is not None:
wf_ptp *= self._sigmoid(wf_ptp)
Expand Down

0 comments on commit ed44aaf

Please sign in to comment.