diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index a72808e176..d90a20902d 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -80,7 +80,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar class PeakSource(PipelineNode): - # base class for peak detector + def get_trace_margin(self): raise NotImplementedError @@ -99,6 +99,7 @@ def get_peak_slice( # this is used in sorting components class PeakDetector(PeakSource): + # base class for peak detector or template matching pass diff --git a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py index 333bcdbc32..df6e3821bb 100644 --- a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py +++ b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py @@ -4,12 +4,18 @@ from spikeinterface.sorters import Spykingcircus2Sorter +from pathlib import Path + class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Spykingcircus2Sorter if __name__ == "__main__": + from spikeinterface import set_global_job_kwargs + + set_global_job_kwargs(n_jobs=1, progress_bar=False) test = SpykingCircus2SorterCommonTestSuite() + test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters" test.setUp() test.test_with_run() diff --git a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py index 58d6c15c8d..b256dd1328 100644 --- a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py @@ -4,6 +4,8 @@ from spikeinterface.sorters import Tridesclous2Sorter +from pathlib import Path + class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Tridesclous2Sorter @@ -11,5 +13,6 @@ class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase if __name__ == "__main__": test = Tridesclous2SorterCommonTestSuite() + test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters" test.setUp() test.test_with_run() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 57755cd759..a180fb4e02 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -226,7 +226,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_method = params["matching"]["method"] matching_params = params["matching"]["method_kwargs"].copy() matching_params["templates"] = templates - matching_params["noise_levels"] = noise_levels + if params["matching"]["method"] in ("tdc-peeler",): + matching_params["noise_levels"] = noise_levels spikes = find_spikes_from_templates( recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs ) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 234be686d0..08a1384333 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -602,21 +602,11 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, sub_recording = recording.frame_slice(t_start, t_stop) local_params.update({"ignore_inds": ignore_inds + [i]}) - spikes, computed = find_spikes_from_templates( + + spikes, more_outputs = find_spikes_from_templates( sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs ) - local_params.update( - { - "overlaps": computed["overlaps"], - "normed_templates": computed["normed_templates"], - "norms": computed["norms"], - "temporal": computed["temporal"], - "spatial": computed["spatial"], - "singular": computed["singular"], - "units_overlaps": computed["units_overlaps"], - "unit_overlaps_indices": computed["unit_overlaps_indices"], - } - ) + local_params["precomputed"] = more_outputs valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2 * margin) if np.sum(valid) > 0: diff --git a/src/spikeinterface/sortingcomponents/matching/base.py b/src/spikeinterface/sortingcomponents/matching/base.py new file mode 100644 index 0000000000..0e60a9e864 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/matching/base.py @@ -0,0 +1,48 @@ +import numpy as np +from spikeinterface.core import Templates +from spikeinterface.core.node_pipeline import PeakDetector + +_base_matching_dtype = [ + ("sample_index", "int64"), + ("channel_index", "int64"), + ("cluster_index", "int64"), + ("amplitude", "float64"), + ("segment_index", "int64"), +] + + +class BaseTemplateMatching(PeakDetector): + def __init__(self, recording, templates, return_output=True, parents=None): + # TODO make a sharedmem of template here + # TODO maybe check that channel_id are the same with recording + + assert isinstance( + templates, Templates + ), f"The templates supplied is of type {type(templates)} and must be a Templates" + self.templates = templates + PeakDetector.__init__(self, recording, return_output=return_output, parents=parents) + + def get_dtype(self): + return np.dtype(_base_matching_dtype) + + def get_trace_margin(self): + raise NotImplementedError + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + spikes = self.compute_matching(traces, start_frame, end_frame, segment_index) + spikes["segment_index"] = segment_index + + margin = self.get_trace_margin() + if margin > 0 and spikes.size > 0: + keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (traces.shape[0] - margin)) + spikes = spikes[keep] + + # node pipeline need to return a tuple + return (spikes,) + + def compute_matching(self, traces, start_frame, end_frame, segment_index): + raise NotImplementedError + + def get_extra_outputs(self): + # can be overwritten if need to ouput some variables with a dict + return None diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ad7391a297..a3624f4296 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -17,7 +17,7 @@ ("segment_index", "int64"), ] -from .main import BaseTemplateMatchingEngine +from .base import BaseTemplateMatching def compress_templates( @@ -89,7 +89,7 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities): return new_overlaps -class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): +class CircusOMPSVDPeeler(BaseTemplateMatching): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -121,147 +121,148 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): ----- """ - _default_params = { - "amplitudes": [0.6, np.inf], - "stop_criteria": "max_failures", - "max_failures": 10, - "omp_min_sps": 0.1, - "relative_error": 5e-5, - "templates": None, - "rank": 5, - "ignore_inds": [], - "vicinity": 3, - } + _more_output_keys = [ + "norms", + "temporal", + "spatial", + "singular", + "units_overlaps", + "unit_overlaps_indices", + "normed_templates", + "overlaps", + ] + + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + amplitudes=[0.6, np.inf], + stop_criteria="max_failures", + max_failures=10, + omp_min_sps=0.1, + relative_error=5e-5, + rank=5, + ignore_inds=[], + vicinity=3, + precomputed=None, + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) + + self.num_channels = recording.get_num_channels() + self.num_samples = templates.num_samples + self.nbefore = templates.nbefore + self.nafter = templates.nafter + self.sampling_frequency = recording.get_sampling_frequency() + self.vicinity = vicinity * self.num_samples + + self.amplitudes = amplitudes + self.stop_criteria = stop_criteria + self.max_failures = max_failures + self.omp_min_sps = omp_min_sps + self.relative_error = relative_error + self.rank = rank + + self.num_templates = len(templates.unit_ids) + + if precomputed is None: + self._prepare_templates() + else: + for key in self._more_output_keys: + assert precomputed[key] is not None, "If templates are provided, %d should also be there" % key + setattr(self, key, precomputed[key]) + + self.ignore_inds = np.array(ignore_inds) + + self.unit_overlaps_tables = {} + for i in range(self.num_templates): + self.unit_overlaps_tables[i] = np.zeros(self.num_templates, dtype=int) + self.unit_overlaps_tables[i][self.unit_overlaps_indices[i]] = np.arange(len(self.unit_overlaps_indices[i])) - @classmethod - def _prepare_templates(cls, d): - templates = d["templates"] - num_templates = len(d["templates"].unit_ids) + if self.vicinity > 0: + self.margin = self.vicinity + else: + self.margin = 2 * self.num_samples + + def _prepare_templates(self): - assert d["stop_criteria"] in ["max_failures", "omp_min_sps", "relative_error"] + assert self.stop_criteria in ["max_failures", "omp_min_sps", "relative_error"] - sparsity = templates.sparsity.mask + sparsity = self.templates.sparsity.mask units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) - d["units_overlaps"] = units_overlaps > 0 - d["unit_overlaps_indices"] = {} - for i in range(num_templates): - (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) + self.units_overlaps = units_overlaps > 0 + self.unit_overlaps_indices = {} + for i in range(self.num_templates): + self.unit_overlaps_indices[i] = np.flatnonzero(self.units_overlaps[i]) - templates_array = templates.get_dense_templates().copy() + templates_array = self.templates.get_dense_templates().copy() # Then we keep only the strongest components - d["temporal"], d["singular"], d["spatial"], templates_array = compress_templates(templates_array, d["rank"]) + self.temporal, self.singular, self.spatial, templates_array = compress_templates(templates_array, self.rank) - d["normed_templates"] = np.zeros(templates_array.shape, dtype=np.float32) - d["norms"] = np.zeros(num_templates, dtype=np.float32) + self.normed_templates = np.zeros(templates_array.shape, dtype=np.float32) + self.norms = np.zeros(self.num_templates, dtype=np.float32) # And get the norms, saving compressed templates for CC matrix - for count in range(num_templates): + for count in range(self.num_templates): template = templates_array[count][:, sparsity[count]] - d["norms"][count] = np.linalg.norm(template) - d["normed_templates"][count][:, sparsity[count]] = template / d["norms"][count] + self.norms[count] = np.linalg.norm(template) + self.normed_templates[count][:, sparsity[count]] = template / self.norms[count] - d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] - d["temporal"] = np.flip(d["temporal"], axis=1) + self.temporal /= self.norms[:, np.newaxis, np.newaxis] + self.temporal = np.flip(self.temporal, axis=1) - d["overlaps"] = [] - d["max_similarity"] = np.zeros((num_templates, num_templates), dtype=np.float32) - for i in range(num_templates): - num_overlaps = np.sum(d["units_overlaps"][i]) - overlapping_units = np.where(d["units_overlaps"][i])[0] + self.overlaps = [] + self.max_similarity = np.zeros((self.num_templates, self.num_templates), dtype=np.float32) + for i in range(self.num_templates): + num_overlaps = np.sum(self.units_overlaps[i]) + overlapping_units = np.flatnonzero(self.units_overlaps[i]) # Reconstruct unit template from SVD Matrices - data = d["temporal"][i] * d["singular"][i][np.newaxis, :] - template_i = np.matmul(data, d["spatial"][i, :, :]) + data = self.temporal[i] * self.singular[i][np.newaxis, :] + template_i = np.matmul(data, self.spatial[i, :, :]) template_i = np.flipud(template_i) - unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32) + unit_overlaps = np.zeros([num_overlaps, 2 * self.num_samples - 1], dtype=np.float32) for count, j in enumerate(overlapping_units): overlapped_channels = sparsity[j] visible_i = template_i[:, overlapped_channels] - spatial_filters = d["spatial"][j, :, overlapped_channels] + spatial_filters = self.spatial[j, :, overlapped_channels] spatially_filtered_template = np.matmul(visible_i, spatial_filters) - visible_i = spatially_filtered_template * d["singular"][j] + visible_i = spatially_filtered_template * self.singular[j] for rank in range(visible_i.shape[1]): - unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full") + unit_overlaps[count, :] += np.convolve(visible_i[:, rank], self.temporal[j][:, rank], mode="full") - d["max_similarity"][i, j] = np.max(unit_overlaps[count]) + self.max_similarity[i, j] = np.max(unit_overlaps[count]) - d["overlaps"].append(unit_overlaps) + self.overlaps.append(unit_overlaps) - if d["amplitudes"] is None: - distances = np.sort(d["max_similarity"], axis=1)[:, ::-1] + if self.amplitudes is None: + distances = np.sort(self.max_similarity, axis=1)[:, ::-1] distances = 1 - distances[:, 1] / 2 - d["amplitudes"] = np.zeros((num_templates, 2)) - d["amplitudes"][:, 0] = distances - d["amplitudes"][:, 1] = np.inf - - d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) - d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) - d["singular"] = d["singular"].T[:, :, np.newaxis] - return d - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - d = cls._default_params.copy() - d.update(kwargs) - - assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) + self.amplitudes = np.zeros((self.num_templates, 2)) + self.amplitudes[:, 0] = distances + self.amplitudes[:, 1] = np.inf - d["num_channels"] = recording.get_num_channels() - d["num_samples"] = d["templates"].num_samples - d["nbefore"] = d["templates"].nbefore - d["nafter"] = d["templates"].nafter - d["sampling_frequency"] = recording.get_sampling_frequency() - d["vicinity"] *= d["num_samples"] + self.spatial = np.moveaxis(self.spatial, [0, 1, 2], [1, 0, 2]) + self.temporal = np.moveaxis(self.temporal, [0, 1, 2], [1, 2, 0]) + self.singular = self.singular.T[:, :, np.newaxis] - if "overlaps" not in d: - d = cls._prepare_templates(d) - else: - for key in [ - "norms", - "temporal", - "spatial", - "singular", - "units_overlaps", - "unit_overlaps_indices", - ]: - assert d[key] is not None, "If templates are provided, %d should also be there" % key - - d["num_templates"] = len(d["templates"].templates_array) - d["ignore_inds"] = np.array(d["ignore_inds"]) - - d["unit_overlaps_tables"] = {} - for i in range(d["num_templates"]): - d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) - d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) - - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - if kwargs["vicinity"] > 0: - margin = kwargs["vicinity"] - else: - margin = 2 * kwargs["num_samples"] - return margin + def get_extra_outputs(self): + output = {} + for key in self._more_output_keys: + output[key] = getattr(self, key) + return output + + def get_trace_margin(self): + return self.margin - @classmethod - def main_function(cls, traces, d): + def compute_matching(self, traces, start_frame, end_frame, segment_index): import scipy.spatial import scipy @@ -269,50 +270,45 @@ def main_function(cls, traces, d): (nrm2,) = scipy.linalg.get_blas_funcs(("nrm2",), dtype=np.float32) - num_templates = d["num_templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - overlaps_array = d["overlaps"] - norms = d["norms"] + overlaps_array = self.overlaps + omp_tol = np.finfo(np.float32).eps - num_samples = d["nafter"] + d["nbefore"] + num_samples = self.nafter + self.nbefore neighbor_window = num_samples - 1 - if isinstance(d["amplitudes"], list): - min_amplitude, max_amplitude = d["amplitudes"] + if isinstance(self.amplitudes, list): + min_amplitude, max_amplitude = self.amplitudes else: - min_amplitude, max_amplitude = d["amplitudes"][:, 0], d["amplitudes"][:, 1] + min_amplitude, max_amplitude = self.amplitudes[:, 0], self.amplitudes[:, 1] min_amplitude = min_amplitude[:, np.newaxis] max_amplitude = max_amplitude[:, np.newaxis] - ignore_inds = d["ignore_inds"] - vicinity = d["vicinity"] num_timesteps = len(traces) num_peaks = num_timesteps - num_samples + 1 - conv_shape = (num_templates, num_peaks) + conv_shape = (self.num_templates, num_peaks) scalar_products = np.zeros(conv_shape, dtype=np.float32) # Filter using overlap-and-add convolution - if len(ignore_inds) > 0: - not_ignored = ~np.isin(np.arange(num_templates), ignore_inds) - spatially_filtered_data = np.matmul(d["spatial"][:, not_ignored, :], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"][:, not_ignored, :] + if len(self.ignore_inds) > 0: + not_ignored = ~np.isin(np.arange(self.num_templates), self.ignore_inds) + spatially_filtered_data = np.matmul(self.spatial[:, not_ignored, :], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * self.singular[:, not_ignored, :] objective_by_rank = scipy.signal.oaconvolve( - scaled_filtered_data, d["temporal"][:, not_ignored, :], axes=2, mode="valid" + scaled_filtered_data, self.temporal[:, not_ignored, :], axes=2, mode="valid" ) scalar_products[not_ignored] += np.sum(objective_by_rank, axis=0) - scalar_products[ignore_inds] = -np.inf + scalar_products[self.ignore_inds] = -np.inf else: - spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"] - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") + spatially_filtered_data = np.matmul(self.spatial, traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * self.singular + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, self.temporal, axes=2, mode="valid") scalar_products += np.sum(objective_by_rank, axis=0) num_spikes = 0 spikes = np.empty(scalar_products.size, dtype=spike_dtype) - M = np.zeros((num_templates, num_templates), dtype=np.float32) + M = np.zeros((self.num_templates, self.num_templates), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -325,13 +321,13 @@ def main_function(cls, traces, d): all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) - if d["stop_criteria"] == "omp_min_sps": - stop_criteria = d["omp_min_sps"] * np.maximum(d["norms"], np.sqrt(num_channels * num_samples)) - elif d["stop_criteria"] == "max_failures": + if self.stop_criteria == "omp_min_sps": + stop_criteria = self.omp_min_sps * np.maximum(self.norms, np.sqrt(self.num_channels * num_samples)) + elif self.stop_criteria == "max_failures": num_valids = 0 - nb_failures = d["max_failures"] - elif d["stop_criteria"] == "relative_error": - if len(ignore_inds) > 0: + nb_failures = self.max_failures + elif self.stop_criteria == "relative_error": + if len(self.ignore_inds) > 0: new_error = np.linalg.norm(scalar_products[not_ignored]) else: new_error = np.linalg.norm(scalar_products) @@ -350,8 +346,8 @@ def main_function(cls, traces, d): myindices = selection[0, idx] local_overlaps = overlaps_array[best_cluster_ind] - overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] - table = d["unit_overlaps_tables"][best_cluster_ind] + overlapping_templates = self.unit_overlaps_indices[best_cluster_ind] + table = self.unit_overlaps_tables[best_cluster_ind] if num_selection == M.shape[0]: Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) @@ -362,7 +358,7 @@ def main_function(cls, traces, d): a, b = myindices[mask], myline[mask] M[num_selection, idx[mask]] = local_overlaps[table[a], b] - if vicinity == 0: + if self.vicinity == 0: scipy.linalg.solve_triangular( M[:num_selection, :num_selection], M[num_selection, :num_selection], @@ -378,7 +374,7 @@ def main_function(cls, traces, d): break M[num_selection, num_selection] = np.sqrt(Lkk) else: - is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + is_in_vicinity = np.where(np.abs(delta_t) < self.vicinity)[0] if len(is_in_vicinity) > 0: L = M[is_in_vicinity, :][:, is_in_vicinity] @@ -403,15 +399,15 @@ def main_function(cls, traces, d): selection = all_selections[:, :num_selection] res_sps = full_sps[selection[0], selection[1]] - if vicinity == 0: + if self.vicinity == 0: all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - all_amplitudes /= norms[selection[0]] + all_amplitudes /= self.norms[selection[0]] else: is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) all_amplitudes = np.append(all_amplitudes, np.float32(1)) L = M[is_in_vicinity, :][:, is_in_vicinity] all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) - all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] + all_amplitudes[is_in_vicinity] /= self.norms[selection[0][is_in_vicinity]] diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] @@ -419,10 +415,10 @@ def main_function(cls, traces, d): for i in modified: tmp_best, tmp_peak = selection[:, i] - diff_amp = diff_amplitudes[i] * norms[tmp_best] + diff_amp = diff_amplitudes[i] * self.norms[tmp_best] local_overlaps = overlaps_array[tmp_best] - overlapping_templates = d["units_overlaps"][tmp_best] + overlapping_templates = self.units_overlaps[tmp_best] if not tmp_peak in neighbors.keys(): idx = [max(0, tmp_peak - neighbor_window), min(num_peaks, tmp_peak + num_samples)] @@ -436,44 +432,47 @@ def main_function(cls, traces, d): scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add # We stop when updates do not modify the chosen spikes anymore - if d["stop_criteria"] == "omp_min_sps": + if self.stop_criteria == "omp_min_sps": is_valid = scalar_products > stop_criteria[:, np.newaxis] do_loop = np.any(is_valid) - elif d["stop_criteria"] == "max_failures": + elif self.stop_criteria == "max_failures": is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) new_num_valids = np.sum(is_valid) if (new_num_valids - num_valids) > 0: - nb_failures = d["max_failures"] + nb_failures = self.max_failures else: nb_failures -= 1 num_valids = new_num_valids do_loop = nb_failures > 0 - elif d["stop_criteria"] == "relative_error": + elif self.stop_criteria == "relative_error": previous_error = new_error - if len(ignore_inds) > 0: + if len(self.ignore_inds) > 0: new_error = np.linalg.norm(scalar_products[not_ignored]) else: new_error = np.linalg.norm(scalar_products) delta_error = np.abs(new_error / previous_error - 1) - do_loop = delta_error > d["relative_error"] + do_loop = delta_error > self.relative_error is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) num_spikes = len(valid_indices[0]) - spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["sample_index"][:num_spikes] = valid_indices[1] + self.nbefore spikes["channel_index"][:num_spikes] = 0 spikes["cluster_index"][:num_spikes] = valid_indices[0] spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] + print("yep0", spikes.size, num_spikes, spikes.shape, spikes.dtype) spikes = spikes[:num_spikes] - order = np.argsort(spikes["sample_index"]) - spikes = spikes[order] + print("yep1", spikes.size, spikes.shape, spikes.dtype) + if spikes.size > 0: + order = np.argsort(spikes["sample_index"]) + spikes = spikes[order] return spikes -class CircusPeeler(BaseTemplateMatchingEngine): +class CircusPeeler(BaseTemplateMatching): """ Greedy Template-matching ported from the Spyking Circus sorter @@ -519,115 +518,25 @@ class CircusPeeler(BaseTemplateMatchingEngine): """ - _default_params = { - "peak_sign": "neg", - "exclude_sweep_ms": 0.1, - "jitter_ms": 0.1, - "detect_threshold": 5, - "noise_levels": None, - "random_chunk_kwargs": {}, - "max_amplitude": 1.5, - "min_amplitude": 0.5, - "use_sparse_matrix_threshold": 0.25, - "templates": None, - } - - @classmethod - def _prepare_templates(cls, d): - import scipy.spatial - import scipy - - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - use_sparse_matrix_threshold = d["use_sparse_matrix_threshold"] - - d["norms"] = np.zeros(num_templates, dtype=np.float32) - - all_units = d["templates"].unit_ids - - sparsity = templates.sparsity.mask + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + peak_sign="neg", + exclude_sweep_ms=0.1, + jitter_ms=0.1, + detect_threshold=5, + noise_levels=None, + random_chunk_kwargs={}, + max_amplitude=1.5, + min_amplitude=0.5, + use_sparse_matrix_threshold=0.25, + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - templates_array = templates.get_dense_templates() - d["sparsities"] = {} - d["normed_templates"] = {} - - for count, unit_id in enumerate(all_units): - (d["sparsities"][count],) = np.nonzero(sparsity[count]) - d["norms"][count] = np.linalg.norm(templates_array[count]) - templates_array[count] /= d["norms"][count] - d["normed_templates"][count] = templates_array[count][:, sparsity[count]] - - templates_array = templates_array.reshape(num_templates, -1) - - nnz = np.sum(templates_array != 0) / (num_templates * num_samples * num_channels) - if nnz <= use_sparse_matrix_threshold: - templates_array = scipy.sparse.csr_matrix(templates_array) - print(f"Templates are automatically sparsified (sparsity level is {nnz})") - d["is_dense"] = False - else: - d["is_dense"] = True - - d["circus_templates"] = templates_array - - return d - - # @classmethod - # def _mcc_error(cls, bounds, good, bad): - # fn = np.sum((good < bounds[0]) | (good > bounds[1])) - # fp = np.sum((bounds[0] <= bad) & (bad <= bounds[1])) - # tp = np.sum((bounds[0] <= good) & (good <= bounds[1])) - # tn = np.sum((bad < bounds[0]) | (bad > bounds[1])) - # denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) - # if denom > 0: - # mcc = 1 - (tp * tn - fp * fn) / np.sqrt(denom) - # else: - # mcc = 1 - # return mcc - - # @classmethod - # def _cost_function_mcc(cls, bounds, good, bad, delta_amplitude, alpha): - # # We want a minimal error, with the larger bounds that are possible - # cost = alpha * cls._mcc_error(bounds, good, bad) + (1 - alpha) * np.abs( - # (1 - (bounds[1] - bounds[0]) / delta_amplitude) - # ) - # return cost - - # @classmethod - # def _optimize_amplitudes(cls, noise_snippets, d): - # parameters = d - # waveform_extractor = parameters["waveform_extractor"] - # templates = parameters["templates"] - # num_templates = parameters["num_templates"] - # max_amplitude = parameters["max_amplitude"] - # min_amplitude = parameters["min_amplitude"] - # alpha = 0.5 - # norms = parameters["norms"] - # all_units = list(waveform_extractor.sorting.unit_ids) - - # parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) - # noise = templates.dot(noise_snippets) / norms[:, np.newaxis] - - # all_amps = {} - # for count, unit_id in enumerate(all_units): - # waveform = waveform_extractor.get_waveforms(unit_id, force_dense=True) - # snippets = waveform.reshape(waveform.shape[0], -1).T - # amps = templates.dot(snippets) / norms[:, np.newaxis] - # good = amps[count, :].flatten() - - # sub_amps = amps[np.concatenate((np.arange(count), np.arange(count + 1, num_templates))), :] - # bad = sub_amps[sub_amps >= good] - # bad = np.concatenate((bad, noise[count])) - # cost_kwargs = [good, bad, max_amplitude - min_amplitude, alpha] - # cost_bounds = [(min_amplitude, 1), (1, max_amplitude)] - # res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) - # parameters["amplitudes"][count] = res.x - - # return d - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): try: from sklearn.feature_extraction.image import extract_patches_2d @@ -636,108 +545,93 @@ def initialize_and_check_kwargs(cls, recording, kwargs): HAVE_SKLEARN = False assert HAVE_SKLEARN, "CircusPeeler needs sklearn to work" - d = cls._default_params.copy() - d.update(kwargs) - # assert isinstance(d['waveform_extractor'], WaveformExtractor) - for v in ["use_sparse_matrix_threshold"]: - assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + assert (use_sparse_matrix_threshold >= 0) and ( + use_sparse_matrix_threshold <= 1 + ), f"use_sparse_matrix_threshold should be in [0, 1]" - d["num_channels"] = recording.get_num_channels() - d["num_samples"] = d["templates"].num_samples - d["num_templates"] = len(d["templates"].unit_ids) + self.num_channels = recording.get_num_channels() + self.num_samples = templates.num_samples + self.num_templates = len(templates.unit_ids) - if d["noise_levels"] is None: + if noise_levels is None: print("CircusPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + noise_levels = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] + self.abs_threholds = noise_levels * detect_threshold - if "overlaps" not in d: - d = cls._prepare_templates(d) - d["overlaps"] = compute_overlaps( - d["normed_templates"], - d["num_samples"], - d["num_channels"], - d["sparsities"], - ) + self.use_sparse_matrix_threshold = use_sparse_matrix_threshold + self._prepare_templates() + self.overlaps = compute_overlaps( + self.normed_templates, + self.num_samples, + self.num_channels, + self.sparsities, + ) + + self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) + + self.nbefore = templates.nbefore + self.nafter = templates.nafter + self.patch_sizes = (templates.num_samples, self.num_channels) + self.sym_patch = self.nbefore == self.nafter + self.jitter = int(jitter_ms * recording.get_sampling_frequency() / 1000.0) + + self.amplitudes = np.zeros((self.num_templates, 2), dtype=np.float32) + self.amplitudes[:, 0] = min_amplitude + self.amplitudes[:, 1] = max_amplitude + + self.margin = max(self.nbefore, self.nafter) * 2 + self.peak_sign = peak_sign + + def _prepare_templates(self): + import scipy.spatial + import scipy + + self.norms = np.zeros(self.num_templates, dtype=np.float32) + + all_units = self.templates.unit_ids + + sparsity = self.templates.sparsity.mask + + templates_array = self.templates.get_dense_templates() + self.sparsities = {} + self.normed_templates = {} + + for count, unit_id in enumerate(all_units): + self.sparsities[count] = np.flatnonzero(sparsity[count]) + self.norms[count] = np.linalg.norm(templates_array[count]) + templates_array[count] /= self.norms[count] + self.normed_templates[count] = templates_array[count][:, sparsity[count]] + + templates_array = templates_array.reshape(self.num_templates, -1) + + nnz = np.sum(templates_array != 0) / (self.num_templates * self.num_samples * self.num_channels) + if nnz <= self.use_sparse_matrix_threshold: + templates_array = scipy.sparse.csr_matrix(templates_array) + print(f"Templates are automatically sparsified (sparsity level is {nnz})") + self.is_dense = False else: - for key in ["circus_templates", "norms"]: - assert d[key] is not None, "If templates are provided, %d should also be there" % key + self.is_dense = True - d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) + self.circus_templates = templates_array - d["nbefore"] = d["templates"].nbefore - d["nafter"] = d["templates"].nafter - d["patch_sizes"] = ( - d["templates"].num_samples, - d["num_channels"], - ) - d["sym_patch"] = d["nbefore"] == d["nafter"] - d["jitter"] = int(d["jitter_ms"] * recording.get_sampling_frequency() / 1000.0) - - d["amplitudes"] = np.zeros((d["num_templates"], 2), dtype=np.float32) - d["amplitudes"][:, 0] = d["min_amplitude"] - d["amplitudes"][:, 1] = d["max_amplitude"] - # num_segments = recording.get_num_segments() - # if d["waveform_extractor"]._params["max_spikes_per_unit"] is None: - # num_snippets = 1000 - # else: - # num_snippets = 2 * d["waveform_extractor"]._params["max_spikes_per_unit"] - - # num_chunks = num_snippets // num_segments - # noise_snippets = get_random_data_chunks( - # recording, num_chunks_per_segment=num_chunks, chunk_size=d["num_samples"], seed=42 - # ) - # noise_snippets = ( - # noise_snippets.reshape(num_chunks, d["num_samples"], d["num_channels"]) - # .reshape(num_chunks, -1) - # .T - # ) - # parameters = cls._optimize_amplitudes(noise_snippets, d) - - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) - return margin - - @classmethod - def main_function(cls, traces, d): - peak_sign = d["peak_sign"] - abs_threholds = d["abs_threholds"] - exclude_sweep_size = d["exclude_sweep_size"] - templates = d["circus_templates"] - num_templates = d["num_templates"] - overlaps = d["overlaps"] - margin = d["margin"] - norms = d["norms"] - jitter = d["jitter"] - patch_sizes = d["patch_sizes"] - num_samples = d["nafter"] + d["nbefore"] - neighbor_window = num_samples - 1 - amplitudes = d["amplitudes"] - sym_patch = d["sym_patch"] + def get_trace_margin(self): + return self.margin + + def compute_matching(self, traces, start_frame, end_frame, segment_index): + + neighbor_window = self.num_samples - 1 - peak_traces = traces[margin // 2 : -margin // 2, :] + peak_traces = traces[self.margin // 2 : -self.margin // 2, :] peak_sample_index, peak_chan_ind = DetectPeakByChannel.detect_peaks( - peak_traces, peak_sign, abs_threholds, exclude_sweep_size + peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size ) from sklearn.feature_extraction.image import extract_patches_2d - if jitter > 0: - jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-jitter, jitter) - jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * jitter) + if self.jitter > 0: + jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-self.jitter, self.jitter) + jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * self.jitter) mask = (jittered_peaks > 0) & (jittered_peaks < len(peak_traces)) jittered_peaks = jittered_peaks[mask] jittered_channels = jittered_channels[mask] @@ -749,26 +643,26 @@ def main_function(cls, traces, d): num_peaks = len(peak_sample_index) - if sym_patch: - snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_index] - peak_sample_index += margin // 2 + if self.sym_patch: + snippets = extract_patches_2d(traces, self.patch_sizes)[peak_sample_index] + peak_sample_index += self.margin // 2 else: - peak_sample_index += margin // 2 - snippet_window = np.arange(-d["nbefore"], d["nafter"]) + peak_sample_index += self.margin // 2 + snippet_window = np.arange(-self.nbefore, self.nafter) snippets = traces[peak_sample_index[:, np.newaxis] + snippet_window] if num_peaks > 0: snippets = snippets.reshape(num_peaks, -1) - scalar_products = templates.dot(snippets.T) + scalar_products = self.circus_templates.dot(snippets.T) else: - scalar_products = np.zeros((num_templates, 0), dtype=np.float32) + scalar_products = np.zeros((self.num_templates, 0), dtype=np.float32) num_spikes = 0 spikes = np.empty(scalar_products.size, dtype=spike_dtype) - idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) + idx_lookup = np.arange(scalar_products.size).reshape(self.num_templates, -1) - min_sps = (amplitudes[:, 0] * norms)[:, np.newaxis] - max_sps = (amplitudes[:, 1] * norms)[:, np.newaxis] + min_sps = (self.amplitudes[:, 0] * self.norms)[:, np.newaxis] + max_sps = (self.amplitudes[:, 1] * self.norms)[:, np.newaxis] is_valid = (scalar_products > min_sps) & (scalar_products < max_sps) @@ -787,7 +681,7 @@ def main_function(cls, traces, d): idx_neighbor = peak_data[is_valid_nn[0] : is_valid_nn[1]] + neighbor_window if not best_cluster_ind in cached_overlaps.keys(): - cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() + cached_overlaps[best_cluster_ind] = self.overlaps[best_cluster_ind].toarray() to_add = -best_amplitude * cached_overlaps[best_cluster_ind][:, idx_neighbor] @@ -802,7 +696,7 @@ def main_function(cls, traces, d): is_valid = (scalar_products > min_sps) & (scalar_products < max_sps) - spikes["amplitude"][:num_spikes] /= norms[spikes["cluster_index"][:num_spikes]] + spikes["amplitude"][:num_spikes] /= self.norms[spikes["cluster_index"][:num_spikes]] spikes = spikes[:num_spikes] order = np.argsort(spikes["sample_index"]) diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 6e5267cb70..f423d55e2a 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -3,8 +3,11 @@ from threadpoolctl import threadpool_limits import numpy as np -from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs -from spikeinterface.core import get_chunk_with_margin +# from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs +# from spikeinterface.core import get_chunk_with_margin + +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.node_pipeline import run_node_pipeline def find_spikes_from_templates( @@ -21,7 +24,7 @@ def find_spikes_from_templates( method_kwargs : dict, optional Keyword arguments for the chosen method extra_outputs : bool - If True then method_kwargs is also returned + If True then a dict is also returned is also returned **job_kwargs : dict Parameters for ChunkRecordingExecutor verbose : Bool, default: False @@ -31,9 +34,8 @@ def find_spikes_from_templates( ------- spikes : ndarray Spikes found from templates. - method_kwargs: + outputs: Optionaly returns for debug purpose. - """ from .method_list import matching_methods @@ -42,117 +44,19 @@ def find_spikes_from_templates( job_kwargs = fix_job_kwargs(job_kwargs) method_class = matching_methods[method] + node0 = method_class(recording, **method_kwargs) + nodes = [node0] - # initialize - method_kwargs = method_class.initialize_and_check_kwargs(recording, method_kwargs) - - # add - method_kwargs["margin"] = method_class.get_margin(recording, method_kwargs) - - # serialiaze for worker - method_kwargs_seralized = method_class.serialize_method_kwargs(method_kwargs) - - # and run - func = _find_spikes_chunk - init_func = _init_worker_find_spikes - init_args = (recording, method, method_kwargs_seralized) - processor = ChunkRecordingExecutor( + spikes = run_node_pipeline( recording, - func, - init_func, - init_args, - handle_returns=True, + nodes, + job_kwargs, job_name=f"find spikes ({method})", - verbose=verbose, - **job_kwargs, + gather_mode="memory", + squeeze_output=True, ) - spikes = processor.run() - - spikes = np.concatenate(spikes) - if extra_outputs: - return spikes, method_kwargs + outputs = node0.get_extra_outputs() + return spikes, outputs else: return spikes - - -def _init_worker_find_spikes(recording, method, method_kwargs): - """Initialize worker for finding spikes.""" - - from .method_list import matching_methods - - method_class = matching_methods[method] - method_kwargs = method_class.unserialize_in_worker(method_kwargs) - - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["method"] = method - worker_ctx["method_kwargs"] = method_kwargs - worker_ctx["function"] = method_class.main_function - - return worker_ctx - - -def _find_spikes_chunk(segment_index, start_frame, end_frame, worker_ctx): - """Find spikes from a chunk of data.""" - - # recover variables of the worker - recording = worker_ctx["recording"] - method = worker_ctx["method"] - method_kwargs = worker_ctx["method_kwargs"] - margin = method_kwargs["margin"] - - # load trace in memory given some margin - recording_segment = recording._recording_segments[segment_index] - traces, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, margin, add_zeros=True - ) - - function = worker_ctx["function"] - - with threadpool_limits(limits=1): - spikes = function(traces, method_kwargs) - - # remove spikes in margin - if margin > 0: - keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (traces.shape[0] - margin)) - spikes = spikes[keep] - - spikes["sample_index"] += start_frame - margin - spikes["segment_index"] = segment_index - return spikes - - -# generic class for template engine -class BaseTemplateMatchingEngine: - default_params = {} - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - """This function runs before loops""" - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def serialize_method_kwargs(cls, kwargs): - """This function serializes kwargs to distribute them to workers""" - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def unserialize_in_worker(cls, recording, kwargs): - """This function unserializes kwargs in workers""" - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def get_margin(cls, recording, kwargs): - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def main_function(cls, traces, method_kwargs): - """This function returns the number of samples for the chunk margins""" - # need to be implemented in subclass - raise NotImplementedError diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 0dc71d789b..26f093c187 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -4,115 +4,68 @@ import numpy as np -from spikeinterface.core import get_noise_levels, get_channel_distances, get_random_data_chunks +from spikeinterface.core import get_noise_levels, get_channel_distances from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive -from spikeinterface.core.template import Templates - -spike_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("cluster_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), -] - - -from .main import BaseTemplateMatchingEngine - - -class NaiveMatching(BaseTemplateMatchingEngine): - """ - This is a naive template matching that does not resolve collision - and does not take in account sparsity. - It just minimizes the distance to templates for detected peaks. - - It is implemented for benchmarking against this low quality template matching. - And also as an example how to deal with methods_kwargs, margin, intit, func, ... - """ - - default_params = { - "templates": None, - "peak_sign": "neg", - "exclude_sweep_ms": 0.1, - "detect_threshold": 5, - "noise_levels": None, - "radius_um": 100, - "random_chunk_kwargs": {}, - } - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - d = cls.default_params.copy() - d.update(kwargs) - - assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) - - templates = d["templates"] - if d["noise_levels"] is None: - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] - - channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["radius_um"] +from .base import BaseTemplateMatching, _base_matching_dtype - d["nbefore"] = templates.nbefore - d["nafter"] = templates.nafter - d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) +class NaiveMatching(BaseTemplateMatching): + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + peak_sign="neg", + exclude_sweep_ms=0.1, + detect_threshold=5, + noise_levels=None, + radius_um=100.0, + random_chunk_kwargs={}, + ): - return d + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - @classmethod - def get_margin(cls, recording, kwargs): - margin = max(kwargs["nbefore"], kwargs["nafter"]) - return margin + self.templates_array = self.templates.get_dense_templates() - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def main_function(cls, traces, method_kwargs): - peak_sign = method_kwargs["peak_sign"] - abs_threholds = method_kwargs["abs_threholds"] - exclude_sweep_size = method_kwargs["exclude_sweep_size"] - neighbours_mask = method_kwargs["neighbours_mask"] - templates_array = method_kwargs["templates"].get_dense_templates() + if noise_levels is None: + noise_levels = get_noise_levels(recording, **random_chunk_kwargs, return_scaled=False) + self.abs_threholds = noise_levels * detect_threshold + self.peak_sign = peak_sign + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance < radius_um + self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) + self.nbefore = self.templates.nbefore + self.nafter = self.templates.nafter + self.margin = max(self.nbefore, self.nafter) - nbefore = method_kwargs["nbefore"] - nafter = method_kwargs["nafter"] + def get_trace_margin(self): + return self.margin - margin = method_kwargs["margin"] + def compute_matching(self, traces, start_frame, end_frame, segment_index): - if margin > 0: - peak_traces = traces[margin:-margin, :] + if self.margin > 0: + peak_traces = traces[self.margin : -self.margin, :] else: peak_traces = traces peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, peak_sign, abs_threholds, exclude_sweep_size, neighbours_mask + peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask ) - peak_sample_ind += margin + peak_sample_ind += self.margin - spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype) + spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind - spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template + spikes["channel_index"] = peak_chan_ind # naively take the closest template for i in range(peak_sample_ind.size): - i0 = peak_sample_ind[i] - nbefore - i1 = peak_sample_ind[i] + nafter + i0 = peak_sample_ind[i] - self.nbefore + i1 = peak_sample_ind[i] + self.nafter waveforms = traces[i0:i1, :] - dist = np.sum(np.sum((templates_array - waveforms[None, :, :]) ** 2, axis=1), axis=1) + dist = np.sum(np.sum((self.templates_array - waveforms[None, :, :]) ** 2, axis=1), axis=1) cluster_index = np.argmin(dist) spikes["cluster_index"][i] = cluster_index diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index e66929e2b1..56457fe2fa 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -11,15 +11,8 @@ from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive from spikeinterface.core.template import Templates -spike_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("cluster_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), -] +from .base import BaseTemplateMatching, _base_matching_dtype -from .main import BaseTemplateMatchingEngine try: import numba @@ -30,7 +23,7 @@ HAVE_NUMBA = False -class TridesclousPeeler(BaseTemplateMatchingEngine): +class TridesclousPeeler(BaseTemplateMatching): """ Template-matching ported from Tridesclous sorter. @@ -45,87 +38,73 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): spike collision when templates have high similarity. """ - default_params = { - "templates": None, - "peak_sign": "neg", - "peak_shift_ms": 0.2, - "detect_threshold": 5, - "noise_levels": None, - "radius_um": 100, - "num_closest": 5, - "sample_shift": 3, - "ms_before": 0.8, - "ms_after": 1.2, - "num_peeler_loop": 2, - "num_template_try": 1, - } - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - assert HAVE_NUMBA, "TridesclousPeeler needs numba to be installed" - - d = cls.default_params.copy() - d.update(kwargs) - - assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + peak_sign="neg", + peak_shift_ms=0.2, + detect_threshold=5, + noise_levels=None, + radius_um=100.0, + num_closest=5, + sample_shift=3, + ms_before=0.8, + ms_after=1.2, + num_peeler_loop=2, + num_template_try=1, + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) + + # maybe in base? + self.templates_array = templates.get_dense_templates() - templates = d["templates"] unit_ids = templates.unit_ids - channel_ids = templates.channel_ids + channel_ids = recording.channel_ids + + sr = recording.sampling_frequency - sr = templates.sampling_frequency + self.nbefore = templates.nbefore + self.nafter = templates.nafter - d["nbefore"] = templates.nbefore - d["nafter"] = templates.nafter - templates_array = templates.get_dense_templates() + self.peak_sign = peak_sign - nbefore_short = int(d["ms_before"] * sr / 1000.0) - nafter_short = int(d["ms_before"] * sr / 1000.0) + nbefore_short = int(ms_before * sr / 1000.0) + nafter_short = int(ms_after * sr / 1000.0) assert nbefore_short <= templates.nbefore assert nafter_short <= templates.nafter - d["nbefore_short"] = nbefore_short - d["nafter_short"] = nafter_short + self.nbefore_short = nbefore_short + self.nafter_short = nafter_short s0 = templates.nbefore - nbefore_short s1 = -(templates.nafter - nafter_short) if s1 == 0: s1 = None - templates_short = templates_array[:, slice(s0, s1), :].copy() - d["templates_short"] = templates_short + # TODO check with out copy + self.templates_short = self.templates_array[:, slice(s0, s1), :].copy() - d["peak_shift"] = int(d["peak_shift_ms"] / 1000 * sr) + self.peak_shift = int(peak_shift_ms / 1000 * sr) - if d["noise_levels"] is None: - print("TridesclousPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording) + assert noise_levels is not None, "TridesclousPeeler : noise should be computed outside" - d["abs_thresholds"] = d["noise_levels"] * d["detect_threshold"] + self.abs_thresholds = noise_levels * detect_threshold channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["radius_um"] - - sparsity = compute_sparsity( - templates, method="best_channels" - ) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) - template_sparsity_inds = sparsity.unit_id_to_channel_indices - template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - for unit_index, unit_id in enumerate(unit_ids): - chan_inds = template_sparsity_inds[unit_id] - template_sparsity[unit_index, chan_inds] = True + self.neighbours_mask = channel_distance < radius_um - d["template_sparsity"] = template_sparsity + if templates.sparsity is not None: + self.template_sparsity = templates.sparsity.mask + else: + self.template_sparsity = np.ones((unit_ids.size, channel_ids.size), dtype=bool) - extremum_channel = get_template_extremum_channel(templates, peak_sign=d["peak_sign"], outputs="index") + extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") # as numpy vector - extremum_channel = np.array([extremum_channel[unit_id] for unit_id in unit_ids], dtype="int64") - d["extremum_channel"] = extremum_channel + self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") channel_locations = templates.probe.contact_positions - - # TODO try it with real locaion - unit_locations = channel_locations[extremum_channel] - # ~ print(unit_locations) + unit_locations = channel_locations[self.extremum_channel] # distance between units import scipy @@ -138,15 +117,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): order = np.argsort(unit_distances[unit_ind, :]) closest_u = np.arange(unit_ids.size)[order].tolist() closest_u.remove(unit_ind) - closest_u = np.array(closest_u[: d["num_closest"]]) + closest_u = np.array(closest_u[:num_closest]) # compute unitary discriminent vector - (chans,) = np.nonzero(d["template_sparsity"][unit_ind, :]) - template_sparse = templates_array[unit_ind, :, :][:, chans] + (chans,) = np.nonzero(self.template_sparsity[unit_ind, :]) + template_sparse = self.templates_array[unit_ind, :, :][:, chans] closest_vec = [] # against N closets for u in closest_u: - vec = templates_array[u, :, :][:, chans] - template_sparse + vec = self.templates_array[u, :, :][:, chans] - template_sparse vec /= np.sum(vec**2) closest_vec.append((u, vec)) # against noise @@ -154,47 +133,38 @@ def initialize_and_check_kwargs(cls, recording, kwargs): closest_units.append(closest_vec) - d["closest_units"] = closest_units + self.closest_units = closest_units # distance channel from unit import scipy distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") - near_cluster_mask = distances < d["radius_um"] + near_cluster_mask = distances < radius_um # nearby cluster for each channel - possible_clusters_by_channel = [] + self.possible_clusters_by_channel = [] for channel_index in range(distances.shape[0]): (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) - possible_clusters_by_channel.append(cluster_inds) + self.possible_clusters_by_channel.append(cluster_inds) - d["possible_clusters_by_channel"] = possible_clusters_by_channel - d["possible_shifts"] = np.arange(-d["sample_shift"], d["sample_shift"] + 1, dtype="int64") + self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") - return d + self.num_peeler_loop = num_peeler_loop + self.num_template_try = num_template_try - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs + self.margin = max(self.nbefore, self.nafter) * 2 - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs + def get_trace_margin(self): + return self.margin - @classmethod - def get_margin(cls, recording, kwargs): - margin = 2 * (kwargs["nbefore"] + kwargs["nafter"]) - return margin - - @classmethod - def main_function(cls, traces, d): + def compute_matching(self, traces, start_frame, end_frame, segment_index): traces = traces.copy() all_spikes = [] level = 0 while True: - spikes = _tdc_find_spikes(traces, d, level=level) + # spikes = _tdc_find_spikes(traces, d, level=level) + spikes = self._find_spikes_one_level(traces, level=level) keep = spikes["cluster_index"] >= 0 if not np.any(keep): @@ -203,7 +173,7 @@ def main_function(cls, traces, d): level += 1 - if level == d["num_peeler_loop"]: + if level == self.num_peeler_loop: break if len(all_spikes) > 0: @@ -211,139 +181,131 @@ def main_function(cls, traces, d): order = np.argsort(all_spikes["sample_index"]) all_spikes = all_spikes[order] else: - all_spikes = np.zeros(0, dtype=spike_dtype) + all_spikes = np.zeros(0, dtype=_base_matching_dtype) return all_spikes + def _find_spikes_one_level(self, traces, level=0): -def _tdc_find_spikes(traces, d, level=0): - peak_sign = d["peak_sign"] - templates = d["templates"] - templates_short = d["templates_short"] - templates_array = templates.get_dense_templates() - - margin = d["margin"] - possible_clusters_by_channel = d["possible_clusters_by_channel"] - - peak_traces = traces[margin // 2 : -margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, peak_sign, d["abs_thresholds"], d["peak_shift"], d["neighbours_mask"] - ) - peak_sample_ind += margin // 2 - - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] - order = np.argsort(np.abs(peak_amplitude))[::-1] - peak_sample_ind = peak_sample_ind[order] - peak_chan_ind = peak_chan_ind[order] - - spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype) - spikes["sample_index"] = peak_sample_ind - spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template - - possible_shifts = d["possible_shifts"] - distances_shift = np.zeros(possible_shifts.size) - - for i in range(peak_sample_ind.size): - sample_index = peak_sample_ind[i] - - chan_ind = peak_chan_ind[i] - possible_clusters = possible_clusters_by_channel[chan_ind] - - if possible_clusters.size > 0: - # ~ s0 = sample_index - d['nbefore'] - # ~ s1 = sample_index + d['nafter'] - - # ~ wf = traces[s0:s1, :] - - s0 = sample_index - d["nbefore_short"] - s1 = sample_index + d["nafter_short"] - wf_short = traces[s0:s1, :] - - ## pure numpy with cluster spasity - # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) - - ## pure numpy with cluster+channel spasity - # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) - # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) - - ## numba with cluster+channel spasity - union_channels = np.any(d["template_sparsity"][possible_clusters, :], axis=0) - # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) - distances = numba_sparse_dist(wf_short, templates_short, union_channels, possible_clusters) - - # DEBUG - # ~ ind = np.argmin(distances) - # ~ cluster_index = possible_clusters[ind] - - for ind in np.argsort(distances)[: d["num_template_try"]]: - cluster_index = possible_clusters[ind] - - chan_sparsity = d["template_sparsity"][cluster_index, :] - template_sparse = templates_array[cluster_index, :, :][:, chan_sparsity] - - # find best shift - - ## pure numpy version - # for s, shift in enumerate(possible_shifts): - # wf_shift = traces[s0 + shift: s1 + shift, chan_sparsity] - # distances_shift[s] = np.sum((template_sparse - wf_shift)**2) - # ind_shift = np.argmin(distances_shift) - # shift = possible_shifts[ind_shift] - - ## numba version - numba_best_shift( - traces, - templates_array[cluster_index, :, :], - sample_index, - d["nbefore"], - possible_shifts, - distances_shift, - chan_sparsity, - ) - ind_shift = np.argmin(distances_shift) - shift = possible_shifts[ind_shift] - - sample_index = sample_index + shift - s0 = sample_index - d["nbefore"] - s1 = sample_index + d["nafter"] - wf_sparse = traces[s0:s1, chan_sparsity] - - # accept or not - - centered = wf_sparse - template_sparse - accepted = True - for other_ind, other_vector in d["closest_units"][cluster_index]: - v = np.sum(centered * other_vector) - if np.abs(v) > 0.5: - accepted = False + peak_traces = traces[self.margin // 2 : -self.margin // 2, :] + peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( + peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask + ) + peak_sample_ind += self.margin // 2 + + peak_amplitude = traces[peak_sample_ind, peak_chan_ind] + order = np.argsort(np.abs(peak_amplitude))[::-1] + peak_sample_ind = peak_sample_ind[order] + peak_chan_ind = peak_chan_ind[order] + + spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) + spikes["sample_index"] = peak_sample_ind + spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template + + possible_shifts = self.possible_shifts + distances_shift = np.zeros(possible_shifts.size) + + for i in range(peak_sample_ind.size): + sample_index = peak_sample_ind[i] + + chan_ind = peak_chan_ind[i] + possible_clusters = self.possible_clusters_by_channel[chan_ind] + + if possible_clusters.size > 0: + # ~ s0 = sample_index - d['nbefore'] + # ~ s1 = sample_index + d['nafter'] + + # ~ wf = traces[s0:s1, :] + + s0 = sample_index - self.nbefore_short + s1 = sample_index + self.nafter_short + wf_short = traces[s0:s1, :] + + ## pure numpy with cluster spasity + # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) + + ## pure numpy with cluster+channel spasity + # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) + # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) + + ## numba with cluster+channel spasity + union_channels = np.any(self.template_sparsity[possible_clusters, :], axis=0) + # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) + distances = numba_sparse_dist(wf_short, self.templates_short, union_channels, possible_clusters) + + # DEBUG + # ~ ind = np.argmin(distances) + # ~ cluster_index = possible_clusters[ind] + + for ind in np.argsort(distances)[: self.num_template_try]: + cluster_index = possible_clusters[ind] + + chan_sparsity = self.template_sparsity[cluster_index, :] + template_sparse = self.templates_array[cluster_index, :, :][:, chan_sparsity] + + # find best shift + + ## pure numpy version + # for s, shift in enumerate(possible_shifts): + # wf_shift = traces[s0 + shift: s1 + shift, chan_sparsity] + # distances_shift[s] = np.sum((template_sparse - wf_shift)**2) + # ind_shift = np.argmin(distances_shift) + # shift = possible_shifts[ind_shift] + + ## numba version + numba_best_shift( + traces, + self.templates_array[cluster_index, :, :], + sample_index, + self.nbefore, + possible_shifts, + distances_shift, + chan_sparsity, + ) + ind_shift = np.argmin(distances_shift) + shift = possible_shifts[ind_shift] + + sample_index = sample_index + shift + s0 = sample_index - self.nbefore + s1 = sample_index + self.nafter + wf_sparse = traces[s0:s1, chan_sparsity] + + # accept or not + + centered = wf_sparse - template_sparse + accepted = True + for other_ind, other_vector in self.closest_units[cluster_index]: + v = np.sum(centered * other_vector) + if np.abs(v) > 0.5: + accepted = False + break + + if accepted: + # ~ if ind != np.argsort(distances)[0]: + # ~ print('not first one', np.argsort(distances), ind) break if accepted: - # ~ if ind != np.argsort(distances)[0]: - # ~ print('not first one', np.argsort(distances), ind) - break + amplitude = 1.0 - if accepted: - amplitude = 1.0 + # remove template + template = self.templates_array[cluster_index, :, :] + s0 = sample_index - self.nbefore + s1 = sample_index + self.nafter + traces[s0:s1, :] -= template * amplitude - # remove template - template = templates_array[cluster_index, :, :] - s0 = sample_index - d["nbefore"] - s1 = sample_index + d["nafter"] - traces[s0:s1, :] -= template * amplitude + else: + cluster_index = -1 + amplitude = 0.0 else: cluster_index = -1 amplitude = 0.0 - else: - cluster_index = -1 - amplitude = 0.0 - - spikes["cluster_index"][i] = cluster_index - spikes["amplitude"][i] = amplitude + spikes["cluster_index"][i] = cluster_index + spikes["amplitude"][i] = amplitude - return spikes + return spikes if HAVE_NUMBA: diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 99de6fcd4e..2531a922da 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -4,7 +4,8 @@ from dataclasses import dataclass from typing import List, Tuple, Optional -from .main import BaseTemplateMatchingEngine + +from .base import BaseTemplateMatching, _base_matching_dtype from spikeinterface.core.template import Templates @@ -197,8 +198,9 @@ def from_parameters_and_templates(cls, params, templates): return template_meta +# important : this is differents from the spikeinterface.core.Sparsity @dataclass -class Sparsity: +class WobbleSparsity: """Variables that describe channel sparsity. Parameters @@ -226,7 +228,7 @@ def from_parameters_and_templates(cls, params, templates): Returns ------- - sparsity : Sparsity + sparsity : WobbleSparsity Dataclass object for aggregating channel sparsity variables together. """ visible_channels = np.ptp(templates, axis=1) > params.visibility_threshold @@ -250,7 +252,7 @@ def from_templates(cls, params, templates): Returns ------- - sparsity : Sparsity + sparsity : WobbleSparsity Dataclass object for aggregating channel sparsity variables together. """ visible_channels = templates.sparsity.mask @@ -297,7 +299,7 @@ def __post_init__(self): self.temporal, self.singular, self.spatial, self.temporal_jittered = self.compressed_templates -class WobbleMatch(BaseTemplateMatchingEngine): +class WobbleMatch(BaseTemplateMatching): """Template matching method from the Paninski lab. Templates are jittered or "wobbled" in time and amplitude to capture variability in spike amplitude and @@ -331,53 +333,30 @@ class WobbleMatch(BaseTemplateMatchingEngine): - "peaks" are considered spikes if their amplitude clears the threshold parameter """ - default_params = { - "templates": None, - } - spike_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("cluster_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), - ] - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - """Initialize the objective and precompute various useful objects. + # default_params = { + # "templates": None, + # } - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object. - kwargs : dict - Keyword arguments for matching method. - - Returns - ------- - d : dict - Updated Keyword arguments. - """ - d = cls.default_params.copy() + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + parameters={}, + ): - required_kwargs_keys = ["templates"] - for required_key in required_kwargs_keys: - assert required_key in kwargs, f"`{required_key}` is a required key in the kwargs" + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - parameters = kwargs.get("parameters", {}) - templates = kwargs["templates"] - assert isinstance(templates, Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) templates_array = templates.get_dense_templates().astype(np.float32, casting="safe") # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) template_meta = TemplateMetadata.from_parameters_and_templates(params, templates_array) if not templates.are_templates_sparse(): - sparsity = Sparsity.from_parameters_and_templates(params, templates_array) + sparsity = WobbleSparsity.from_parameters_and_templates(params, templates_array) else: - sparsity = Sparsity.from_templates(params, templates) + sparsity = WobbleSparsity.from_templates(params, templates) # Perform initial computations on templates necessary for computing the objective sparse_templates = np.where(sparsity.visible_channels[:, np.newaxis, :], templates_array, 0) @@ -394,84 +373,47 @@ def initialize_and_check_kwargs(cls, recording, kwargs): norm_squared=norm_squared, ) - # Pack initial data into kwargs - kwargs["params"] = params - kwargs["template_meta"] = template_meta - kwargs["sparsity"] = sparsity - kwargs["template_data"] = template_data - kwargs["nbefore"] = templates.nbefore - kwargs["nafter"] = templates.nafter - d.update(kwargs) - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - # This function does nothing without a waveform extractor -- candidate for refactor - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - # This function does nothing without a waveform extractor -- candidate for refactor - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - """Get margin for chunking recording. + self.params = params + self.template_meta = template_meta + self.sparsity = sparsity + self.template_data = template_data + self.nbefore = templates.nbefore + self.nafter = templates.nafter - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object. - kwargs : dict - Keyword arguments for matching method. - - Returns - ------- - margin : int - Buffer in samples on each side of a chunk. - """ - buffer_ms = 10 - # margin = int(buffer_ms*1e-3 * recording.sampling_frequency) - margin = 300 # To ensure equivalence with spike-psvae version of the algorithm - return margin + # buffer_ms = 10 + # self.margin = int(buffer_ms*1e-3 * recording.sampling_frequency) + self.margin = 300 # To ensure equivalence with spike-psvae version of the algorithm - @classmethod - def main_function(cls, traces, method_kwargs): - """Detect spikes in traces using the template matching algorithm. + def get_trace_margin(self): + return self.margin - Parameters - ---------- - traces : ndarray (chunk_len + 2*margin, num_channels) - Voltage traces for a chunk of the recording. - method_kwargs : dict - Keyword arguments for matching method. + def compute_matching(self, traces, start_frame, end_frame, segment_index): - Returns - ------- - spikes : ndarray (num_spikes,) - Resulting spike train. - """ # Unpack method_kwargs - nbefore, nafter = method_kwargs["nbefore"], method_kwargs["nafter"] - template_meta = method_kwargs["template_meta"] - params = method_kwargs["params"] - sparsity = method_kwargs["sparsity"] - template_data = method_kwargs["template_data"] + # nbefore, nafter = method_kwargs["nbefore"], method_kwargs["nafter"] + # template_meta = method_kwargs["template_meta"] + # params = method_kwargs["params"] + # sparsity = method_kwargs["sparsity"] + # template_data = method_kwargs["template_data"] # Check traces assert traces.dtype == np.float32, "traces must be specified as np.float32" # Compute objective - objective = compute_objective(traces, template_data, params.approx_rank) - objective_normalized = 2 * objective - template_data.norm_squared[:, np.newaxis] + objective = compute_objective(traces, self.template_data, self.params.approx_rank) + objective_normalized = 2 * objective - self.template_data.norm_squared[:, np.newaxis] # Compute spike train spike_trains, scalings, distance_metrics = [], [], [] - for i in range(params.max_iter): + for i in range(self.params.max_iter): # find peaks - spike_train, scaling, distance_metric = cls.find_peaks( - objective, objective_normalized, np.array(spike_trains), params, template_data, template_meta + spike_train, scaling, distance_metric = self.find_peaks( + objective, + objective_normalized, + np.array(spike_trains), + self.params, + self.template_data, + self.template_meta, ) if len(spike_train) == 0: break @@ -482,15 +424,22 @@ def main_function(cls, traces, method_kwargs): distance_metrics.extend(list(distance_metric)) # subtract newly detected spike train from traces (via the objective) - objective, objective_normalized = cls.subtract_spike_train( - spike_train, scaling, template_data, objective, objective_normalized, params, template_meta, sparsity + objective, objective_normalized = self.subtract_spike_train( + spike_train, + scaling, + self.template_data, + objective, + objective_normalized, + self.params, + self.template_meta, + self.sparsity, ) spike_train = np.array(spike_trains) scalings = np.array(scalings) distance_metric = np.array(distance_metrics) if len(spike_train) == 0: # no spikes found - return np.zeros(0, dtype=cls.spike_dtype) + return np.zeros(0, dtype=_base_matching_dtype) # order spike times index = np.argsort(spike_train[:, 0]) @@ -499,8 +448,8 @@ def main_function(cls, traces, method_kwargs): distance_metric = distance_metric[index] # adjust spike_train - spike_train[:, 0] += nbefore # beginning of template --> center of template - spike_train[:, 1] //= params.jitter_factor # jittered_index --> template_index + spike_train[:, 0] += self.nbefore # beginning of template --> center of template + spike_train[:, 1] //= self.params.jitter_factor # jittered_index --> template_index # TODO : Benchmark spike amplitudes # Find spike amplitudes / channels @@ -512,7 +461,7 @@ def main_function(cls, traces, method_kwargs): channel_inds.append(best_ch) # assign result to spikes array - spikes = np.zeros(spike_train.shape[0], dtype=cls.spike_dtype) + spikes = np.zeros(spike_train.shape[0], dtype=_base_matching_dtype) spikes["sample_index"] = spike_train[:, 0] spikes["cluster_index"] = spike_train[:, 1] spikes["channel_index"] = channel_inds @@ -622,7 +571,7 @@ def subtract_spike_train( Dataclass object for aggregating the parameters together. template_meta : TemplateMetadata Dataclass object for aggregating template metadata together. - sparsity : Sparsity + sparsity : WobbleSparsity Dataclass object for aggregating channel sparsity variables together. Returns diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index dab19809be..cbf1d29932 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -10,6 +10,7 @@ job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) +# job_kwargs = dict(n_jobs=1, chunk_duration="500ms", progress_bar=True) def get_sorting_analyzer(): @@ -40,19 +41,25 @@ def test_find_spikes_from_templates(method, sorting_analyzer): noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() # sorting_analyzer - method_kwargs_all = {"templates": templates, "noise_levels": noise_levels} + method_kwargs_all = { + "templates": templates, + } method_kwargs = {} + if method in ("naive", "tdc-peeler", "circus"): + method_kwargs["noise_levels"] = noise_levels + # method_kwargs["wobble"] = { # "templates": waveform_extractor.get_all_templates(), # "nbefore": waveform_extractor.nbefore, # "nafter": waveform_extractor.nafter, # } - sampling_frequency = recording.get_sampling_frequency() + method_kwargs.update(method_kwargs_all) + spikes, info = find_spikes_from_templates( + recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + ) - method_kwargs_ = method_kwargs.get(method, {}) - method_kwargs_.update(method_kwargs_all) - spikes = find_spikes_from_templates(recording, method=method, method_kwargs=method_kwargs_, **job_kwargs) + # print(info) # DEBUG = True @@ -65,15 +72,15 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # gt_sorting = sorting_analyzer.sorting - # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) + # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency) - # metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) + # ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) # fig, ax = plt.subplots() # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) # si.plot_agreement_matrix(comp, ax=ax) # ax.set_title(method) - # plt.show() + # plt.show() if __name__ == "__main__": @@ -81,6 +88,6 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # method = "naive" # method = "tdc-peeler" # method = "circus" - # method = "circus-omp-svd" - method = "wobble" + method = "circus-omp-svd" + # method = "wobble" test_find_spikes_from_templates(method, sorting_analyzer) diff --git a/src/spikeinterface/sortingcomponents/tests/test_wobble.py b/src/spikeinterface/sortingcomponents/tests/test_wobble.py index 5e6be02409..d6d1e1e0b9 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_wobble.py +++ b/src/spikeinterface/sortingcomponents/tests/test_wobble.py @@ -143,7 +143,7 @@ def test_convolve_templates(): ) unit_overlap = unit_overlap > 0 unit_overlap = np.repeat(unit_overlap, jitter_factor, axis=0) - sparsity = wobble.Sparsity(visible_channels, unit_overlap) + sparsity = wobble.WobbleSparsity(visible_channels, unit_overlap) # Act: run convolve_templates pairwise_convolution = wobble.convolve_templates(