Skip to content

Commit

Permalink
Sc2 fixes
Browse files Browse the repository at this point in the history
* Fixes

* Patches

* Fixes for SC2 and for split clustering

* debugging clustering

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* WIP

* Default params

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adding gather_func to find_spikes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Gathering mode more explicit for matching

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* WIP

* Fixes for SC2

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP

* Simplifications

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Naming for Sam

* Optimize circus matching engine

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Optimizations

* Remove the limit to chunk sizes in circus-omp-svd

* Naming

* Patch imports

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yger and pre-commit-ci[bot] authored Oct 9, 2024
1 parent badc4c5 commit 275e501
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 135 deletions.
41 changes: 18 additions & 23 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2},
"whitening": {"mode": "local", "regularize": True},
"detection": {"peak_sign": "neg", "detect_threshold": 4},
"selection": {
"method": "uniform",
Expand All @@ -36,7 +37,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"seed": 42,
},
"apply_motion_correction": True,
"motion_correction": {"preset": "nonrigid_fast_and_accurate"},
"motion_correction": {"preset": "dredge_fast"},
"merging": {
"similarity_kwargs": {"method": "cosine", "support": "union", "max_lag_ms": 0.2},
"correlograms_kwargs": {},
Expand All @@ -46,7 +47,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
},
},
"clustering": {"legacy": True},
"matching": {"method": "wobble"},
"matching": {"method": "circus-omp-svd"},
"apply_preprocessing": True,
"matched_filtering": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
Expand All @@ -62,6 +63,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
and also the radius_um used to be considered during clustering",
"sparsity": "A dictionary to be passed to all the calls to sparsify the templates",
"filtering": "A dictionary for the high_pass filter to be used during preprocessing",
"whitening": "A dictionary for the whitening option to be used during preprocessing",
"detection": "A dictionary for the peak detection node (locally_exclusive)",
"selection": "A dictionary for the peak selection node. Default is to use smart_sampling_amplitudes, with a minimum of 20000 peaks\
and 5000 peaks per electrode on average.",
Expand Down Expand Up @@ -109,8 +111,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction
from spikeinterface.sortingcomponents.tools import get_prototype_spike

job_kwargs = params["job_kwargs"]
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs = fix_job_kwargs(params["job_kwargs"])
job_kwargs.update({"progress_bar": verbose})

recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
Expand All @@ -119,7 +120,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
num_channels = recording.get_num_channels()
ms_before = params["general"].get("ms_before", 2)
ms_after = params["general"].get("ms_after", 2)
radius_um = params["general"].get("radius_um", 100)
radius_um = params["general"].get("radius_um", 75)
exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after) / 2)

## First, we are filtering the data
Expand All @@ -143,14 +144,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
print("Motion correction activated (probe geometry compatible)")
motion_folder = sorter_output_folder / "motion"
params["motion_correction"].update({"folder": motion_folder})
recording_f = correct_motion(recording_f, **params["motion_correction"])
recording_f = correct_motion(recording_f, **params["motion_correction"], **job_kwargs)
else:
motion_folder = None

## We need to whiten before the template matching step, to boost the results
# TODO add , regularize=True chen ready
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True)
whitening_kwargs = params["whitening"].copy()
whitening_kwargs["dtype"] = "float32"
whitening_kwargs["radius_um"] = radius_um
if num_channels == 1:
whitening_kwargs["regularize"] = False

recording_w = whiten(recording_f, **whitening_kwargs)
noise_levels = get_noise_levels(recording_w, return_scaled=False)

if recording_w.check_serializability("json"):
Expand All @@ -172,20 +178,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
nbefore = int(ms_before * fs / 1000.0)
nafter = int(ms_after * fs / 1000.0)

peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params)

if params["matched_filtering"]:
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000)
prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before

for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in detection_params:
detection_params.pop(value)

detection_params["chunk_duration"] = "100ms"

peaks = detect_peaks(recording_w, "matched_filtering", **detection_params)
else:
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params)

if verbose:
print("We found %d peaks in total" % len(peaks))
Expand All @@ -196,7 +196,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## We subselect a subset of all the peaks, by making the distributions os SNRs over all
## channels as flat as possible
selection_params = params["selection"]
selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels
selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels)
selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"])

selection_params.update({"noise_levels": noise_levels})
Expand Down Expand Up @@ -281,11 +281,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_job_params = job_kwargs.copy()

if matching_method is not None:
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in matching_job_params:
matching_job_params[value] = None
matching_job_params["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_w, matching_method, method_kwargs=matching_params, **matching_job_params
)
Expand Down
18 changes: 8 additions & 10 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def main_function(cls, recording, peaks, params):
pipeline_nodes = [node0, node1, node2]

if len(params["recursive_kwargs"]) == 0:
from sklearn.decomposition import PCA

all_pc_data = run_node_pipeline(
recording,
Expand All @@ -152,9 +153,9 @@ def main_function(cls, recording, peaks, params):
sub_data = sub_data.reshape(len(sub_data), -1)

if all_pc_data.shape[1] > params["n_svd"][1]:
tsvd = TruncatedSVD(params["n_svd"][1])
tsvd = PCA(params["n_svd"][1], whiten=True)
else:
tsvd = TruncatedSVD(all_pc_data.shape[1])
tsvd = PCA(all_pc_data.shape[1], whiten=True)

hdbscan_data = tsvd.fit_transform(sub_data)
try:
Expand Down Expand Up @@ -184,14 +185,16 @@ def main_function(cls, recording, peaks, params):
)

sparse_mask = node1.neighbours_mask
neighbours_mask = get_channel_distances(recording) < radius_um
neighbours_mask = get_channel_distances(recording) <= radius_um

# np.save(features_folder / "sparse_mask.npy", sparse_mask)
np.save(features_folder / "peaks.npy", peaks)

original_labels = peaks["channel_index"]
from spikeinterface.sortingcomponents.clustering.split import split_clusters

min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50)

peak_labels, _ = split_clusters(
original_labels,
recording,
Expand All @@ -202,7 +205,7 @@ def main_function(cls, recording, peaks, params):
feature_name="sparse_tsvd",
neighbours_mask=neighbours_mask,
waveforms_sparse_mask=sparse_mask,
min_size_split=50,
min_size_split=min_size,
clusterer_kwargs=d["hdbscan_kwargs"],
n_pca_features=params["n_svd"][1],
scale_n_pca_by_depth=True,
Expand Down Expand Up @@ -233,7 +236,7 @@ def main_function(cls, recording, peaks, params):
if d["rank"] is not None:
from spikeinterface.sortingcomponents.matching.circus import compress_templates

_, _, _, templates_array = compress_templates(templates_array, 5)
_, _, _, templates_array = compress_templates(templates_array, d["rank"])

templates = Templates(
templates_array=templates_array,
Expand All @@ -258,13 +261,8 @@ def main_function(cls, recording, peaks, params):
print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids)))

cleaning_matching_params = params["job_kwargs"].copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in cleaning_matching_params:
cleaning_matching_params.pop(value)
cleaning_matching_params["chunk_duration"] = "100ms"
cleaning_matching_params["n_jobs"] = 1
cleaning_matching_params["progress_bar"] = False

cleaning_params = params["cleaning_kwargs"].copy()

labels, peak_labels = remove_duplicates_via_matching(
Expand Down
22 changes: 13 additions & 9 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def split_clusters(
peak_labels,
recording,
features_dict_or_folder,
method="hdbscan_on_local_pca",
method="local_feature_clustering",
method_kwargs={},
recursive=False,
recursive_depth=None,
Expand Down Expand Up @@ -81,7 +81,6 @@ def split_clusters(
) as pool:
labels_set = np.setdiff1d(peak_labels, [-1])
current_max_label = np.max(labels_set) + 1

jobs = []
for label in labels_set:
peak_indices = np.flatnonzero(peak_labels == label)
Expand All @@ -95,15 +94,14 @@ def split_clusters(

for res in iterator:
is_split, local_labels, peak_indices = res.result()
# print(is_split, local_labels, peak_indices)
if not is_split:
continue

mask = local_labels >= 0
peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label
peak_labels[peak_indices[~mask]] = local_labels[~mask]

split_count[peak_indices] += 1

current_max_label += np.max(local_labels[mask]) + 1

if recursive:
Expand All @@ -120,6 +118,7 @@ def split_clusters(
for label in new_labels_set:
peak_indices = np.flatnonzero(peak_labels == label)
if peak_indices.size > 0:
# print('Relaunched', label, len(peak_indices), recursion_level)
jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level))
if progress_bar:
iterator.total += 1
Expand Down Expand Up @@ -187,7 +186,7 @@ def split(
min_size_split=25,
n_pca_features=2,
scale_n_pca_by_depth=False,
minimum_common_channels=2,
minimum_overlap_ratio=0.25,
):
local_labels = np.zeros(peak_indices.size, dtype=np.int64)

Expand All @@ -199,19 +198,22 @@ def split(
# target channel subset is done intersect local channels + neighbours
local_chans = np.unique(peaks["channel_index"][peak_indices])

target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0))
target_intersection_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0))
target_union_channels = np.flatnonzero(np.any(neighbours_mask[local_chans, :], axis=0))
num_intersection = len(target_intersection_channels)
num_union = len(target_union_channels)

# TODO fix this a better way, this when cluster have too few overlapping channels
if target_channels.size < minimum_common_channels:
if (num_intersection / num_union) < minimum_overlap_ratio:
return False, None

aligned_wfs, dont_have_channels = aggregate_sparse_features(
peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels
peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_intersection_channels
)

local_labels[dont_have_channels] = -2
kept = np.flatnonzero(~dont_have_channels)

# print(recursion_level, kept.size, min_size_split)
if kept.size < min_size_split:
return False, None

Expand All @@ -222,6 +224,8 @@ def split(
if flatten_features.shape[1] > n_pca_features:
from sklearn.decomposition import PCA

# from sklearn.decomposition import TruncatedSVD

if scale_n_pca_by_depth:
# tsvd = TruncatedSVD(n_pca_features * recursion_level)
tsvd = PCA(n_pca_features * recursion_level, whiten=True)
Expand Down
Loading

0 comments on commit 275e501

Please sign in to comment.