Skip to content

Commit

Permalink
Merge branch 'refactor_matching' of github.com:samuelgarcia/spikeinte…
Browse files Browse the repository at this point in the history
…rface into refactor_matching
  • Loading branch information
samuelgarcia committed Oct 7, 2024
2 parents f5bcb6a + 3de7cad commit 80e6b40
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 80 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar


class PeakSource(PipelineNode):

def get_trace_margin(self):
raise NotImplementedError

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

from pathlib import Path


class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
SorterClass = Spykingcircus2Sorter


if __name__ == "__main__":
from spikeinterface import set_global_job_kwargs

set_global_job_kwargs(n_jobs=1, progress_bar=False)
test = SpykingCircus2SorterCommonTestSuite()
test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters"
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_method = params["matching"]["method"]
matching_params = params["matching"]["method_kwargs"].copy()
matching_params["templates"] = templates
if params["matching"]["method"] in ("tdc-peeler", ):
if params["matching"]["method"] in ("tdc-peeler",):
matching_params["noise_levels"] = noise_levels
spikes = find_spikes_from_templates(
recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None,

sub_recording = recording.frame_slice(t_start, t_stop)
local_params.update({"ignore_inds": ignore_inds + [i]})

spikes, more_outputs = find_spikes_from_templates(
sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs
)
Expand Down
15 changes: 8 additions & 7 deletions src/spikeinterface/sortingcomponents/matching/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@
("segment_index", "int64"),
]


class BaseTemplateMatching(PeakDetector):
def __init__(self, recording, templates, return_output=True, parents=None):
# TODO make a sharedmem of template here
# TODO maybe check that channel_id are the same with recording

assert isinstance(templates, Templates), (
f"The templates supplied is of type {type(templates)} and must be a Templates"
)
assert isinstance(
templates, Templates
), f"The templates supplied is of type {type(templates)} and must be a Templates"
self.templates = templates
PeakDetector.__init__(self, recording, return_output=return_output, parents=parents)

def get_dtype(self):
return np.dtype(_base_matching_dtype)

def get_trace_margin(self):
raise NotImplementedError
raise NotImplementedError

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
spikes = self.compute_matching(traces, start_frame, end_frame, segment_index)
Expand All @@ -37,11 +38,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
spikes = spikes[keep]

# node pipeline need to return a tuple
return (spikes, )
return (spikes,)

def compute_matching(self, traces, start_frame, end_frame, segment_index):
raise NotImplementedError

def get_extra_outputs(self):
# can be overwritten if need to ouput some variables with a dict
return None
return None
78 changes: 39 additions & 39 deletions src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities):

return new_overlaps


class CircusOMPSVDPeeler(BaseTemplateMatching):
"""
Orthogonal Matching Pursuit inspired from Spyking Circus sorter
Expand Down Expand Up @@ -121,17 +122,21 @@ class CircusOMPSVDPeeler(BaseTemplateMatching):
"""

_more_output_keys = [
"norms",
"temporal",
"spatial",
"singular",
"units_overlaps",
"unit_overlaps_indices",
"normed_templates",
"overlaps",
]

def __init__(self, recording, return_output=True, parents=None,
"norms",
"temporal",
"spatial",
"singular",
"units_overlaps",
"unit_overlaps_indices",
"normed_templates",
"overlaps",
]

def __init__(
self,
recording,
return_output=True,
parents=None,
templates=None,
amplitudes=[0.6, np.inf],
stop_criteria="max_failures",
Expand All @@ -142,7 +147,7 @@ def __init__(self, recording, return_output=True, parents=None,
ignore_inds=[],
vicinity=3,
precomputed=None,
):
):

BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None)

Expand All @@ -169,7 +174,6 @@ def __init__(self, recording, return_output=True, parents=None,
assert precomputed[key] is not None, "If templates are provided, %d should also be there" % key
setattr(self, key, precomputed[key])


self.ignore_inds = np.array(ignore_inds)

self.unit_overlaps_tables = {}
Expand All @@ -182,7 +186,6 @@ def __init__(self, recording, return_output=True, parents=None,
else:
self.margin = 2 * self.num_samples


def _prepare_templates(self):

assert self.stop_criteria in ["max_failures", "omp_min_sps", "relative_error"]
Expand Down Expand Up @@ -256,11 +259,8 @@ def get_extra_outputs(self):
output[key] = getattr(self, key)
return output




def get_trace_margin(self):
return self.margin
return self.margin

def compute_matching(self, traces, start_frame, end_frame, segment_index):
import scipy.spatial
Expand Down Expand Up @@ -468,10 +468,8 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
if spikes.size > 0:
order = np.argsort(spikes["sample_index"])
spikes = spikes[order]

return spikes


return spikes


class CircusPeeler(BaseTemplateMatching):
Expand Down Expand Up @@ -519,19 +517,23 @@ class CircusPeeler(BaseTemplateMatching):
"""
def __init__(self, recording, return_output=True, parents=None,

templates=None,
peak_sign="neg",
exclude_sweep_ms=0.1,
jitter_ms=0.1,
detect_threshold=5,
noise_levels=None,
random_chunk_kwargs={},
max_amplitude=1.5,
min_amplitude=0.5,
use_sparse_matrix_threshold=0.25,
):

def __init__(
self,
recording,
return_output=True,
parents=None,
templates=None,
peak_sign="neg",
exclude_sweep_ms=0.1,
jitter_ms=0.1,
detect_threshold=5,
noise_levels=None,
random_chunk_kwargs={},
max_amplitude=1.5,
min_amplitude=0.5,
use_sparse_matrix_threshold=0.25,
):

BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None)

Expand All @@ -544,7 +546,9 @@ def __init__(self, recording, return_output=True, parents=None,

assert HAVE_SKLEARN, "CircusPeeler needs sklearn to work"

assert (use_sparse_matrix_threshold >= 0) and (use_sparse_matrix_threshold <= 1), f"use_sparse_matrix_threshold should be in [0, 1]"
assert (use_sparse_matrix_threshold >= 0) and (
use_sparse_matrix_threshold <= 1
), f"use_sparse_matrix_threshold should be in [0, 1]"

self.num_channels = recording.get_num_channels()
self.num_samples = templates.num_samples
Expand Down Expand Up @@ -580,8 +584,6 @@ def __init__(self, recording, return_output=True, parents=None,
self.margin = max(self.nbefore, self.nafter) * 2
self.peak_sign = peak_sign



def _prepare_templates(self):
import scipy.spatial
import scipy
Expand Down Expand Up @@ -617,7 +619,6 @@ def _prepare_templates(self):
def get_trace_margin(self):
return self.margin


def compute_matching(self, traces, start_frame, end_frame, segment_index):

neighbor_window = self.num_samples - 1
Expand Down Expand Up @@ -702,4 +703,3 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
spikes = spikes[order]

return spikes

1 change: 0 additions & 1 deletion src/spikeinterface/sortingcomponents/matching/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from spikeinterface.core.node_pipeline import run_node_pipeline



def find_spikes_from_templates(
recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs
) -> np.ndarray | tuple[np.ndarray, dict]:
Expand Down
17 changes: 9 additions & 8 deletions src/spikeinterface/sortingcomponents/matching/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@
from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive



from .base import BaseTemplateMatching, _base_matching_dtype


class NaiveMatching(BaseTemplateMatching):
def __init__(self, recording, return_output=True, parents=None,
def __init__(
self,
recording,
return_output=True,
parents=None,
templates=None,
peak_sign="neg",
exclude_sweep_ms=0.1,
detect_threshold=5,
noise_levels=None,
radius_um=100.,
radius_um=100.0,
random_chunk_kwargs={},
):
):

BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None)

Expand All @@ -37,15 +41,13 @@ def __init__(self, recording, return_output=True, parents=None,
self.nafter = self.templates.nafter
self.margin = max(self.nbefore, self.nafter)


def get_trace_margin(self):
return self.margin


def compute_matching(self, traces, start_frame, end_frame, segment_index):

if self.margin > 0:
peak_traces = traces[self.margin:-self.margin, :]
peak_traces = traces[self.margin : -self.margin, :]
else:
peak_traces = traces
peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks(
Expand All @@ -70,4 +72,3 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
spikes["amplitude"][i] = 0.0

return spikes

25 changes: 13 additions & 12 deletions src/spikeinterface/sortingcomponents/matching/tdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,29 @@ class TridesclousPeeler(BaseTemplateMatching):
This method is quite fast but don't give exelent results to resolve
spike collision when templates have high similarity.
"""
def __init__(self, recording, return_output=True, parents=None,

def __init__(
self,
recording,
return_output=True,
parents=None,
templates=None,
peak_sign="neg",
peak_shift_ms=0.2,
detect_threshold=5,
noise_levels=None,
radius_um=100.,
radius_um=100.0,
num_closest=5,
sample_shift=3,
ms_before=0.8,
ms_after=1.2,
num_peeler_loop=2,
num_template_try=1,
):
):

BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None)

# maybe in base?
# maybe in base?
self.templates_array = templates.get_dense_templates()

unit_ids = templates.unit_ids
Expand All @@ -64,7 +69,7 @@ def __init__(self, recording, return_output=True, parents=None,

self.nbefore = templates.nbefore
self.nafter = templates.nafter

self.peak_sign = peak_sign

nbefore_short = int(ms_before * sr / 1000.0)
Expand Down Expand Up @@ -103,6 +108,7 @@ def __init__(self, recording, return_output=True, parents=None,

# distance between units
import scipy

unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean")

# seach for closet units and unitary discriminant vector
Expand All @@ -111,7 +117,7 @@ def __init__(self, recording, return_output=True, parents=None,
order = np.argsort(unit_distances[unit_ind, :])
closest_u = np.arange(unit_ids.size)[order].tolist()
closest_u.remove(unit_ind)
closest_u = np.array(closest_u[: num_closest])
closest_u = np.array(closest_u[:num_closest])

# compute unitary discriminent vector
(chans,) = np.nonzero(self.template_sparsity[unit_ind, :])
Expand Down Expand Up @@ -298,10 +304,8 @@ def _find_spikes_one_level(self, traces, level=0):

spikes["cluster_index"][i] = cluster_index
spikes["amplitude"][i] = amplitude

return spikes


return spikes


if HAVE_NUMBA:
Expand Down Expand Up @@ -346,6 +350,3 @@ def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, d
distances_shift[i] = sum_dist

return distances_shift



Loading

0 comments on commit 80e6b40

Please sign in to comment.