Skip to content

Commit

Permalink
Merge branch 'main' into tdc_2
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia authored Oct 10, 2023
2 parents 64d507c + d9fd3d2 commit 7dafef8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 18 deletions.
3 changes: 2 additions & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

# recording_f = whiten(recording_f, dtype="float32")
recording_f = zscore(recording_f, dtype="float32")
noise_levels = np.ones(num_channels, dtype=np.float32)

## Then, we are detecting peaks with a locally_exclusive method
detection_params = params["detection"].copy()
Expand All @@ -87,7 +88,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels
selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"])

noise_levels = np.ones(num_channels, dtype=np.float32)
selection_params.update({"noise_levels": noise_levels})
selected_peaks = select_peaks(
peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params
Expand All @@ -107,6 +107,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params.update(dict(shared_memory=params["shared_memory"]))
clustering_params["job_kwargs"] = job_kwargs
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"
clustering_params.update({"noise_levels": noise_levels})

labels, peak_labels = find_cluster_from_peaks(
recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from spikeinterface.core import extract_waveforms
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature
from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractDenseWaveforms,
ExtractSparseWaveforms,
PeakRetriever,
)


class RandomProjectionClustering:
Expand All @@ -43,7 +48,8 @@ class RandomProjectionClustering:
"ms_before": 1,
"ms_after": 1,
"random_seed": 42,
"smoothing_kwargs": {"window_length_ms": 1},
"noise_levels": None,
"smoothing_kwargs": {"window_length_ms": 0.25},
"shared_memory": True,
"tmp_folder": None,
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True},
Expand Down Expand Up @@ -72,7 +78,10 @@ def main_function(cls, recording, peaks, params):
num_samples = nbefore + nafter
num_chans = recording.get_num_channels()

noise_levels = get_noise_levels(recording, return_scaled=False)
if d["noise_levels"] is None:
noise_levels = get_noise_levels(recording, return_scaled=False)
else:
noise_levels = d["noise_levels"]

np.random.seed(d["random_seed"])

Expand All @@ -82,10 +91,16 @@ def main_function(cls, recording, peaks, params):
else:
tmp_folder = Path(params["tmp_folder"]).absolute()

### Then we extract the SVD features
tmp_folder.mkdir(parents=True, exist_ok=True)

node0 = PeakRetriever(recording, peaks)
node1 = ExtractDenseWaveforms(
recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"]
node1 = ExtractSparseWaveforms(
recording,
parents=[node0],
return_output=False,
ms_before=params["ms_before"],
ms_after=params["ms_after"],
radius_um=params["radius_um"],
)

node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"])
Expand Down Expand Up @@ -123,6 +138,8 @@ def sigmoid(x, L, x0, k, b):
return_output=True,
projections=projections,
radius_um=params["radius_um"],
sigmoid=None,
sparse=True,
)

pipeline_nodes = [node0, node1, node2, node3]
Expand All @@ -136,6 +153,18 @@ def sigmoid(x, L, x0, k, b):
clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"])
peak_labels = clustering[0]

# peak_labels = -1 * np.ones(len(peaks), dtype=int)
# nb_clusters = 0
# for c in np.unique(peaks['channel_index']):
# mask = peaks['channel_index'] == c
# clustering = hdbscan.hdbscan(hdbscan_data[mask], **d['hdbscan_kwargs'])
# local_labels = clustering[0]
# valid_clusters = local_labels > -1
# if np.sum(valid_clusters) > 0:
# local_labels[valid_clusters] += nb_clusters
# peak_labels[mask] = local_labels
# nb_clusters += len(np.unique(local_labels[valid_clusters]))

labels = np.unique(peak_labels)
labels = labels[labels >= 0]

Expand Down Expand Up @@ -174,15 +203,6 @@ def sigmoid(x, L, x0, k, b):
if verbose:
print("We found %d raw clusters, starting to clean with matching..." % (len(labels)))

# create a tmp folder
if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
tmp_folder = get_global_tmp_folder() / name
else:
tmp_folder = Path(params["tmp_folder"])

tmp_folder.mkdir(parents=True, exist_ok=True)

sorting_folder = tmp_folder / "sorting"
unit_ids = np.arange(len(np.unique(spikes["unit_index"])))
sorting = NumpySorting(spikes, fs, unit_ids=unit_ids)
Expand Down
9 changes: 7 additions & 2 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
projections=None,
sigmoid=None,
radius_um=None,
sparse=True,
):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)

Expand All @@ -195,7 +196,8 @@ def __init__(
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.radius_um = radius_um
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um))
self.sparse = sparse
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse))
self._dtype = recording.get_dtype()

def get_dtype(self):
Expand All @@ -213,7 +215,10 @@ def compute(self, traces, peaks, waveforms):
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.neighbours_mask[main_chan])
local_projections = self.projections[chan_inds, :]
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)
if self.sparse:
wf_ptp = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1)
else:
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)

if self.sigmoid is not None:
wf_ptp *= self._sigmoid(wf_ptp)
Expand Down

0 comments on commit 7dafef8

Please sign in to comment.