Skip to content

Commit

Permalink
Merge branch 'main' into fix-unit-id-matching
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Sep 28, 2023
2 parents f1b7bfe + 427d7b5 commit 59955c3
Show file tree
Hide file tree
Showing 6 changed files with 467 additions and 86 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1},
"filtering": {"dtype": "float32"},
"detection": {"peak_sign": "neg", "detect_threshold": 5},
Expand Down Expand Up @@ -151,7 +151,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_job_params["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_f, method="circus-omp", method_kwargs=matching_params, **matching_job_params
recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params
)

if verbose:
Expand Down
42 changes: 25 additions & 17 deletions src/spikeinterface/sortingcomponents/clustering/clustering_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,14 +539,14 @@ def remove_duplicates_via_matching(
method_kwargs={},
job_kwargs={},
tmp_folder=None,
method="circus-omp-svd",
):
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface import get_noise_levels
from spikeinterface.core import BinaryRecordingExtractor
from spikeinterface.core import NumpySorting
from spikeinterface.core import extract_waveforms
from spikeinterface.core import get_global_tmp_folder
from spikeinterface.sortingcomponents.matching.circus import get_scipy_shape
import string, random, shutil, os
from pathlib import Path

Expand Down Expand Up @@ -591,19 +591,12 @@ def remove_duplicates_via_matching(

chunk_size = duration + 3 * margin

dummy_filter = np.empty((num_chans, duration), dtype=np.float32)
dummy_traces = np.empty((num_chans, chunk_size), dtype=np.float32)

fshape, axes = get_scipy_shape(dummy_filter, dummy_traces, axes=1)

method_kwargs.update(
{
"waveform_extractor": waveform_extractor,
"noise_levels": noise_levels,
"amplitudes": [0.95, 1.05],
"omp_min_sps": 0.1,
"templates": None,
"overlaps": None,
}
)

Expand All @@ -618,16 +611,31 @@ def remove_duplicates_via_matching(

method_kwargs.update({"ignored_ids": ignore_ids + [i]})
spikes, computed = find_spikes_from_templates(
sub_recording, method="circus-omp", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
)
method_kwargs.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
"norms": computed["norms"],
"sparsities": computed["sparsities"],
}
sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
)
if method == "circus-omp-svd":
method_kwargs.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
"norms": computed["norms"],
"temporal": computed["temporal"],
"spatial": computed["spatial"],
"singular": computed["singular"],
"units_overlaps": computed["units_overlaps"],
"unit_overlaps_indices": computed["unit_overlaps_indices"],
"sparsity_mask": computed["sparsity_mask"],
}
)
elif method == "circus-omp":
method_kwargs.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
"norms": computed["norms"],
"sparsities": computed["sparsities"],
}
)
valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging)
if np.sum(valid) > 0:
if np.sum(valid) == 1:
Expand Down
114 changes: 63 additions & 51 deletions src/spikeinterface/sortingcomponents/clustering/random_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip
from spikeinterface.core import NumpySorting
from spikeinterface.core import extract_waveforms
from spikeinterface.sortingcomponents.features_from_peaks import compute_features_from_peaks, EnergyFeature
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature
from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever


class RandomProjectionClustering:
Expand All @@ -34,17 +36,17 @@ class RandomProjectionClustering:
"cluster_selection_method": "leaf",
},
"cleaning_kwargs": {},
"waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100},
"radius_um": 100,
"max_spikes_per_unit": 200,
"selection_method": "closest_to_centroid",
"nb_projections": {"ptp": 8, "energy": 2},
"ms_before": 1.5,
"ms_after": 1.5,
"nb_projections": 10,
"ms_before": 1,
"ms_after": 1,
"random_seed": 42,
"shared_memory": False,
"min_values": {"ptp": 0, "energy": 0},
"smoothing_kwargs": {"window_length_ms": 1},
"shared_memory": True,
"tmp_folder": None,
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "10M", "verbose": True, "progress_bar": True},
"job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True},
}

@classmethod
Expand Down Expand Up @@ -74,50 +76,60 @@ def main_function(cls, recording, peaks, params):

np.random.seed(d["random_seed"])

features_params = {}
features_list = []

noise_snippets = None

for proj_type in ["ptp", "energy"]:
if d["nb_projections"][proj_type] > 0:
features_list += [f"random_projections_{proj_type}"]

if d["min_values"][proj_type] == "auto":
if noise_snippets is None:
num_segments = recording.get_num_segments()
num_chunks = 3 * d["max_spikes_per_unit"] // num_segments
noise_snippets = get_random_data_chunks(
recording, num_chunks_per_segment=num_chunks, chunk_size=num_samples, seed=42
)
noise_snippets = noise_snippets.reshape(num_chunks, num_samples, num_chans)

if proj_type == "energy":
data = np.linalg.norm(noise_snippets, axis=1)
min_values = np.median(data, axis=0)
elif proj_type == "ptp":
data = np.ptp(noise_snippets, axis=1)
min_values = np.median(data, axis=0)
elif d["min_values"][proj_type] > 0:
min_values = d["min_values"][proj_type]
else:
min_values = None

projections = np.random.randn(num_chans, d["nb_projections"][proj_type])
features_params[f"random_projections_{proj_type}"] = {
"radius_um": params["radius_um"],
"projections": projections,
"min_values": min_values,
}

features_data = compute_features_from_peaks(
recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"]
if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
tmp_folder = get_global_tmp_folder() / name
else:
tmp_folder = Path(params["tmp_folder"]).absolute()

### Then we extract the SVD features
node0 = PeakRetriever(recording, peaks)
node1 = ExtractDenseWaveforms(
recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"]
)

if len(features_data) > 1:
hdbscan_data = np.hstack((features_data[0], features_data[1]))
else:
hdbscan_data = features_data[0]
node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"])

projections = np.random.randn(num_chans, d["nb_projections"])
projections -= projections.mean(0)
projections /= projections.std(0)

nbefore = int(params["ms_before"] * fs / 1000)
nafter = int(params["ms_after"] * fs / 1000)
nsamples = nbefore + nafter

import scipy

x = np.random.randn(100, nsamples, num_chans).astype(np.float32)
x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1)

ptps = np.ptp(x, axis=1)
a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000))
ydata = np.cumsum(a) / a.sum()
xdata = b[1:]

from scipy.optimize import curve_fit

def sigmoid(x, L, x0, k, b):
y = L / (1 + np.exp(-k * (x - x0))) + b
return y

p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess
popt, pcov = curve_fit(sigmoid, xdata, ydata, p0)

node3 = RandomProjectionsFeature(
recording,
parents=[node0, node2],
return_output=True,
projections=projections,
radius_um=params["radius_um"],
)

pipeline_nodes = [node0, node1, node2, node3]

hdbscan_data = run_node_pipeline(
recording, pipeline_nodes, params["job_kwargs"], job_name="extracting features"
)

import sklearn

Expand All @@ -132,7 +144,7 @@ def main_function(cls, recording, peaks, params):

all_indices = np.arange(0, peak_labels.size)

max_spikes = params["max_spikes_per_unit"]
max_spikes = params["waveforms"]["max_spikes_per_unit"]
selection_method = params["selection_method"]

for unit_ind in labels:
Expand Down
27 changes: 15 additions & 12 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,41 +184,44 @@ def __init__(
return_output=True,
parents=None,
projections=None,
radius_um=150.0,
min_values=None,
sigmoid=None,
radius_um=None,
):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)

self.projections = projections
self.radius_um = radius_um
self.min_values = min_values

self.sigmoid = sigmoid
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um

self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values))

self.radius_um = radius_um
self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um))
self._dtype = recording.get_dtype()

def get_dtype(self):
return self._dtype

def _sigmoid(self, x):
L, x0, k, b = self.sigmoid
y = L / (1 + np.exp(-k * (x - x0))) + b
return y

def compute(self, traces, peaks, waveforms):
all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype)

for main_chan in np.unique(peaks["channel_index"]):
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.neighbours_mask[main_chan])
local_projections = self.projections[chan_inds, :]
wf_ptp = (waveforms[idx][:, :, chan_inds]).ptp(axis=1)
wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)

if self.min_values is not None:
wf_ptp = (wf_ptp / self.min_values[chan_inds]) ** 4
if self.sigmoid is not None:
wf_ptp *= self._sigmoid(wf_ptp)

denom = np.sum(wf_ptp, axis=1)
mask = denom != 0

all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis])

return all_projections


Expand Down
Loading

0 comments on commit 59955c3

Please sign in to comment.