Skip to content

Commit

Permalink
Merge pull request #3447 from samuelgarcia/refactor_matching
Browse files Browse the repository at this point in the history
Refactor matching with nodepiepeline
  • Loading branch information
samuelgarcia authored Oct 7, 2024
2 parents 69bf6e4 + 80e6b40 commit 298e57a
Show file tree
Hide file tree
Showing 13 changed files with 636 additions and 918 deletions.
3 changes: 2 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar


class PeakSource(PipelineNode):
# base class for peak detector

def get_trace_margin(self):
raise NotImplementedError

Expand All @@ -99,6 +99,7 @@ def get_peak_slice(

# this is used in sorting components
class PeakDetector(PeakSource):
# base class for peak detector or template matching
pass


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@

from spikeinterface.sorters import Spykingcircus2Sorter

from pathlib import Path


class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
SorterClass = Spykingcircus2Sorter


if __name__ == "__main__":
from spikeinterface import set_global_job_kwargs

set_global_job_kwargs(n_jobs=1, progress_bar=False)
test = SpykingCircus2SorterCommonTestSuite()
test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters"
test.setUp()
test.test_with_run()
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@

from spikeinterface.sorters import Tridesclous2Sorter

from pathlib import Path


class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
SorterClass = Tridesclous2Sorter


if __name__ == "__main__":
test = Tridesclous2SorterCommonTestSuite()
test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters"
test.setUp()
test.test_with_run()
3 changes: 2 additions & 1 deletion src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_method = params["matching"]["method"]
matching_params = params["matching"]["method_kwargs"].copy()
matching_params["templates"] = templates
matching_params["noise_levels"] = noise_levels
if params["matching"]["method"] in ("tdc-peeler",):
matching_params["noise_levels"] = noise_levels
spikes = find_spikes_from_templates(
recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,21 +602,11 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None,

sub_recording = recording.frame_slice(t_start, t_stop)
local_params.update({"ignore_inds": ignore_inds + [i]})
spikes, computed = find_spikes_from_templates(

spikes, more_outputs = find_spikes_from_templates(
sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs
)
local_params.update(
{
"overlaps": computed["overlaps"],
"normed_templates": computed["normed_templates"],
"norms": computed["norms"],
"temporal": computed["temporal"],
"spatial": computed["spatial"],
"singular": computed["singular"],
"units_overlaps": computed["units_overlaps"],
"unit_overlaps_indices": computed["unit_overlaps_indices"],
}
)
local_params["precomputed"] = more_outputs
valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2 * margin)

if np.sum(valid) > 0:
Expand Down
48 changes: 48 additions & 0 deletions src/spikeinterface/sortingcomponents/matching/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
from spikeinterface.core import Templates
from spikeinterface.core.node_pipeline import PeakDetector

_base_matching_dtype = [
("sample_index", "int64"),
("channel_index", "int64"),
("cluster_index", "int64"),
("amplitude", "float64"),
("segment_index", "int64"),
]


class BaseTemplateMatching(PeakDetector):
def __init__(self, recording, templates, return_output=True, parents=None):
# TODO make a sharedmem of template here
# TODO maybe check that channel_id are the same with recording

assert isinstance(
templates, Templates
), f"The templates supplied is of type {type(templates)} and must be a Templates"
self.templates = templates
PeakDetector.__init__(self, recording, return_output=return_output, parents=parents)

def get_dtype(self):
return np.dtype(_base_matching_dtype)

def get_trace_margin(self):
raise NotImplementedError

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
spikes = self.compute_matching(traces, start_frame, end_frame, segment_index)
spikes["segment_index"] = segment_index

margin = self.get_trace_margin()
if margin > 0 and spikes.size > 0:
keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (traces.shape[0] - margin))
spikes = spikes[keep]

# node pipeline need to return a tuple
return (spikes,)

def compute_matching(self, traces, start_frame, end_frame, segment_index):
raise NotImplementedError

def get_extra_outputs(self):
# can be overwritten if need to ouput some variables with a dict
return None
Loading

0 comments on commit 298e57a

Please sign in to comment.