diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index aaca718096..e12d83810a 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -30,7 +30,7 @@ Here, we present a MATLAB code that creates a random dataset and writes it to a Loading Data in SpikeInterface ------------------------------ -After executing the above MATLAB code, a binary file named `your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path. +After executing the above MATLAB code, a binary file named :code:`your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path. Use the following Python script to load the binary data into SpikeInterface: @@ -55,7 +55,7 @@ Use the following Python script to load the binary data into SpikeInterface: # Load data using SpikeInterface recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, - num_channels=num_channels, dtype=dtype) + num_channels=num_channels, dtype=dtype) # Confirm that the data was loaded correctly by comparing the data shapes and see they match the MATLAB data print(recording.get_num_frames(), recording.get_num_channels()) @@ -65,18 +65,18 @@ Follow the steps above to seamlessly import your MATLAB data into SpikeInterface Common Pitfalls & Tips ---------------------- -1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use `time_axis=1` in `si.read_binary()`. +1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use :code:`time_axis=1` in :code:`si.read_binary()`. 2. **File Path**: Always double-check the Python file path. 3. **Data Type Consistency**: Ensure data types between MATLAB and Python are consistent. MATLAB's `double` is equivalent to Numpy's `float64`. 4. **Sampling Frequency**: Set the appropriate sampling frequency in Hz for SpikeInterface. -5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's [Numpy for MATLAB Users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) guide. +5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's `Numpy for MATLAB Users `_ guide. Using gains and offsets for integer data ---------------------------------------- Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units, you can apply a gain and an offset. -In SpikeInterface, you can use the `gain_to_uV` and `offset_to_uV` parameters, since traces are handled in microvolts (uV). Both parameters can be integrated into the `read_binary` function. -If your data in MATLAB is stored as `int16`, and you know the gain and offset, you can use the following code to load the data: +In SpikeInterface, you can use the :code:`gain_to_uV` and :code:`offset_to_uV` parameters, since traces are handled in microvolts (uV). Both parameters can be integrated into the :code:`read_binary` function. +If your data in MATLAB is stored as :code:`int16`, and you know the gain and offset, you can use the following code to load the data: .. code-block:: python @@ -90,7 +90,8 @@ If your data in MATLAB is stored as `int16`, and you know the gain and offset, y num_channels=num_channels, dtype=dtype_int, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV) - recording.get_traces(return_scaled=True) # Return traces in micro volts (uV) + recording.get_traces() # Return traces in original units [type: int] + recording.get_traces(return_scaled=True) # Return traces in micro volts (uV) [type: float] This will equip your recording object with capabilities to convert the data to float values in uV using the :code:`get_traces()` method with the :code:`return_scaled` parameter set to :code:`True`. diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index b40b998103..31241a4147 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -40,7 +40,6 @@ def __init__( sampling_frequency: float | None = None, session_info_file_path: str | Path | None = None, spikes_matfile_path: str | Path | None = None, - session_info_matfile_path: str | Path | None = None, ): try: from pymatreader import read_mat @@ -67,26 +66,6 @@ def __init__( ) file_path = spikes_matfile_path if file_path is None else file_path - if session_info_matfile_path is not None: - # Raise an error if the warning period has expired - deprecation_issued = datetime.datetime(2023, 4, 1) - deprecation_deadline = deprecation_issued + datetime.timedelta(days=180) - if datetime.datetime.now() > deprecation_deadline: - raise ValueError( - "The session_info_matfile_path argument is no longer supported in. Use session_info_file_path instead." - ) - - # Otherwise, issue a DeprecationWarning - else: - warnings.warn( - "The session_info_matfile_path argument is deprecated and will be removed in six months. " - "Use session_info_file_path instead.", - DeprecationWarning, - ) - session_info_file_path = ( - session_info_matfile_path if session_info_file_path is None else session_info_file_path - ) - self.spikes_cellinfo_path = Path(file_path) self.session_path = self.spikes_cellinfo_path.parent self.session_id = self.spikes_cellinfo_path.stem.split(".")[0] diff --git a/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py b/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py index 35de8a23e2..c4c8d0c993 100644 --- a/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py +++ b/src/spikeinterface/extractors/tests/test_cellexplorerextractor.py @@ -26,7 +26,7 @@ class CellExplorerSortingTest(SortingCommonTestSuite, unittest.TestCase): ( "cellexplorer/dataset_2/20170504_396um_0um_merge.spikes.cellinfo.mat", { - "session_info_matfile_path": local_folder + "session_info_file_path": local_folder / "cellexplorer/dataset_2/20170504_396um_0um_merge.sessionInfo.mat" }, ), diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 66c008c3fc..710c4f76f4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -20,7 +20,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, "filtering": {"dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, @@ -151,7 +151,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( - recording_f, method="circus-omp", method_kwargs=matching_params, **matching_job_params + recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params ) if verbose: diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index b87bbc7cee..28a1a63065 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -539,6 +539,7 @@ def remove_duplicates_via_matching( method_kwargs={}, job_kwargs={}, tmp_folder=None, + method="circus-omp-svd", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface import get_noise_levels @@ -546,7 +547,6 @@ def remove_duplicates_via_matching( from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms from spikeinterface.core import get_global_tmp_folder - from spikeinterface.sortingcomponents.matching.circus import get_scipy_shape import string, random, shutil, os from pathlib import Path @@ -591,19 +591,12 @@ def remove_duplicates_via_matching( chunk_size = duration + 3 * margin - dummy_filter = np.empty((num_chans, duration), dtype=np.float32) - dummy_traces = np.empty((num_chans, chunk_size), dtype=np.float32) - - fshape, axes = get_scipy_shape(dummy_filter, dummy_traces, axes=1) - method_kwargs.update( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], "omp_min_sps": 0.1, - "templates": None, - "overlaps": None, } ) @@ -618,16 +611,31 @@ def remove_duplicates_via_matching( method_kwargs.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method="circus-omp", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs - ) - method_kwargs.update( - { - "overlaps": computed["overlaps"], - "templates": computed["templates"], - "norms": computed["norms"], - "sparsities": computed["sparsities"], - } + sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs ) + if method == "circus-omp-svd": + method_kwargs.update( + { + "overlaps": computed["overlaps"], + "templates": computed["templates"], + "norms": computed["norms"], + "temporal": computed["temporal"], + "spatial": computed["spatial"], + "singular": computed["singular"], + "units_overlaps": computed["units_overlaps"], + "unit_overlaps_indices": computed["unit_overlaps_indices"], + "sparsity_mask": computed["sparsity_mask"], + } + ) + elif method == "circus-omp": + method_kwargs.update( + { + "overlaps": computed["overlaps"], + "templates": computed["templates"], + "norms": computed["norms"], + "sparsities": computed["sparsities"], + } + ) valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) if np.sum(valid) > 0: if np.sum(valid) == 1: diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index be8ecd6702..864548e7d4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -18,7 +18,9 @@ from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms -from spikeinterface.sortingcomponents.features_from_peaks import compute_features_from_peaks, EnergyFeature +from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser +from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature +from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever class RandomProjectionClustering: @@ -34,17 +36,17 @@ class RandomProjectionClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, + "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, - "max_spikes_per_unit": 200, "selection_method": "closest_to_centroid", - "nb_projections": {"ptp": 8, "energy": 2}, - "ms_before": 1.5, - "ms_after": 1.5, + "nb_projections": 10, + "ms_before": 1, + "ms_after": 1, "random_seed": 42, - "shared_memory": False, - "min_values": {"ptp": 0, "energy": 0}, + "smoothing_kwargs": {"window_length_ms": 1}, + "shared_memory": True, "tmp_folder": None, - "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "10M", "verbose": True, "progress_bar": True}, + "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } @classmethod @@ -74,50 +76,60 @@ def main_function(cls, recording, peaks, params): np.random.seed(d["random_seed"]) - features_params = {} - features_list = [] - - noise_snippets = None - - for proj_type in ["ptp", "energy"]: - if d["nb_projections"][proj_type] > 0: - features_list += [f"random_projections_{proj_type}"] - - if d["min_values"][proj_type] == "auto": - if noise_snippets is None: - num_segments = recording.get_num_segments() - num_chunks = 3 * d["max_spikes_per_unit"] // num_segments - noise_snippets = get_random_data_chunks( - recording, num_chunks_per_segment=num_chunks, chunk_size=num_samples, seed=42 - ) - noise_snippets = noise_snippets.reshape(num_chunks, num_samples, num_chans) - - if proj_type == "energy": - data = np.linalg.norm(noise_snippets, axis=1) - min_values = np.median(data, axis=0) - elif proj_type == "ptp": - data = np.ptp(noise_snippets, axis=1) - min_values = np.median(data, axis=0) - elif d["min_values"][proj_type] > 0: - min_values = d["min_values"][proj_type] - else: - min_values = None - - projections = np.random.randn(num_chans, d["nb_projections"][proj_type]) - features_params[f"random_projections_{proj_type}"] = { - "radius_um": params["radius_um"], - "projections": projections, - "min_values": min_values, - } - - features_data = compute_features_from_peaks( - recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"] + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name + else: + tmp_folder = Path(params["tmp_folder"]).absolute() + + ### Then we extract the SVD features + node0 = PeakRetriever(recording, peaks) + node1 = ExtractDenseWaveforms( + recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"] ) - if len(features_data) > 1: - hdbscan_data = np.hstack((features_data[0], features_data[1])) - else: - hdbscan_data = features_data[0] + node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) + + projections = np.random.randn(num_chans, d["nb_projections"]) + projections -= projections.mean(0) + projections /= projections.std(0) + + nbefore = int(params["ms_before"] * fs / 1000) + nafter = int(params["ms_after"] * fs / 1000) + nsamples = nbefore + nafter + + import scipy + + x = np.random.randn(100, nsamples, num_chans).astype(np.float32) + x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1) + + ptps = np.ptp(x, axis=1) + a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000)) + ydata = np.cumsum(a) / a.sum() + xdata = b[1:] + + from scipy.optimize import curve_fit + + def sigmoid(x, L, x0, k, b): + y = L / (1 + np.exp(-k * (x - x0))) + b + return y + + p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess + popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) + + node3 = RandomProjectionsFeature( + recording, + parents=[node0, node2], + return_output=True, + projections=projections, + radius_um=params["radius_um"], + ) + + pipeline_nodes = [node0, node1, node2, node3] + + hdbscan_data = run_node_pipeline( + recording, pipeline_nodes, params["job_kwargs"], job_name="extracting features" + ) import sklearn @@ -132,7 +144,7 @@ def main_function(cls, recording, peaks, params): all_indices = np.arange(0, peak_labels.size) - max_spikes = params["max_spikes_per_unit"] + max_spikes = params["waveforms"]["max_spikes_per_unit"] selection_method = params["selection_method"] for unit_ind in labels: diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index bd82ffa0a6..b534c2356d 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -184,41 +184,44 @@ def __init__( return_output=True, parents=None, projections=None, - radius_um=150.0, - min_values=None, + sigmoid=None, + radius_um=None, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.projections = projections - self.radius_um = radius_um - self.min_values = min_values - + self.sigmoid = sigmoid self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance < radius_um - - self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values)) - + self.radius_um = radius_um + self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um)) self._dtype = recording.get_dtype() def get_dtype(self): return self._dtype + def _sigmoid(self, x): + L, x0, k, b = self.sigmoid + y = L / (1 + np.exp(-k * (x - x0))) + b + return y + def compute(self, traces, peaks, waveforms): all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype) + for main_chan in np.unique(peaks["channel_index"]): (idx,) = np.nonzero(peaks["channel_index"] == main_chan) (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] - wf_ptp = (waveforms[idx][:, :, chan_inds]).ptp(axis=1) + wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) - if self.min_values is not None: - wf_ptp = (wf_ptp / self.min_values[chan_inds]) ** 4 + if self.sigmoid is not None: + wf_ptp *= self._sigmoid(wf_ptp) denom = np.sum(wf_ptp, axis=1) mask = denom != 0 - all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis]) + return all_projections diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index a19e7b71b5..358691cd25 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -33,9 +33,6 @@ from .main import BaseTemplateMatchingEngine -################# -# Circus peeler # -################# from scipy.fft._helper import _init_nd_shape_and_axes @@ -478,6 +475,366 @@ def main_function(cls, traces, d): return spikes +class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): + """ + Orthogonal Matching Pursuit inspired from Spyking Circus sorter + + https://elifesciences.org/articles/34518 + + This is an Orthogonal Template Matching algorithm. For speed and + memory optimization, templates are automatically sparsified. Signal + is convolved with the templates, and as long as some scalar products + are higher than a given threshold, we use a Cholesky decomposition + to compute the optimal amplitudes needed to reconstruct the signal. + + IMPORTANT NOTE: small chunks are more efficient for such Peeler, + consider using 100ms chunk + + Parameters + ---------- + amplitude: tuple + (Minimal, Maximal) amplitudes allowed for every template + omp_min_sps: float + Stopping criteria of the OMP algorithm, in percentage of the norm + noise_levels: array + The noise levels, for every channels. If None, they will be automatically + computed + random_chunk_kwargs: dict + Parameters for computing noise levels, if not provided (sub optimal) + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. + ----- + """ + + _default_params = { + "amplitudes": [0.6, 2], + "omp_min_sps": 0.1, + "waveform_extractor": None, + "random_chunk_kwargs": {}, + "noise_levels": None, + "rank": 5, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, + "ignored_ids": [], + "vicinity": 0, + } + + @classmethod + def _prepare_templates(cls, d): + waveform_extractor = d["waveform_extractor"] + num_templates = len(d["waveform_extractor"].sorting.unit_ids) + + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask + + d["sparsity_mask"] = sparsity + units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) + d["units_overlaps"] = units_overlaps > 0 + d["unit_overlaps_indices"] = {} + for i in range(num_templates): + (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) + + templates = waveform_extractor.get_all_templates(mode="median").copy() + + # First, we set masked channels to 0 + for count in range(num_templates): + templates[count][:, ~d["sparsity_mask"][count]] = 0 + + # Then we keep only the strongest components + rank = d["rank"] + temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) + d["temporal"] = temporal[:, :, :rank] + d["singular"] = singular[:, :rank] + d["spatial"] = spatial[:, :rank, :] + + # We reconstruct the approximated templates + templates = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) + + d["templates"] = {} + d["norms"] = np.zeros(num_templates, dtype=np.float32) + + # And get the norms, saving compressed templates for CC matrix + for count in range(num_templates): + template = templates[count][:, d["sparsity_mask"][count]] + d["norms"][count] = np.linalg.norm(template) + d["templates"][count] = template / d["norms"][count] + + d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] + d["temporal"] = np.flip(d["temporal"], axis=1) + + d["overlaps"] = [] + for i in range(num_templates): + num_overlaps = np.sum(d["units_overlaps"][i]) + overlapping_units = np.where(d["units_overlaps"][i])[0] + + # Reconstruct unit template from SVD Matrices + data = d["temporal"][i] * d["singular"][i][np.newaxis, :] + template_i = np.matmul(data, d["spatial"][i, :, :]) + template_i = np.flipud(template_i) + + unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32) + + for count, j in enumerate(overlapping_units): + overlapped_channels = d["sparsity_mask"][j] + visible_i = template_i[:, overlapped_channels] + + spatial_filters = d["spatial"][j, :, overlapped_channels] + spatially_filtered_template = np.matmul(visible_i, spatial_filters) + visible_i = spatially_filtered_template * d["singular"][j] + + for rank in range(visible_i.shape[1]): + unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full") + + d["overlaps"].append(unit_overlaps) + + d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) + d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) + d["singular"] = d["singular"].T[:, :, np.newaxis] + return d + + @classmethod + def initialize_and_check_kwargs(cls, recording, kwargs): + d = cls._default_params.copy() + d.update(kwargs) + + # assert isinstance(d['waveform_extractor'], WaveformExtractor) + + for v in ["omp_min_sps"]: + assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + + d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() + d["num_samples"] = d["waveform_extractor"].nsamples + d["nbefore"] = d["waveform_extractor"].nbefore + d["nafter"] = d["waveform_extractor"].nafter + d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["vicinity"] *= d["num_samples"] + + if d["noise_levels"] is None: + print("CircusOMPPeeler : noise should be computed outside") + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + + if "templates" not in d: + d = cls._prepare_templates(d) + else: + for key in [ + "norms", + "temporal", + "spatial", + "singular", + "units_overlaps", + "sparsity_mask", + "unit_overlaps_indices", + ]: + assert d[key] is not None, "If templates are provided, %d should also be there" % key + + d["num_templates"] = len(d["templates"]) + d["ignored_ids"] = np.array(d["ignored_ids"]) + + d["unit_overlaps_tables"] = {} + for i in range(d["num_templates"]): + d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) + d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) + + omp_min_sps = d["omp_min_sps"] + # d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + d["stop_criteria"] = omp_min_sps * np.maximum(d["norms"], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) + + return d + + @classmethod + def serialize_method_kwargs(cls, kwargs): + kwargs = dict(kwargs) + # remove waveform_extractor + kwargs.pop("waveform_extractor") + return kwargs + + @classmethod + def unserialize_in_worker(cls, kwargs): + return kwargs + + @classmethod + def get_margin(cls, recording, kwargs): + margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) + return margin + + @classmethod + def main_function(cls, traces, d): + templates = d["templates"] + num_templates = d["num_templates"] + num_channels = d["num_channels"] + num_samples = d["num_samples"] + overlaps = d["overlaps"] + norms = d["norms"] + nbefore = d["nbefore"] + nafter = d["nafter"] + omp_tol = np.finfo(np.float32).eps + num_samples = d["nafter"] + d["nbefore"] + neighbor_window = num_samples - 1 + min_amplitude, max_amplitude = d["amplitudes"] + ignored_ids = d["ignored_ids"] + stop_criteria = d["stop_criteria"][:, np.newaxis] + vicinity = d["vicinity"] + rank = d["rank"] + + num_timesteps = len(traces) + + num_peaks = num_timesteps - num_samples + 1 + conv_shape = (num_templates, num_peaks) + scalar_products = np.zeros(conv_shape, dtype=np.float32) + + # Filter using overlap-and-add convolution + if len(ignored_ids) > 0: + mask = ~np.isin(np.arange(num_templates), ignored_ids) + spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"][:, mask, :] + objective_by_rank = scipy.signal.oaconvolve( + scaled_filtered_data, d["temporal"][:, mask, :], axes=2, mode="valid" + ) + scalar_products[mask] += np.sum(objective_by_rank, axis=0) + scalar_products[ignored_ids] = -np.inf + else: + spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"] + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") + scalar_products += np.sum(objective_by_rank, axis=0) + + num_spikes = 0 + + spikes = np.empty(scalar_products.size, dtype=spike_dtype) + idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) + + M = np.zeros((num_templates, num_templates), dtype=np.float32) + + all_selections = np.empty((2, scalar_products.size), dtype=np.int32) + final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) + num_selection = 0 + + full_sps = scalar_products.copy() + + neighbors = {} + cached_overlaps = {} + + is_valid = scalar_products > stop_criteria + all_amplitudes = np.zeros(0, dtype=np.float32) + is_in_vicinity = np.zeros(0, dtype=np.int32) + + while np.any(is_valid): + best_amplitude_ind = scalar_products[is_valid].argmax() + best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) + + if num_selection > 0: + delta_t = selection[1] - peak_index + idx = np.where((delta_t < num_samples) & (delta_t > -num_samples))[0] + myline = neighbor_window + delta_t[idx] + myindices = selection[0, idx] + + local_overlaps = overlaps[best_cluster_ind] + overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] + table = d["unit_overlaps_tables"][best_cluster_ind] + + if num_selection == M.shape[0]: + Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) + Z[:num_selection, :num_selection] = M + M = Z + + mask = np.isin(myindices, overlapping_templates) + a, b = myindices[mask], myline[mask] + M[num_selection, idx[mask]] = local_overlaps[table[a], b] + + if vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 + else: + M[0, 0] = 1 + + all_selections[:, num_selection] = [best_cluster_ind, peak_index] + num_selection += 1 + + selection = all_selections[:, :num_selection] + res_sps = full_sps[selection[0], selection[1]] + + if True: # vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + all_amplitudes /= norms[selection[0]] + else: + # This is not working, need to figure out why + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] + + diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] + modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] + final_amplitudes[selection[0], selection[1]] = all_amplitudes + + for i in modified: + tmp_best, tmp_peak = selection[:, i] + diff_amp = diff_amplitudes[i] * norms[tmp_best] + + local_overlaps = overlaps[tmp_best] + overlapping_templates = d["units_overlaps"][tmp_best] + + if not tmp_peak in neighbors.keys(): + idx = [max(0, tmp_peak - neighbor_window), min(num_peaks, tmp_peak + num_samples)] + tdx = [neighbor_window + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak - 1] + neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} + + idx = neighbors[tmp_peak]["idx"] + tdx = neighbors[tmp_peak]["tdx"] + + to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] + scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add + + is_valid = scalar_products > stop_criteria + + is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) + valid_indices = np.where(is_valid) + + num_spikes = len(valid_indices[0]) + spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["channel_index"][:num_spikes] = 0 + spikes["cluster_index"][:num_spikes] = valid_indices[0] + spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] + + spikes = spikes[:num_spikes] + order = np.argsort(spikes["sample_index"]) + spikes = spikes[order] + + return spikes + + class CircusPeeler(BaseTemplateMatchingEngine): """ diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index bedc04a9d5..d982943126 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -1,6 +1,6 @@ from .naive import NaiveMatching from .tdc import TridesclousPeeler -from .circus import CircusPeeler, CircusOMPPeeler +from .circus import CircusPeeler, CircusOMPPeeler, CircusOMPSVDPeeler from .wobble import WobbleMatch matching_methods = { @@ -8,5 +8,6 @@ "tridesclous": TridesclousPeeler, "circus": CircusPeeler, "circus-omp": CircusOMPPeeler, + "circus-omp-svd": CircusOMPSVDPeeler, "wobble": WobbleMatch, }