From f06115c04ecbd4858ec582d5db0e4f99da9112ec Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 10 Jan 2024 21:29:07 +0100 Subject: [PATCH 001/192] Initial implementaion of SortingResult --- src/spikeinterface/core/__init__.py | 2 + src/spikeinterface/core/sortingresult.py | 979 ++++++++++++++++++ .../core/tests/test_sortingresult.py | 121 +++ 3 files changed, 1102 insertions(+) create mode 100644 src/spikeinterface/core/sortingresult.py create mode 100644 src/spikeinterface/core/tests/test_sortingresult.py diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index f490887334..6c811d8fff 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -132,3 +132,5 @@ # channel sparsity from .sparsity import ChannelSparsity, compute_sparsity + +from .sortingresult import SortingResult diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py new file mode 100644 index 0000000000..f92fa214ed --- /dev/null +++ b/src/spikeinterface/core/sortingresult.py @@ -0,0 +1,979 @@ +from __future__ import annotations +from typing import Literal, Optional + +from pathlib import Path +import os +import json +import pickle +import weakref +import shutil + +import numpy as np + +import probeinterface + +from .baserecording import BaseRecording +from .basesorting import BaseSorting + +from .base import load_extractor +from .recording_tools import check_probe_do_not_overlap, get_rec_attributes +from .core_tools import check_json +from .numpyextractors import SharedMemorySorting +from .sparsity import ChannelSparsity +from .sortingfolder import NumpyFolderSorting + + +# TODO +# * make info.json that contain some version info of spikeinterface +# * same for zarr +# * sample spikes and propagate in compute with option + + + +# high level function +def start_sorting_result(sorting, recording, format="memory", folder=None, + sparse=True, sparsity=None, + # **kwargs + ): + """ + Create a SortingResult by pairing a Sorting and the corresponding Recording. + + This object will handle a list of ResultExtension for all the post processing steps like: waveforms, + templates, unit locations, spike locations, quality mertics ... + + This object will be also use used for ploting purpose. + + + Parameters + ---------- + sorting: Sorting + The sorting object + recording: Recording + The recording object + folder: str or Path or None, default: None + The folder where waveforms are cached + format: "memory | "binary_folder" | "zarr", default: "memory" + The mode to store waveforms. If "folder", waveforms are stored on disk in the specified folder. + The "folder" argument must be specified in case of mode "folder". + If "memory" is used, the waveforms are stored in RAM. Use this option carefully! + sparse: bool, default: True + If True, then a sparsity mask is computed usingthe `precompute_sparsity()` function is run using + a few spikes to get an estimate of dense templates to create a ChannelSparsity object. + Then, the sparsity will be propagated to all ResultExtention that handle sparsity (like wavforms, pca, ...) + sparsity: ChannelSparsity or None, default: None + The sparsity used to compute waveforms. If this is given, `sparse` is ignored. Default None. + sparsity_temp_folder: str or Path or None, default: None + If sparse is True, this is the temporary folder where the dense waveforms are temporarily saved. + If None, dense waveforms are extracted in memory in batches (which can be controlled by the `unit_batch_size` + parameter. With a large number of units (e.g., > 400), it is advisable to use a temporary folder. + num_spikes_for_sparsity: int, default: 100 + The number of spikes to use to estimate sparsity (if sparse=True). + unit_batch_size: int, default: 200 + The number of units to process at once when extracting dense waveforms (if sparse=True and sparsity_temp_folder + is None). + + sparsity kwargs: + {} + + + job kwargs: + {} + + + Returns + ------- + sorting_result: SortingResult + The SortingResult object + + Examples + -------- + >>> import spikeinterface as si + + >>> # Extract dense waveforms and save to disk with binary_folder format. + >>> sortres = si.start_sorting_result(sorting, recording, format="binary_folder", folder="/path/to_my/result") + + """ + + + # handle sparsity + if sparsity is not None: + assert isinstance(sparsity, ChannelSparsity), "'sparsity' must be a ChannelSparsity object" + unit_id_to_channel_ids = sparsity.unit_id_to_channel_ids + assert all(u in sorting.unit_ids for u in unit_id_to_channel_ids), "Invalid unit ids in sparsity" + for channels in unit_id_to_channel_ids.values(): + assert all(ch in recording.channel_ids for ch in channels), "Invalid channel ids in sparsity" + elif sparse: + # TODO + # raise NotImplementedError() + sparsity = None + # estimate_kwargs, job_kwargs = split_job_kwargs(kwargs) + # sparsity = precompute_sparsity( + # recording, + # sorting, + # ms_before=ms_before, + # ms_after=ms_after, + # num_spikes_for_sparsity=num_spikes_for_sparsity, + # unit_batch_size=unit_batch_size, + # temp_folder=sparsity_temp_folder, + # allow_unfiltered=allow_unfiltered, + # **estimate_kwargs, + # **job_kwargs, + # ) + else: + sparsity = None + + sorting_result = SortingResult.create( + sorting, recording, format=format, folder=folder, sparsity=sparsity) + + return sorting_result + +def load_sorting_result(folder, load_extensions=True, format="auto"): + """ + Load a SortingResult object from disk. + + Parameters + ---------- + folder : str or Path + The folder / zarr folder where the waveform extractor is stored + load_extensions : bool, default: True + Load all extensions or not. + format: "auto" | "binary_folder" | "zarr" + The format of the folder. + + Returns + ------- + sorting_result: SortingResult + The loaded SortingResult + + """ + + return SortingResult.load(folder, load_extensions=load_extensions, format=format) + + +class SortingResult: + """ + Class to make a pair of Recording-Sorting which will be used used for all post postprocessing, + visualization and quality metric computation. + + This internaly maintain a list of computed ResultExtention (waveform, pca, unit position, spike poisition, ...). + + This can live in memory and/or can be be persistent to disk in 2 internal formats (folder/json/npz or zarr). + + This handle unit sparsity that can be propagated to ResultExtention. + + This handle spike sampling that can be propagated to ResultExtention : work only on a subset of spikes. + + This internally save a copy of the Sorting and extract main recording attributes (without traces) so + the SortingResult object can be reload even if references to the original sorting and/or to the original recording + are lost. + """ + def __init__(self, sorting=None, recording=None, rec_attributes=None, format=None, sparsity=None): + # very fast init because checks are done in load and create + self.sorting = sorting + # self.recorsding will be a property + self._recording = recording + self.rec_attributes = rec_attributes + self.format = format + self.sparsity = sparsity + + # extensions are not loaded at init + self.extensions = dict() + + ## create and load zone + + @classmethod + def create(cls, + sorting: BaseSorting, + recording: BaseRecording, + format: Literal["memory", "binary_folder", "zarr", ] = "memory", + folder=None, + sparsity=None, + ): + # some checks + assert sorting.sampling_frequency == recording.sampling_frequency + # check that multiple probes are non-overlapping + all_probes = recording.get_probegroup().probes + check_probe_do_not_overlap(all_probes) + + if format == "memory": + rec_attributes = get_rec_attributes(recording) + rec_attributes["probegroup"] = recording.get_probegroup() + # a copy of sorting is created directly in shared memory format to avoid further duplication of spikes. + sorting_copy = SharedMemorySorting.from_sorting(sorting) + sortres = SortingResult(sorting=sorting_copy, recording=recording, rec_attributes=rec_attributes, format=format, sparsity=sparsity) + elif format == "binary_folder": + cls.create_binary_folder(folder, sorting, recording, sparsity, rec_attributes=None) + sortres = cls.load_from_binary_folder(folder, recording=recording) + elif format == "zarr": + cls.create_zarr(folder, sorting, recording, sparsity, rec_attributes=None) + sortres = cls.load_from_zarr(folder, recording=recording) + else: + raise ValueError("SortingResult.create: wrong format") + + return sortres + + @classmethod + def load(cls, folder, recording=None, load_extensions=True, format="auto"): + """ + Load folder or zarr. + The recording can be given if the recording location has changed. + Otherwise the recording is loaded when possible. + """ + folder = Path(folder) + assert folder.is_dir(), "Waveform folder does not exists" + if format == "auto": + # make better assumption and check for auto guess format + if folder.suffix == ".zarr": + format = "zarr" + else: + format = "binary_folder" + + if format == "binary_folder": + sortres = SortingResult.load_from_binary_folder(folder, recording=recording) + elif format == "zarr": + sortres = SortingResult.load_from_zarr(folder, recording=recording) + + + sortres.folder = folder + + if load_extensions: + sortres.load_all_saved_extension() + + return sortres + + @classmethod + def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attributes): + # used by create and save + + folder = Path(folder) + if folder.is_dir(): + raise ValueError(f"Folder already exists {folder}") + folder.mkdir(parents=True) + + # save a copy of the sorting + NumpyFolderSorting.write_sorting(sorting, folder / "sorting") + + # save recording and sorting provenance + if recording.check_serializability("json"): + recording.dump(folder / "recording.json", relative_to=folder) + elif recording.check_serializability("pickle"): + recording.dump(folder / "recording.pickle", relative_to=folder) + + if sorting.check_serializability("json"): + sorting.dump(folder / "sorting_provenance.json", relative_to=folder) + elif sorting.check_serializability("pickle"): + sorting.dump(folder / "sorting_provenance.pickle", relative_to=folder) + + # dump recording attributes + probegroup = None + rec_attributes_file = folder / "recording_info" / "recording_attributes.json" + rec_attributes_file.parent.mkdir() + if rec_attributes is None: + assert recording is not None + rec_attributes = get_rec_attributes(recording) + rec_attributes_file.write_text(json.dumps(check_json(rec_attributes), indent=4), encoding="utf8") + probegroup = recording.get_probegroup() + else: + rec_attributes_copy = rec_attributes.copy() + probegroup = rec_attributes_copy.pop("probegroup") + rec_attributes_file.write_text(json.dumps(check_json(rec_attributes_copy), indent=4), encoding="utf8") + + if probegroup is not None: + probegroup_file = folder / "recording_info" / "probegroup.json" + probeinterface.write_probeinterface(probegroup_file, probegroup) + + if sparsity is not None: + with open(folder / "sparsity.json", mode="w") as f: + json.dump(check_json(sparsity.to_dict()), f) + + @classmethod + def load_from_binary_folder(cls, folder, recording=None): + folder = Path(folder) + assert folder.is_dir(), f"This folder does not exists {folder}" + + # load internal sorting copy and make it sharedmem + sorting = SharedMemorySorting.from_sorting(NumpyFolderSorting(folder / "sorting")) + + # load recording if possible + if recording is None: + # try to load the recording if not provided + for type in ("json", "pickle"): + filename = folder / f"recording.{type}" + if filename.exists(): + try: + recording = load_extractor(filename, base_folder=folder) + break + except: + recording = None + else: + # TODO maybe maybe not??? : do we need to check attributes match internal rec_attributes + # Note this will make the loading too slow + pass + + # recording attributes + rec_attributes_file = folder / "recording_info" / "recording_attributes.json" + if not rec_attributes_file.exists(): + raise ValueError("This folder is not a SortingResult folder") + with open(rec_attributes_file, "r") as f: + rec_attributes = json.load(f) + # the probe is handle ouside the main json + probegroup_file = folder / "recording_info" / "probegroup.json" + print(probegroup_file, probegroup_file.is_file()) + if probegroup_file.is_file(): + rec_attributes["probegroup"] = probeinterface.read_probeinterface(probegroup_file) + else: + rec_attributes["probegroup"] = None + + # sparsity + sparsity_file = folder / "sparsity.json" + if sparsity_file.is_file(): + with open(sparsity_file, mode="r") as f: + sparsity = ChannelSparsity.from_dict(json.load(f)) + else: + sparsity = None + + sortres = SortingResult( + sorting=sorting, + recording=recording, + rec_attributes=rec_attributes, + format="binary_folder", + sparsity=sparsity) + + return sortres + + def _get_zarr_root(self, mode="r+"): + import zarr + zarr_root = zarr.open(self.folder, mode=mode) + return zarr_root + + @classmethod + def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): + raise NotImplementedError + + @classmethod + def load_from_zarr(cls, folder, recording=None): + raise NotImplementedError + + + def save_as( + self, folder=None, format="binary_folder", + ) -> "SortingResult": + """ + Save SortingResult object into another format. + Uselfull for memory to zarr or memory to binray. + + Note that the recording provenance or sorting provenance can be lost. + + Mainly propagate the copied sorting and recording property. + + Parameters + ---------- + folder : str or Path + The output waveform folder + format : "binary_folder" | "zarr", default: "binary_folder" + The backend to use for saving the waveforms + """ + + if self.has_recording(): + recording = self.recording + else: + recording = None + + # Note that the sorting is a copy we need to go back to the orginal sorting (if available) + sorting_provenance = self.get_sorting_provenance() + if sorting_provenance is None: + # if the original sorting objetc is not available anymore (kilosort folder deleted, ....), take the copy + sorting_provenance = self.sorting + + if format == "memory": + # This make a copy of actual SortingResult + # TODO + raise NotImplementedError + elif format == "binary_folder": + # create a new folder + SortingResult.create_binary_folder(folder, sorting_provenance, recording, self.sparsity, self.rec_attributes) + new_sortres = SortingResult.load_from_binary_folder(folder) + new_sortres.folder = folder + + elif format == "zarr": + # TODO + raise NotImplementedError + else: + raise ValueError("SortingResult.save: wrong format") + + # make a copy of extensions + for extension_name, extension in self.extensions.items(): + new_sortres.extensions[extension_name] = extension.copy(new_sortres) + + return new_sortres + + + def is_read_only(self) -> bool: + if self.format == "memory": + return False + return not os.access(self.folder, os.W_OK) + + + ## map attribute and property zone + + @property + def recording(self) -> BaseRecording: + if not self.has_recording(): + raise ValueError("SortingResult could not load the recording") + return self._recording + + @property + def channel_ids(self) -> np.ndarray: + return np.array(self.rec_attributes["channel_ids"]) + + @property + def sampling_frequency(self) -> float: + return self.sorting.get_sampling_frequency() + + @property + def unit_ids(self) -> np.ndarray: + return self.sorting.unit_ids + + def has_recording(self) -> bool: + return self._recording is not None + + def is_sparse(self) -> bool: + return self.sparsity is not None + + def get_sorting_provenance(self): + """ + Get the original sorting if possible otherwise return None + """ + if self.format == "memory": + # the orginal sorting provenance is not keps in that case + sorting_provenance = None + + elif self.format == "binary_folder": + for type in ("json", "pickle"): + filename = self.folder / f"sorting_provenance.{type}" + if filename.exists(): + try: + sorting_provenance = load_extractor(filename, base_folder=self.folder) + break + except: + sorting_provenance = None + + elif self.format == "zarr": + # TODO + raise NotImplementedError + + return sorting_provenance + + # def is_read_only(self) -> bool: + # return self._is_read_only + + def get_num_samples(self, segment_index: Optional[int] = None) -> int: + # we use self.sorting to check segment_index + segment_index = self.sorting._check_segment_index(segment_index) + return self.rec_attributes["num_samples"][segment_index] + + def get_total_samples(self) -> int: + s = 0 + for segment_index in range(self.get_num_segments()): + s += self.get_num_samples(segment_index) + return s + + def get_total_duration(self) -> float: + duration = self.get_total_samples() / self.sampling_frequency + return duration + + def get_num_channels(self) -> int: + return self.rec_attributes["num_channels"] + + def get_num_segments(self) -> int: + return self.sorting.get_num_segments() + + def get_probegroup(self): + return self.rec_attributes["probegroup"] + + def get_probe(self): + probegroup = self.get_probegroup() + assert len(probegroup.probes) == 1, "There are several probes. Use `get_probegroup()`" + return probegroup.probes[0] + + def get_channel_locations(self) -> np.ndarray: + # important note : contrary to recording + # this give all channel locations, so no kwargs like channel_ids and axes + all_probes = self.get_probegroup().probes + all_positions = np.vstack([probe.contact_positions for probe in all_probes]) + return all_positions + + def channel_ids_to_indices(self, channel_ids) -> np.ndarray: + all_channel_ids = self.rec_attributes["channel_ids"] + indices = np.array([all_channel_ids.index(id) for id in channel_ids], dtype=int) + return indices + + def __repr__(self) -> str: + clsname = self.__class__.__name__ + nseg = self.get_num_segments() + nchan = self.get_num_channels() + nunits = self.sorting.get_num_units() + txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" + if self.is_sparse(): + txt += " - sparse" + return txt + + ## extensions zone + def compute(self, extension_name, **params): + """ + Compute one extension + + Parameters + ---------- + extension_name + + **params + + Returns + ------- + sorting_result: SortingResult + The SortingResult object + + Examples + -------- + + >>> extension = sortres.compute("unit_location", **some_params) + >>> unit_location = extension.get_data() + + """ + # TODO check extension dependency + + extension_class = get_extension_class(extension_name) + extension_instance = extension_class(self) + extension_instance.set_params(**params) + extension_instance.run() + + self.extensions[extension_name] = extension_instance + + return extension_instance + + def get_saved_extension_names(self): + """ + Get extension saved in folder or zarr that can be loaded. + """ + assert self.format != "memory" + global _possible_extensions + + saved_extension_names = [] + for extension_class in _possible_extensions: + extension_name = extension_class.extension_name + if self.format == "binary_folder": + is_saved = (self.folder / extension_name).is_dir() and (self.folder / extension_name / "params.json").is_file() + elif self.format == "zarr": + zarr_root = self._get_zarr_root(mode="r") + is_saved = extension_name in zarr_root.keys() and "params" in zarr_root[extension_name].attrs.keys() + if is_saved: + saved_extension_names.append(extension_class.extension_name) + return saved_extension_names + + def get_extension(self, extension_name: str): + """ + Get a ResultExtension. + If not loaded then load it before. + + + """ + if extension_name in self.extensions: + return self.extensions[extension_name] + + if self.has_extension(extension_name): + self.load_extension(extension_name) + return self.extensions[extension_name] + + return None + + def load_extension(self, extension_name: str): + """ + Load an extensionet from folder or zarr into the `ResultSorting.extensions` dict. + + Parameters + ---------- + extension_name: str + The extension name. + + Returns + ------- + ext_instanace: + The loaded instance of the extension + + """ + assert self.format != "memory" + + extension_class = get_extension_class(extension_name) + + extension_instance = extension_class(self) + extension_instance.load_prams() + extension_instance.load_data() + + return extension_instance + + def load_all_saved_extension(self): + """ + Load all saved extension in memory. + """ + for extension_name in self.get_saved_extension_names(): + self.load_extension(extension_name) + + def delete_extension(self, extension_name) -> None: + """ + Delete the extension from the dict and also in the persistent zarr or folder. + """ + pass + + def get_loaded_extension_names(self): + """ + Return the loaded or already computed extensions names. + """ + return list(self.extensions.keys()) + + def has_extension(self, extension_name: str) -> bool: + """ + Check if the extension exists in memory (dict) or in the folder or in zarr. + + If force_load=True (the default) then the extension is automatically loaded if available. + """ + if extension_name in self.extensions: + return True + elif extension_name in self.get_saved_extension_names(): + return True + else: + return False + + +global _possible_extensions +_possible_extensions = [] + +def register_result_extension(extension_class): + """ + This maintains a list of possible extensions that are available. + It depends on the imported submodules (e.g. for postprocessing module). + + For instance with: + import spikeinterface as si + only one extension will be available + but with + import spikeinterface.postprocessing + more extensions will be available + """ + assert issubclass(extension_class, ResultExtension) + assert extension_class.extension_name is not None, "extension_name must not be None" + global _possible_extensions + + already_registered = any(extension_class is ext for ext in _possible_extensions) + if not already_registered: + assert all( + extension_class.extension_name != ext.extension_name for ext in _possible_extensions + ), "Extension name already exists" + + _possible_extensions.append(extension_class) + + +def get_extension_class(extension_name: str): + """ + Get extension class from name and check if registered. + + Parameters + ---------- + extension_name: str + The extension name. + + Returns + ------- + ext_class: + The class of the extension. + """ + global _possible_extensions + extensions_dict = {ext.extension_name: ext for ext in _possible_extensions} + assert extension_name in extensions_dict, "Extension is not registered, please import related module before" + ext_class = extensions_dict[extension_name] + return ext_class + + +class ResultExtension: + """ + This the base class to extend the SortingResult. + It can handle persistency to disk any computations related + + For instance: + * waveforms + * principal components + * spike amplitudes + * quality metrics + + Possible extension can be register on the fly at import time with register_result_extension() mechanism. + It also enables any custum computation on top on SortingResult to be implemented by the user. + + An extension needs to inherit from this class and implement some abstract methods: + * _set_params + * _run + * + + The subclass must also set an `extension_name` class attribute which is not None by default. + + The subclass must also hanle an attribute `__data` which is a dict contain the results after the `run()`. + """ + extension_name = None + + def __init__(self, sorting_result): + self._sorting_result = weakref.ref(sorting_result) + + self._params = None + self._data = dict() + + @property + def sorting_result(self): + # Important : to avoid the SortingResult referencing a ResultExtension + # and ResultExtension referencing a SortingResult we need a weakref. + # Otherwise the garbage collector is not working properly. + # and so the SortingResult + its recording are still alive even after deleting explicitly + # the SortingResult which makes it impossible to delete the folder when using memmap. + sorting_result = self._sorting_result() + if sorting_result is None: + raise ValueError(f"The extension {self.extension_name} has lost its SortingResult") + return sorting_result + + # some attribuites come from sorting_result + @property + def format(self): + return self.sorting_result.format + + @property + def sparsity(self): + return self.sorting_result.sparsity + + @property + def folder(self): + return self.sorting_result.folder + + def _get_binary_extension_folder(self): + extension_folder = self.folder / "saved_extensions" /self.extension_name + return extension_folder + + + def _get_zarr_extension_group(self, mode='r+'): + zarr_root = self.sorting_result._get_zarr_root(mode=mode) + assert self.extension_name in zarr_root.keys(), ( + f"SortingResult: extension {self.extension_name} " f"is not in folder {self.folder}" + ) + extension_group = zarr_root[self.extension_name] + return extension_group + + + @classmethod + def load(cls, sorting_result): + ext = cls(sorting_result) + ext.load_params() + ext.load_data() + return ext + + def load_params(self): + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + params_file = extension_folder / "params.json" + assert params_file.is_file(), f"No params file in extension {self.extension_name} folder" + with open(str(params_file), "r") as f: + params = json.load(f) + + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode='r') + assert "params" in extension_group.attrs, f"No params file in extension {self.extension_name} folder" + params = extension_group.attrs["params"] + + self._params = params + + def load_data(self): + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + for ext_data_file in extension_folder: + if ext_data_file.name == "params.json": + continue + ext_data_name = ext_data_file.stem + if ext_data_file.suffix == ".json": + ext_data = json.load(ext_data_file.open("r")) + elif ext_data_file.suffix == ".npy": + # The lazy loading of an extension is complicated because if we compute again + # and have a link to the old buffer on windows then it fails + # ext_data = np.load(ext_data_file, mmap_mode="r") + # so we go back to full loading + ext_data = np.load(ext_data_file) + elif ext_data_file.suffix == ".csv": + import pandas as pd + ext_data = pd.read_csv(ext_data_file, index_col=0) + elif ext_data_file.suffix == ".pkl": + ext_data = pickle.load(ext_data_file.open("rb")) + else: + continue + self._data[ext_data_name] = ext_data + + elif self.format == "zarr": + raise NotImplementedError + # TODO: decide if we make a copy or not + # extension_group = self._get_zarr_extension_group(mode='r') + # for ext_data_name in extension_group.keys(): + # ext_data_ = extension_group[ext_data_name] + # if "dict" in ext_data_.attrs: + # ext_data = ext_data_[0] + # elif "dataframe" in ext_data_.attrs: + # import xarray + # ext_data = xarray.open_zarr( + # ext_data_.store, group=f"{extension_group.name}/{ext_data_name}" + # ).to_pandas() + # ext_data.index.rename("", inplace=True) + # else: + # ext_data = ext_data_ + # self._data[ext_data_name] = ext_data + + def copy(self, new_sorting_result): + new_extension = self.__class__(new_sorting_result) + new_extension._params = self._params.copy() + new_extension._data = self._data + new_extension._save() + + def run(self, **kwargs): + self._run(**kwargs) + if not self.sorting_result.is_read_only(): + self._save(**kwargs) + + def _run(self, **kwargs): + # must be implemented in subclass + # must populate the self._data dictionary + raise NotImplementedError + + def save(self, **kwargs): + self._save(**kwargs) + + def _save(self, **kwargs): + if self.format == "memory": + return + + if self.sorting_result.is_read_only(): + raise ValueError("The SortingResult is read only save is not possible") + + # delete already saved + self._reset_folder() + self._save_params() + + + if self.format == "binary_folder": + import pandas as pd + + extension_folder = self._get_binary_extension_folder() + + for ext_data_name, ext_data in self._data.items(): + if isinstance(ext_data, dict): + with (extension_folder / f"{ext_data_name}.json").open("w") as f: + json.dump(ext_data, f) + elif isinstance(ext_data, np.ndarray): + np.save(extension_folder / f"{ext_data_name}.npy", ext_data) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_csv(extension_folder / f"{ext_data_name}.csv", index=True) + else: + try: + with (extension_folder / f"{ext_data_name}.pkl").open("wb") as f: + pickle.dump(ext_data, f) + except: + raise Exception(f"Could not save {ext_data_name} as extension data") + elif self.format == "zarr": + from .zarrextractors import get_default_zarr_compressor + import pandas as pd + import numcodecs + + extension_group = self._get_zarr_extension_group(mode="r+") + + compressor = kwargs.get("compressor", None) + if compressor is None: + compressor = get_default_zarr_compressor() + + for ext_data_name, ext_data in self._data.items(): + if ext_data_name in extension_group: + del extension_group[ext_data_name] + if isinstance(ext_data, dict): + extension_group.create_dataset( + name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() + ) + extension_group[ext_data_name].attrs["dict"] = True + elif isinstance(ext_data, np.ndarray): + extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_xarray().to_zarr( + store=extension_group.store, + group=f"{extension_group.name}/{ext_data_name}", + mode="a", + ) + extension_group[ext_data_name].attrs["dataframe"] = True + else: + try: + extension_group.create_dataset( + name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() + ) + except: + raise Exception(f"Could not save {ext_data_name} as extension data") + + def _reset_folder(self): + """ + Delete the extension in folder (binary or zarr) and create an empty one. + """ + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + if extension_folder.is_dir(): + shutil.rmtree(extension_folder) + extension_folder.mkdir(exist_ok=False, parents=True) + + elif self.format == "zarr": + import zarr + + zarr_root = zarr.open(self.folder, mode="r+") + self.extension_group = zarr_root.create_group(self.extension_name, overwrite=True) + + def reset(self): + """ + Reset the waveform extension. + Delete the sub folder and create a new empty one. + """ + self._reset_folder() + self._params = None + self._data = dict() + + def _select_extension_data(self, unit_ids): + # must be implemented in subclass + raise NotImplementedError + + def set_params(self, **params): + """ + Set parameters for the extension and + make it persistent in json. + """ + params = self._set_params(**params) + self._params = params + + print(self.sorting_result.is_read_only()) + if self.sorting_result.is_read_only(): + return + + self._save_params() + + def _save_params(self): + params_to_save = self._params.copy() + if "sparsity" in params_to_save and params_to_save["sparsity"] is not None: + assert isinstance( + params_to_save["sparsity"], ChannelSparsity + ), "'sparsity' parameter must be a ChannelSparsity object!" + params_to_save["sparsity"] = params_to_save["sparsity"].to_dict() + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + extension_folder.mkdir(exist_ok=True) + param_file = extension_folder / "params.json" + param_file.write_text(json.dumps(check_json(params_to_save), indent=4), encoding="utf8") + elif self.format == "zarr": + self.extension_group.attrs["params"] = check_json(params_to_save) + + def _set_params(self, **params): + # must be implemented in subclass + # must return a cleaned version of params dict + raise NotImplementedError + diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py new file mode 100644 index 0000000000..9a296c118b --- /dev/null +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -0,0 +1,121 @@ +import pytest +from pathlib import Path + +import shutil + +from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.core import SortingResult +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension + +import numpy as np + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "core" +else: + cache_folder = Path("cache_folder") / "core" + + +def get_dataset(): + recording, sorting = generate_ground_truth_recording( + durations=[30.0], sampling_frequency=16000.0, num_channels=10, num_units=5, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + return recording, sorting + + + +def test_SortingResult_memory(): + recording, sorting = get_dataset() + sortres = SortingResult.create(sorting, recording, format="memory", sparsity=None) + print(sortres.rec_attributes.keys()) + _check_sorting_results(sortres) + + # save to binrary_folder + folder = cache_folder / "test_SortingResult_saved_binary_folder" + if folder.exists(): + shutil.rmtree(folder) + + sortres2 = sortres.save_as(folder, format="binary_folder") + _check_sorting_results(sortres2) + + # save to zarr + # folder = cache_folder / "test_SortingResult_saved_zarr.zarr" + # if folder.exists(): + # shutil.rmtree(folder) + # sortres2 = sortres.save_as(folder, format="zarr") + # _check_sorting_results(sortres2) + + + +def test_SortingResult_folder(): + recording, sorting = get_dataset() + + folder = cache_folder / "test_SortingResult_folder" + if folder.exists(): + shutil.rmtree(folder) + + sortres = SortingResult.create(sorting, recording, format="binary_folder", folder=folder, sparsity=None) + sortres = SortingResult.load(folder) + + print(sortres.folder) + + _check_sorting_results(sortres) + +def _check_sorting_results(sortres): + register_result_extension(DummyResultExtension) + + print() + print(sortres) + print(sortres.sampling_frequency) + print(sortres.channel_ids) + print(sortres.unit_ids) + print(sortres.get_probe()) + print(sortres.sparsity) + + sortres.compute("dummy", param1=5.5) + ext = sortres.get_extension("dummy") + assert ext is not None + assert ext._params["param1"] == 5.5 + sortres.compute("dummy", param1=5.5) + + sortres.delete_extension("dummy") + ext = sortres.get_extension("dummy") + assert ext is None + + +class DummyResultExtension(ResultExtension): + extension_name = "dummy" + + def _set_params(self, param0="yep", param1=1.2, param2=[1,2, 3.]): + params = dict(param0=param0, param1=param1, param2=param2) + params["more_option"] = "yep" + return params + + def _run(self, **kwargs): + # print("dummy run") + self._data["result_one"] = "abcd" + self._data["result_two"] = np.zeros(3) + + +class DummyResultExtension2(ResultExtension): + extension_name = "dummy" + + +def test_extension(): + register_result_extension(DummyResultExtension) + # can be register twice + register_result_extension(DummyResultExtension) + + # same name should trigger an error + with pytest.raises(AssertionError): + register_result_extension(DummyResultExtension2) + + +if __name__ == "__main__": + test_SortingResult_memory() + + # test_SortingResult_folder() + + # test_extension() \ No newline at end of file From b23c772ebef03496e01c4c835119925b5aeb050e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 11 Jan 2024 16:29:47 +0100 Subject: [PATCH 002/192] SortingResult more test with ResultExtension --- src/spikeinterface/core/sortingresult.py | 216 ++++++++++++------ .../core/tests/test_sortingresult.py | 111 ++++++--- 2 files changed, 234 insertions(+), 93 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index f92fa214ed..f9ab0d5a1b 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -92,8 +92,20 @@ def start_sorting_result(sorting, recording, format="memory", folder=None, >>> # Extract dense waveforms and save to disk with binary_folder format. >>> sortres = si.start_sorting_result(sorting, recording, format="binary_folder", folder="/path/to_my/result") - """ + >>> # Can be reload + >>> sortres = si.load_sorting_result(folder="/path/to_my/result") + + >>> # Can run extension + >>> sortres = si.compute("unit_locations", ...) + + >>> # Can be copy to another format (extensions are propagated) + >>> sortres2 = sortres.save_as(format="memory") + >>> sortres3 = sortres.save_as(format="zarr", folder="/path/to_my/result.zarr") + >>> # Can make a copy with a subset of units (extensions are propagated for the unit subset) + >>> sortres4 = sortres.select_units(unit_ids=sorting.units_ids[:5], format="memory") + >>> sortres5 = sortres.select_units(unit_ids=sorting.units_ids[:5], format="binary_folder", folder="/result_5units") + """ # handle sparsity if sparsity is not None: @@ -158,6 +170,7 @@ class SortingResult: This internaly maintain a list of computed ResultExtention (waveform, pca, unit position, spike poisition, ...). This can live in memory and/or can be be persistent to disk in 2 internal formats (folder/json/npz or zarr). + A SortingResult can be transfer to another format using `save_as()` This handle unit sparsity that can be propagated to ResultExtention. @@ -196,11 +209,7 @@ def create(cls, check_probe_do_not_overlap(all_probes) if format == "memory": - rec_attributes = get_rec_attributes(recording) - rec_attributes["probegroup"] = recording.get_probegroup() - # a copy of sorting is created directly in shared memory format to avoid further duplication of spikes. - sorting_copy = SharedMemorySorting.from_sorting(sorting) - sortres = SortingResult(sorting=sorting_copy, recording=recording, rec_attributes=rec_attributes, format=format, sparsity=sparsity) + sortres = cls.create_memory(sorting, recording, sparsity, rec_attributes=None) elif format == "binary_folder": cls.create_binary_folder(folder, sorting, recording, sparsity, rec_attributes=None) sortres = cls.load_from_binary_folder(folder, recording=recording) @@ -241,9 +250,27 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto"): return sortres + @classmethod + def create_memory(cls, sorting, recording, sparsity, rec_attributes): + # used by create and save_as + + if rec_attributes is None: + assert recording is not None + rec_attributes = get_rec_attributes(recording) + rec_attributes["probegroup"] = recording.get_probegroup() + else: + # a copy is done to avoid shared dict between instances (which can block garbage collector) + rec_attributes = rec_attributes.copy() + + # a copy of sorting is created directly in shared memory format to avoid further duplication of spikes. + sorting_copy = SharedMemorySorting.from_sorting(sorting) + sortres = SortingResult(sorting=sorting_copy, recording=recording, rec_attributes=rec_attributes, + format="memory", sparsity=sparsity) + return sortres + @classmethod def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attributes): - # used by create and save + # used by create and save_as folder = Path(folder) if folder.is_dir(): @@ -318,7 +345,7 @@ def load_from_binary_folder(cls, folder, recording=None): rec_attributes = json.load(f) # the probe is handle ouside the main json probegroup_file = folder / "recording_info" / "probegroup.json" - print(probegroup_file, probegroup_file.is_file()) + if probegroup_file.is_file(): rec_attributes["probegroup"] = probeinterface.read_probeinterface(probegroup_file) else: @@ -355,23 +382,9 @@ def load_from_zarr(cls, folder, recording=None): raise NotImplementedError - def save_as( - self, folder=None, format="binary_folder", - ) -> "SortingResult": + def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingResult": """ - Save SortingResult object into another format. - Uselfull for memory to zarr or memory to binray. - - Note that the recording provenance or sorting provenance can be lost. - - Mainly propagate the copied sorting and recording property. - - Parameters - ---------- - folder : str or Path - The output waveform folder - format : "binary_folder" | "zarr", default: "binary_folder" - The backend to use for saving the waveforms + Internal used by both save_as(), copy() and select_units() which are more or less the same. """ if self.has_recording(): @@ -384,29 +397,81 @@ def save_as( if sorting_provenance is None: # if the original sorting objetc is not available anymore (kilosort folder deleted, ....), take the copy sorting_provenance = self.sorting + + if unit_ids is not None: + # when only some unit_ids then the sorting must be sliced + sorting_provenance = sorting_provenance.select_units(unit_ids) if format == "memory": # This make a copy of actual SortingResult - # TODO - raise NotImplementedError + new_sortres = SortingResult.create_memory(sorting_provenance, recording, self.sparsity, self.rec_attributes) + elif format == "binary_folder": # create a new folder + assert folder is not None, "For format='binary_folder' folder must be provided" SortingResult.create_binary_folder(folder, sorting_provenance, recording, self.sparsity, self.rec_attributes) new_sortres = SortingResult.load_from_binary_folder(folder) new_sortres.folder = folder elif format == "zarr": - # TODO + assert folder is not None, "For format='zarr' folder must be provided" raise NotImplementedError else: raise ValueError("SortingResult.save: wrong format") # make a copy of extensions + # note that the copy of extension handle itself the slicing of units when necessary for extension_name, extension in self.extensions.items(): - new_sortres.extensions[extension_name] = extension.copy(new_sortres) + new_sortres.extensions[extension_name] = extension.copy(new_sortres, unit_ids=unit_ids) return new_sortres + def save_as(self, format="binary_folder", folder=None) -> "SortingResult": + """ + Save SortingResult object into another format. + Uselfull for memory to zarr or memory to binray. + + Note that the recording provenance or sorting provenance can be lost. + + Mainly propagate the copied sorting and recording property. + + Parameters + ---------- + folder : str or Path + The output waveform folder + format : "binary_folder" | "zarr", default: "binary_folder" + The backend to use for saving the waveforms + """ + return self._save_or_select(format=format, folder=folder, unit_ids=None) + + + def select_units(self, unit_ids, folder=None, format="binary_folder") -> "SortingResult": + """ + This method is equivalent to `save_as()`but with a subset of units. + Filters units by creating a new waveform extractor object in a new folder. + + Extensions are also updated to filter the selected unit ids. + + Parameters + ---------- + unit_ids : list or array + The unit ids to keep in the new WaveformExtractor object + folder : Path or None + The new folder where selected waveforms are copied + format: + a + Returns + ------- + we : WaveformExtractor + The newly create waveform extractor with the selected units + """ + return self._save_or_select(format=format, folder=folder, unit_ids=unit_ids) + + def copy(self): + """ + Create a a copy of SortingResult with format "memory". + """ + return self._save_or_select(format="binary_folder", folder=None, unit_ids=None) def is_read_only(self) -> bool: if self.format == "memory": @@ -451,12 +516,14 @@ def get_sorting_provenance(self): elif self.format == "binary_folder": for type in ("json", "pickle"): filename = self.folder / f"sorting_provenance.{type}" + sorting_provenance = None if filename.exists(): try: sorting_provenance = load_extractor(filename, base_folder=self.folder) break except: - sorting_provenance = None + pass + # sorting_provenance = None elif self.format == "zarr": # TODO @@ -574,8 +641,9 @@ def get_saved_extension_names(self): def get_extension(self, extension_name: str): """ Get a ResultExtension. - If not loaded then load it before. + If not loaded then load is automatic. + Return None if the extension is not computed yet (this avoid the use of has_extension() and then get it) """ if extension_name in self.extensions: @@ -607,7 +675,7 @@ def load_extension(self, extension_name: str): extension_class = get_extension_class(extension_name) extension_instance = extension_class(self) - extension_instance.load_prams() + extension_instance.load_params() extension_instance.load_data() return extension_instance @@ -623,7 +691,15 @@ def delete_extension(self, extension_name) -> None: """ Delete the extension from the dict and also in the persistent zarr or folder. """ - pass + + # delete from folder or zarr + if self.format != "memory" and self.has_extension(extension_name): + # need a reload to reset the folder + ext = self.load_extension(extension_name) + ext.reset() + + # remove from dict + self.extensions.pop(extension_name, None) def get_loaded_extension_names(self): """ @@ -634,11 +710,11 @@ def get_loaded_extension_names(self): def has_extension(self, extension_name: str) -> bool: """ Check if the extension exists in memory (dict) or in the folder or in zarr. - - If force_load=True (the default) then the extension is automatically loaded if available. """ if extension_name in self.extensions: return True + elif self.format == "memory": + return False elif extension_name in self.get_saved_extension_names(): return True else: @@ -711,7 +787,7 @@ class ResultExtension: An extension needs to inherit from this class and implement some abstract methods: * _set_params * _run - * + * _select_extension_data The subclass must also set an `extension_name` class attribute which is not None by default. @@ -722,8 +798,27 @@ class ResultExtension: def __init__(self, sorting_result): self._sorting_result = weakref.ref(sorting_result) - self._params = None - self._data = dict() + self.params = None + self.data = dict() + + ####### + # This 3 methods must be implemented in the subclass!!! + # See DummyResultExtension in test_sortingresult.py as a simple example + def _run(self, **kwargs): + # must be implemented in subclass + # must populate the self.data dictionary + raise NotImplementedError + + def _set_params(self, **params): + # must be implemented in subclass + # must return a cleaned version of params dict + raise NotImplementedError + + def _select_extension_data(self, unit_ids): + # must be implemented in subclass + raise NotImplementedError + # + ####### @property def sorting_result(self): @@ -784,12 +879,12 @@ def load_params(self): assert "params" in extension_group.attrs, f"No params file in extension {self.extension_name} folder" params = extension_group.attrs["params"] - self._params = params + self.params = params def load_data(self): if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() - for ext_data_file in extension_folder: + for ext_data_file in extension_folder.iterdir(): if ext_data_file.name == "params.json": continue ext_data_name = ext_data_file.stem @@ -808,7 +903,7 @@ def load_data(self): ext_data = pickle.load(ext_data_file.open("rb")) else: continue - self._data[ext_data_name] = ext_data + self.data[ext_data_name] = ext_data elif self.format == "zarr": raise NotImplementedError @@ -826,24 +921,24 @@ def load_data(self): # ext_data.index.rename("", inplace=True) # else: # ext_data = ext_data_ - # self._data[ext_data_name] = ext_data + # self.data[ext_data_name] = ext_data - def copy(self, new_sorting_result): + def copy(self, new_sorting_result, unit_ids=None): + # alessio : please note that this also replace the old BaseWaveformExtractorExtension.select_units!!! new_extension = self.__class__(new_sorting_result) - new_extension._params = self._params.copy() - new_extension._data = self._data + new_extension.params = self.params.copy() + if unit_ids is None: + new_extension.data = self.data + else: + new_extension.data = self._select_extension_data(unit_ids) new_extension._save() + return new_extension def run(self, **kwargs): self._run(**kwargs) if not self.sorting_result.is_read_only(): self._save(**kwargs) - def _run(self, **kwargs): - # must be implemented in subclass - # must populate the self._data dictionary - raise NotImplementedError - def save(self, **kwargs): self._save(**kwargs) @@ -864,7 +959,7 @@ def _save(self, **kwargs): extension_folder = self._get_binary_extension_folder() - for ext_data_name, ext_data in self._data.items(): + for ext_data_name, ext_data in self.data.items(): if isinstance(ext_data, dict): with (extension_folder / f"{ext_data_name}.json").open("w") as f: json.dump(ext_data, f) @@ -889,7 +984,7 @@ def _save(self, **kwargs): if compressor is None: compressor = get_default_zarr_compressor() - for ext_data_name, ext_data in self._data.items(): + for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: del extension_group[ext_data_name] if isinstance(ext_data, dict): @@ -936,12 +1031,9 @@ def reset(self): Delete the sub folder and create a new empty one. """ self._reset_folder() - self._params = None - self._data = dict() + self.params = None + self.data = dict() - def _select_extension_data(self, unit_ids): - # must be implemented in subclass - raise NotImplementedError def set_params(self, **params): """ @@ -949,16 +1041,15 @@ def set_params(self, **params): make it persistent in json. """ params = self._set_params(**params) - self._params = params + self.params = params - print(self.sorting_result.is_read_only()) if self.sorting_result.is_read_only(): return self._save_params() def _save_params(self): - params_to_save = self._params.copy() + params_to_save = self.params.copy() if "sparsity" in params_to_save and params_to_save["sparsity"] is not None: assert isinstance( params_to_save["sparsity"], ChannelSparsity @@ -966,14 +1057,11 @@ def _save_params(self): params_to_save["sparsity"] = params_to_save["sparsity"].to_dict() if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() - extension_folder.mkdir(exist_ok=True) + extension_folder.mkdir(exist_ok=True, parents=True) param_file = extension_folder / "params.json" param_file.write_text(json.dumps(check_json(params_to_save), indent=4), encoding="utf8") elif self.format == "zarr": self.extension_group.attrs["params"] = check_json(params_to_save) - def _set_params(self, **params): - # must be implemented in subclass - # must return a cleaned version of params dict - raise NotImplementedError + diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index 9a296c118b..10b09eb49f 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -29,23 +29,16 @@ def get_dataset(): def test_SortingResult_memory(): recording, sorting = get_dataset() sortres = SortingResult.create(sorting, recording, format="memory", sparsity=None) - print(sortres.rec_attributes.keys()) - _check_sorting_results(sortres) - # save to binrary_folder - folder = cache_folder / "test_SortingResult_saved_binary_folder" - if folder.exists(): - shutil.rmtree(folder) + _check_sorting_results(sortres, sorting) - sortres2 = sortres.save_as(folder, format="binary_folder") - _check_sorting_results(sortres2) - # save to zarr + # save to zarr: not done yet!!! # folder = cache_folder / "test_SortingResult_saved_zarr.zarr" # if folder.exists(): # shutil.rmtree(folder) - # sortres2 = sortres.save_as(folder, format="zarr") - # _check_sorting_results(sortres2) + # sortres2 = sortres.save_as(format="zarr", folder=folder) + # _check_sorting_results(sortres2, sorting) @@ -58,33 +51,77 @@ def test_SortingResult_folder(): sortres = SortingResult.create(sorting, recording, format="binary_folder", folder=folder, sparsity=None) sortres = SortingResult.load(folder) - - print(sortres.folder) + _check_sorting_results(sortres, sorting) - _check_sorting_results(sortres) -def _check_sorting_results(sortres): - register_result_extension(DummyResultExtension) +def _check_sorting_results(sortres, original_sorting): print() print(sortres) - print(sortres.sampling_frequency) - print(sortres.channel_ids) - print(sortres.unit_ids) - print(sortres.get_probe()) - print(sortres.sparsity) + register_result_extension(DummyResultExtension) + + assert "channel_ids" in sortres.rec_attributes + assert "sampling_frequency" in sortres.rec_attributes + assert "num_samples" in sortres.rec_attributes + + probe = sortres.get_probe() + sparsity = sortres.sparsity + + # compute sortres.compute("dummy", param1=5.5) ext = sortres.get_extension("dummy") assert ext is not None - assert ext._params["param1"] == 5.5 + assert ext.params["param1"] == 5.5 + # recompute sortres.compute("dummy", param1=5.5) - + # and delete sortres.delete_extension("dummy") ext = sortres.get_extension("dummy") assert ext is None + # save to several format + for format in ("memory", "binary_folder", ): # "zarr" + if format != "memory": + folder = cache_folder / f"test_SortingResult_save_as_{format}" + if folder.exists(): + shutil.rmtree(folder) + else: + folder = None + + # compute one extension to check the save + sortres.compute("dummy") + + sortres2 = sortres.save_as(format=format, folder=folder) + ext = sortres2.get_extension("dummy") + assert ext is not None + + data = sortres2.get_extension("dummy").data + assert "result_one" in data + assert data["result_two"].size == original_sorting.to_spike_vector().size + + # select unit_ids to several format + for format in ("memory", "binary_folder", ): # "zarr" + if format != "memory": + folder = cache_folder / f"test_SortingResult_select_units_with{format}" + if folder.exists(): + shutil.rmtree(folder) + else: + folder = None + # compute one extension to check the slice + sortres.compute("dummy") + keep_unit_ids = original_sorting.unit_ids[::2] + sortres2 = sortres.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) + + # check propagation of result data and correct sligin + assert np.array_equal(keep_unit_ids, sortres2.unit_ids) + data = sortres2.get_extension("dummy").data + assert data["result_one"] == sortres.get_extension("dummy").data["result_one"] + # unit 1, 3, ... should be removed + assert np.all(~np.isin(data["result_two"], [-1, -3])) + + class DummyResultExtension(ResultExtension): extension_name = "dummy" @@ -95,8 +132,24 @@ def _set_params(self, param0="yep", param1=1.2, param2=[1,2, 3.]): def _run(self, **kwargs): # print("dummy run") - self._data["result_one"] = "abcd" - self._data["result_two"] = np.zeros(3) + self.data["result_one"] = "abcd" + # the result two has the same size of the spike vector!! + # and represent nothing (the trick is to use unit_index for testing slice) + spikes = self.sorting_result.sorting.to_spike_vector() + self.data["result_two"] = spikes["unit_index"] * -1 + + def _select_extension_data(self, unit_ids): + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + + spikes = self.sorting_result.sorting.to_spike_vector() + keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) + # here the first key do not depend on unit_id + # but the second need to be sliced!! + new_data = dict() + new_data["result_one"] = self.data["result_one"] + new_data["result_two"] = self.data["result_two"][keep_spike_mask] + + return new_data class DummyResultExtension2(ResultExtension): @@ -105,10 +158,10 @@ class DummyResultExtension2(ResultExtension): def test_extension(): register_result_extension(DummyResultExtension) - # can be register twice + # can be register twice without error register_result_extension(DummyResultExtension) - # same name should trigger an error + # other extension with same name should trigger an error with pytest.raises(AssertionError): register_result_extension(DummyResultExtension2) @@ -116,6 +169,6 @@ def test_extension(): if __name__ == "__main__": test_SortingResult_memory() - # test_SortingResult_folder() + test_SortingResult_folder() - # test_extension() \ No newline at end of file + test_extension() \ No newline at end of file From 1b90437b52ef1e97860df1d38e8444bfdc727b54 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 11 Jan 2024 17:17:48 +0100 Subject: [PATCH 003/192] clean tests --- src/spikeinterface/core/sortingresult.py | 4 +- .../core/tests/test_sortingresult.py | 44 ++++++++++--------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index f9ab0d5a1b..bae0a8f0fb 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -340,7 +340,7 @@ def load_from_binary_folder(cls, folder, recording=None): # recording attributes rec_attributes_file = folder / "recording_info" / "recording_attributes.json" if not rec_attributes_file.exists(): - raise ValueError("This folder is not a SortingResult folder") + raise ValueError("This folder is not a SortingResult with format='binary_folder'") with open(rec_attributes_file, "r") as f: rec_attributes = json.load(f) # the probe is handle ouside the main json @@ -1063,5 +1063,3 @@ def _save_params(self): elif self.format == "zarr": self.extension_group.attrs["params"] = check_json(params_to_save) - - diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index 10b09eb49f..ddcef0c454 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -4,7 +4,7 @@ import shutil from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import SortingResult +from spikeinterface.core import SortingResult, start_sorting_result, load_sorting_result from spikeinterface.core.sortingresult import register_result_extension, ResultExtension import numpy as np @@ -28,32 +28,37 @@ def get_dataset(): def test_SortingResult_memory(): recording, sorting = get_dataset() - sortres = SortingResult.create(sorting, recording, format="memory", sparsity=None) - + sortres = start_sorting_result(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_results(sortres, sorting) - # save to zarr: not done yet!!! - # folder = cache_folder / "test_SortingResult_saved_zarr.zarr" - # if folder.exists(): - # shutil.rmtree(folder) - # sortres2 = sortres.save_as(format="zarr", folder=folder) - # _check_sorting_results(sortres2, sorting) - - - -def test_SortingResult_folder(): +def test_SortingResult_binary_folder(): recording, sorting = get_dataset() - folder = cache_folder / "test_SortingResult_folder" + folder = cache_folder / "test_SortingResult_binary_folder" if folder.exists(): shutil.rmtree(folder) - sortres = SortingResult.create(sorting, recording, format="binary_folder", folder=folder, sparsity=None) - sortres = SortingResult.load(folder) + sortres = start_sorting_result(sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None) + sortres = load_sorting_result(folder, format="auto") + _check_sorting_results(sortres, sorting) +# def test_SortingResult_zarr(): +# recording, sorting = get_dataset() + +# folder = cache_folder / "test_SortingResult_zarr.zarr" +# if folder.exists(): +# shutil.rmtree(folder) + +# sortres = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) +# sortres = load_sorting_result(folder, format="auto") + +# _check_sorting_results(sortres, sorting) + + + def _check_sorting_results(sortres, original_sorting): print() @@ -168,7 +173,6 @@ def test_extension(): if __name__ == "__main__": test_SortingResult_memory() - - test_SortingResult_folder() - - test_extension() \ No newline at end of file + test_SortingResult_binary_folder() + # test_SortingResult_zarr() + test_extension() From 8f98ce81b978f3aa82f88bf035924a1e17dbb600 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 12 Jan 2024 09:40:53 +0100 Subject: [PATCH 004/192] Start zarr for SortingResult --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/sortingresult.py | 111 ++++++++++++++++-- .../core/tests/test_sortingresult.py | 24 ++-- 3 files changed, 116 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 6c811d8fff..25f5d6ef53 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -133,4 +133,4 @@ # channel sparsity from .sparsity import ChannelSparsity, compute_sparsity -from .sortingresult import SortingResult +from .sortingresult import SortingResult, start_sorting_result, load_sorting_result diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index bae0a8f0fb..e025c79a88 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -7,6 +7,7 @@ import pickle import weakref import shutil +import warnings import numpy as np @@ -21,7 +22,7 @@ from .numpyextractors import SharedMemorySorting from .sparsity import ChannelSparsity from .sortingfolder import NumpyFolderSorting - +from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor # TODO # * make info.json that contain some version info of spikeinterface @@ -367,7 +368,7 @@ def load_from_binary_folder(cls, folder, recording=None): sparsity=sparsity) return sortres - + def _get_zarr_root(self, mode="r+"): import zarr zarr_root = zarr.open(self.folder, mode=mode) @@ -375,11 +376,108 @@ def _get_zarr_root(self, mode="r+"): @classmethod def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): - raise NotImplementedError + # used by create and save_as + import zarr + + + + folder = Path(folder) + # force zarr sufix + if folder.suffix != ".zarr": + folder = folder.parent / f"{folder.stem}.zarr" + + if folder.is_dir(): + raise ValueError(f"Folder already exists {folder}") + + + + + zarr_root = zarr.open(folder, mode="w") + + # the recording and sorting provenance can be used only if compatible with json + if recording.check_serializability("json"): + rec_dict = recording.to_dict(relative_to=folder, recursive=True) + zarr_root.attrs["recording"] = check_json(rec_dict) + else: + warnings.warn("SortingResult with zarr : the Recording is not json serializable, the recording link will be lost for futur load") + + if sorting.check_serializability("json"): + sort_dict = sorting.to_dict(relative_to=folder, recursive=True) + zarr_root.attrs["sorting_provenance"] = check_json(sort_dict) + # else: + # warnings.warn("SortingResult with zarr : the sorting provenance is not json serializable, the sorting provenance link will be lost for futur load") + + recording_info = zarr_root.create_group("recording_info") + + if rec_attributes is None: + assert recording is not None + rec_attributes = get_rec_attributes(recording) + probegroup = recording.get_probegroup() + else: + rec_attributes_copy = rec_attributes.copy() + probegroup = rec_attributes_copy.pop("probegroup") + recording_info.attrs["recording_attributes"] = check_json(rec_attributes) + + if probegroup is not None: + recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) + + if sparsity is not None: + zarr_root.attrs["sparsity"] = check_json(sparsity.to_dict()) + + # write sorting copy + from .zarrextractors import add_sorting_to_zarr_group + # Alessio : we need to find a way to propagate compressor for all steps. + # kwargs = dict(compressor=...) + zarr_kwargs = dict() + add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **zarr_kwargs) + @classmethod def load_from_zarr(cls, folder, recording=None): - raise NotImplementedError + import zarr + folder = Path(folder) + assert folder.is_dir(), f"This folder does not exists {folder}" + + zarr_root = zarr.open(folder, mode="r") + + # load internal sorting copy and make it sharedmem + # TODO + # sorting = ZarrSortingExtractor... + sorting = SharedMemorySorting.from_sorting(NumpyFolderSorting(folder / "sorting")) + + # load recording if possible + if recording is None: + try: + recording = load_extractor(zarr_root.attrs["recording"], base_folder=folder) + except: + recording = None + else: + # TODO maybe maybe not??? : do we need to check attributes match internal rec_attributes + # Note this will make the loading too slow + pass + + # recording attributes + rec_attributes = zarr_root.require_group("recording_info").attrs["recording_attributes"] + if "probegroup" in zarr_root.require_group("recording_info").attrs: + probegroup_dict = zarr_root.require_group("recording_info").attrs["probegroup"] + rec_attributes["probegroup"] = probeinterface.Probe.from_dict(probegroup_dict) + else: + rec_attributes["probegroup"] = None + + # sparsity + if "sparsity" in zarr_root.attrs: + sparsity = zarr_root.attrs["sparsity"] + else: + sparsity = None + + sortres = SortingResult( + sorting=sorting, + recording=recording, + rec_attributes=rec_attributes, + format="zarr", + sparsity=sparsity) + + return sortres def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingResult": @@ -531,9 +629,6 @@ def get_sorting_provenance(self): return sorting_provenance - # def is_read_only(self) -> bool: - # return self._is_read_only - def get_num_samples(self, segment_index: Optional[int] = None) -> int: # we use self.sorting to check segment_index segment_index = self.sorting._check_segment_index(segment_index) @@ -974,7 +1069,7 @@ def _save(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - from .zarrextractors import get_default_zarr_compressor + import pandas as pd import numcodecs diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index ddcef0c454..799ed58676 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -45,17 +45,17 @@ def test_SortingResult_binary_folder(): _check_sorting_results(sortres, sorting) -# def test_SortingResult_zarr(): -# recording, sorting = get_dataset() +def test_SortingResult_zarr(): + recording, sorting = get_dataset() -# folder = cache_folder / "test_SortingResult_zarr.zarr" -# if folder.exists(): -# shutil.rmtree(folder) + folder = cache_folder / "test_SortingResult_zarr.zarr" + if folder.exists(): + shutil.rmtree(folder) -# sortres = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) -# sortres = load_sorting_result(folder, format="auto") + sortres = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) + # sortres = load_sorting_result(folder, format="auto") -# _check_sorting_results(sortres, sorting) + # _check_sorting_results(sortres, sorting) @@ -172,7 +172,7 @@ def test_extension(): if __name__ == "__main__": - test_SortingResult_memory() - test_SortingResult_binary_folder() - # test_SortingResult_zarr() - test_extension() + # test_SortingResult_memory() + # test_SortingResult_binary_folder() + test_SortingResult_zarr() + # test_extension() From 9841deb2cf656d8db16b49eb2fc2f42be02f2c15 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 12 Jan 2024 17:52:31 +0100 Subject: [PATCH 005/192] zarr implementation for SortingResult --- src/spikeinterface/core/sortingresult.py | 162 ++++++++++++------ .../core/tests/test_sortingresult.py | 19 +- 2 files changed, 118 insertions(+), 63 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index e025c79a88..8fd3df059c 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -140,6 +140,7 @@ def start_sorting_result(sorting, recording, format="memory", folder=None, return sorting_result + def load_sorting_result(folder, load_extensions=True, format="auto"): """ Load a SortingResult object from disk. @@ -214,9 +215,11 @@ def create(cls, elif format == "binary_folder": cls.create_binary_folder(folder, sorting, recording, sparsity, rec_attributes=None) sortres = cls.load_from_binary_folder(folder, recording=recording) + sortres.folder = folder elif format == "zarr": cls.create_zarr(folder, sorting, recording, sparsity, rec_attributes=None) sortres = cls.load_from_zarr(folder, recording=recording) + sortres.folder = folder else: raise ValueError("SortingResult.create: wrong format") @@ -273,6 +276,10 @@ def create_memory(cls, sorting, recording, sparsity, rec_attributes): def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attributes): # used by create and save_as + # TODO add a spikeinterface_info.json folder with SI version and object type + + assert recording is not None, "To create a SortingResult you need recording not None" + folder = Path(folder) if folder.is_dir(): raise ValueError(f"Folder already exists {folder}") @@ -378,8 +385,9 @@ def _get_zarr_root(self, mode="r+"): def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): # used by create and save_as import zarr + import numcodecs - + # TODO add an attribute with SI version and object type folder = Path(folder) # force zarr sufix @@ -389,21 +397,32 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): if folder.is_dir(): raise ValueError(f"Folder already exists {folder}") - - - zarr_root = zarr.open(folder, mode="w") - # the recording and sorting provenance can be used only if compatible with json + # the recording + rec_dict = recording.to_dict(relative_to=folder, recursive=True) + zarr_rec = np.array([rec_dict], dtype=object) if recording.check_serializability("json"): - rec_dict = recording.to_dict(relative_to=folder, recursive=True) - zarr_root.attrs["recording"] = check_json(rec_dict) + # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) + zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) + elif recording.check_serializability("pickle"): + # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) + zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) else: warnings.warn("SortingResult with zarr : the Recording is not json serializable, the recording link will be lost for futur load") + + # sorting provenance + sort_dict = sorting.to_dict(relative_to=folder, recursive=True) if sorting.check_serializability("json"): - sort_dict = sorting.to_dict(relative_to=folder, recursive=True) - zarr_root.attrs["sorting_provenance"] = check_json(sort_dict) + # zarr_root.attrs["sorting_provenance"] = check_json(sort_dict) + zarr_sort = np.array([sort_dict], dtype=object) + zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON()) + elif sorting.check_serializability("pickle"): + # zarr_root.create_dataset("sorting_provenance", data=sort_dict, object_codec=numcodecs.Pickle()) + zarr_sort = np.array([sort_dict], dtype=object) + zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle()) + # else: # warnings.warn("SortingResult with zarr : the sorting provenance is not json serializable, the sorting provenance link will be lost for futur load") @@ -414,15 +433,19 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): rec_attributes = get_rec_attributes(recording) probegroup = recording.get_probegroup() else: - rec_attributes_copy = rec_attributes.copy() - probegroup = rec_attributes_copy.pop("probegroup") + rec_attributes = rec_attributes.copy() + probegroup = rec_attributes.pop("probegroup") + recording_info.attrs["recording_attributes"] = check_json(rec_attributes) + # recording_info.create_dataset("recording_attributes", data=check_json(rec_attributes), object_codec=numcodecs.JSON()) if probegroup is not None: recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) + # recording_info.create_dataset("probegroup", data=check_json(probegroup.to_dict()), object_codec=numcodecs.JSON()) if sparsity is not None: zarr_root.attrs["sparsity"] = check_json(sparsity.to_dict()) + # zarr_root.create_dataset("sparsity", data=check_json(sparsity.to_dict()), object_codec=numcodecs.JSON()) # write sorting copy from .zarrextractors import add_sorting_to_zarr_group @@ -431,6 +454,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): zarr_kwargs = dict() add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **zarr_kwargs) + recording_info = zarr_root.create_group("extensions") + @classmethod def load_from_zarr(cls, folder, recording=None): @@ -440,15 +465,16 @@ def load_from_zarr(cls, folder, recording=None): zarr_root = zarr.open(folder, mode="r") - # load internal sorting copy and make it sharedmem - # TODO - # sorting = ZarrSortingExtractor... - sorting = SharedMemorySorting.from_sorting(NumpyFolderSorting(folder / "sorting")) + # load internal sorting and make it sharedmem + # TODO propagate storage_options + sorting = SharedMemorySorting.from_sorting(ZarrSortingExtractor(folder, zarr_group="sorting")) # load recording if possible if recording is None: + rec_dict = zarr_root["recording"][0] try: - recording = load_extractor(zarr_root.attrs["recording"], base_folder=folder) + + recording = load_extractor(rec_dict, base_folder=folder) except: recording = None else: @@ -457,16 +483,19 @@ def load_from_zarr(cls, folder, recording=None): pass # recording attributes - rec_attributes = zarr_root.require_group("recording_info").attrs["recording_attributes"] - if "probegroup" in zarr_root.require_group("recording_info").attrs: - probegroup_dict = zarr_root.require_group("recording_info").attrs["probegroup"] - rec_attributes["probegroup"] = probeinterface.Probe.from_dict(probegroup_dict) + rec_attributes = zarr_root["recording_info"].attrs["recording_attributes"] + # rec_attributes = zarr_root["recording_info"]["recording_attributes"] + if "probegroup" in zarr_root["recording_info"].attrs: + probegroup_dict = zarr_root["recording_info"].attrs["probegroup"] + # probegroup_dict = zarr_root["recording_info"]["probegroup"] + rec_attributes["probegroup"] = probeinterface.ProbeGroup.from_dict(probegroup_dict) else: rec_attributes["probegroup"] = None # sparsity if "sparsity" in zarr_root.attrs: - sparsity = zarr_root.attrs["sparsity"] + # sparsity = zarr_root.attrs["sparsity"] + sparsity = zarr_root["sparsity"] else: sparsity = None @@ -508,12 +537,14 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> # create a new folder assert folder is not None, "For format='binary_folder' folder must be provided" SortingResult.create_binary_folder(folder, sorting_provenance, recording, self.sparsity, self.rec_attributes) - new_sortres = SortingResult.load_from_binary_folder(folder) + new_sortres = SortingResult.load_from_binary_folder(folder, recording=recording) new_sortres.folder = folder elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" - raise NotImplementedError + SortingResult.create_zarr(folder, sorting_provenance, recording, self.sparsity, self.rec_attributes) + new_sortres = SortingResult.load_from_zarr(folder, recording=recording) + new_sortres.folder = folder else: raise ValueError("SortingResult.save: wrong format") @@ -624,8 +655,12 @@ def get_sorting_provenance(self): # sorting_provenance = None elif self.format == "zarr": - # TODO - raise NotImplementedError + zarr_root = self._get_zarr_root(mode="r") + if "sorting_provenance" in zarr_root.keys(): + sort_dict = zarr_root["sorting_provenance"][0] + sorting_provenance = load_extractor(sort_dict, base_folder=self.folder) + else: + sorting_provenance = None return sorting_provenance @@ -721,17 +756,27 @@ def get_saved_extension_names(self): assert self.format != "memory" global _possible_extensions + if self.format == "zarr": + zarr_root = self._get_zarr_root(mode="r") + if "extensions" in zarr_root.keys(): + extension_group = zarr_root["extensions"] + else: + extension_group = None + saved_extension_names = [] for extension_class in _possible_extensions: extension_name = extension_class.extension_name if self.format == "binary_folder": is_saved = (self.folder / extension_name).is_dir() and (self.folder / extension_name / "params.json").is_file() elif self.format == "zarr": - zarr_root = self._get_zarr_root(mode="r") - is_saved = extension_name in zarr_root.keys() and "params" in zarr_root[extension_name].attrs.keys() + if extension_group is not None: + is_saved = extension_name in extension_group.keys() and "params" in extension_group[extension_name].attrs.keys() + else: + is_saved = False if is_saved: saved_extension_names.append(extension_class.extension_name) - return saved_extension_names + + return saved_extension_names def get_extension(self, extension_name: str): """ @@ -947,12 +992,8 @@ def _get_binary_extension_folder(self): def _get_zarr_extension_group(self, mode='r+'): zarr_root = self.sorting_result._get_zarr_root(mode=mode) - assert self.extension_name in zarr_root.keys(), ( - f"SortingResult: extension {self.extension_name} " f"is not in folder {self.folder}" - ) - extension_group = zarr_root[self.extension_name] + extension_group = zarr_root["extensions"][self.extension_name] return extension_group - @classmethod def load(cls, sorting_result): @@ -1001,22 +1042,24 @@ def load_data(self): self.data[ext_data_name] = ext_data elif self.format == "zarr": - raise NotImplementedError - # TODO: decide if we make a copy or not - # extension_group = self._get_zarr_extension_group(mode='r') - # for ext_data_name in extension_group.keys(): - # ext_data_ = extension_group[ext_data_name] - # if "dict" in ext_data_.attrs: - # ext_data = ext_data_[0] - # elif "dataframe" in ext_data_.attrs: - # import xarray - # ext_data = xarray.open_zarr( - # ext_data_.store, group=f"{extension_group.name}/{ext_data_name}" - # ).to_pandas() - # ext_data.index.rename("", inplace=True) - # else: - # ext_data = ext_data_ - # self.data[ext_data_name] = ext_data + # Alessio + # TODO: we need decide if we make a copy to memory or keep the lazy loading. For binary_folder it used to be lazy with memmap + # but this make the garbage complicated when a data is hold by a plot but the o SortingResult is delete + # lets talk + extension_group = self._get_zarr_extension_group(mode='r') + for ext_data_name in extension_group.keys(): + ext_data_ = extension_group[ext_data_name] + if "dict" in ext_data_.attrs: + ext_data = ext_data_[0] + elif "dataframe" in ext_data_.attrs: + import xarray + ext_data = xarray.open_zarr( + ext_data_.store, group=f"{extension_group.name}/{ext_data_name}" + ).to_pandas() + ext_data.index.rename("", inplace=True) + else: + ext_data = ext_data_ + self.data[ext_data_name] = ext_data def copy(self, new_sorting_result, unit_ids=None): # alessio : please note that this also replace the old BaseWaveformExtractorExtension.select_units!!! @@ -1042,13 +1085,12 @@ def _save(self, **kwargs): return if self.sorting_result.is_read_only(): - raise ValueError("The SortingResult is read only save is not possible") + raise ValueError(f"The SortingResult is read only save extension {self.extension_name} is not possible") # delete already saved - self._reset_folder() + self._reset_extension_folder() self._save_params() - if self.format == "binary_folder": import pandas as pd @@ -1104,7 +1146,7 @@ def _save(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") - def _reset_folder(self): + def _reset_extension_folder(self): """ Delete the extension in folder (binary or zarr) and create an empty one. """ @@ -1116,16 +1158,15 @@ def _reset_folder(self): elif self.format == "zarr": import zarr - zarr_root = zarr.open(self.folder, mode="r+") - self.extension_group = zarr_root.create_group(self.extension_name, overwrite=True) + extension_group = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) def reset(self): """ Reset the waveform extension. Delete the sub folder and create a new empty one. """ - self._reset_folder() + self._reset_extension_folder() self.params = None self.data = dict() @@ -1135,6 +1176,10 @@ def set_params(self, **params): Set parameters for the extension and make it persistent in json. """ + # this ensure data is also deleted and corresponf to params + # this also ensure the group is created + self._reset_extension_folder() + params = self._set_params(**params) self.params = params @@ -1150,11 +1195,14 @@ def _save_params(self): params_to_save["sparsity"], ChannelSparsity ), "'sparsity' parameter must be a ChannelSparsity object!" params_to_save["sparsity"] = params_to_save["sparsity"].to_dict() + + if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() extension_folder.mkdir(exist_ok=True, parents=True) param_file = extension_folder / "params.json" param_file.write_text(json.dumps(check_json(params_to_save), indent=4), encoding="utf8") elif self.format == "zarr": - self.extension_group.attrs["params"] = check_json(params_to_save) + extension_group = self._get_zarr_extension_group(mode="r+") + extension_group.attrs["params"] = check_json(params_to_save) diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index 799ed58676..c7d471b266 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -53,8 +53,7 @@ def test_SortingResult_zarr(): shutil.rmtree(folder) sortres = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) - # sortres = load_sorting_result(folder, format="auto") - + sortres = load_sorting_result(folder, format="auto") # _check_sorting_results(sortres, sorting) @@ -85,11 +84,15 @@ def _check_sorting_results(sortres, original_sorting): ext = sortres.get_extension("dummy") assert ext is None + assert sortres.has_recording() # save to several format - for format in ("memory", "binary_folder", ): # "zarr" + for format in ("memory", "binary_folder", "zarr"): if format != "memory": - folder = cache_folder / f"test_SortingResult_save_as_{format}" + if format == "zarr": + folder = cache_folder / f"test_SortingResult_save_as_{format}.zarr" + else: + folder = cache_folder / f"test_SortingResult_save_as_{format}" if folder.exists(): shutil.rmtree(folder) else: @@ -99,6 +102,7 @@ def _check_sorting_results(sortres, original_sorting): sortres.compute("dummy") sortres2 = sortres.save_as(format=format, folder=folder) + print(sortres2.recording) ext = sortres2.get_extension("dummy") assert ext is not None @@ -107,9 +111,12 @@ def _check_sorting_results(sortres, original_sorting): assert data["result_two"].size == original_sorting.to_spike_vector().size # select unit_ids to several format - for format in ("memory", "binary_folder", ): # "zarr" + for format in ("memory", "binary_folder", "zarr"): if format != "memory": - folder = cache_folder / f"test_SortingResult_select_units_with{format}" + if format == "zarr": + folder = cache_folder / f"test_SortingResult_select_units_with_{format}.zarr" + else: + folder = cache_folder / f"test_SortingResult_select_units_with_{format}" if folder.exists(): shutil.rmtree(folder) else: From ff45ab735cd51f796e0a0c7b31db6c616dc27537 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 12 Jan 2024 21:15:08 +0100 Subject: [PATCH 006/192] Start ComputeWaveforms with ResultExtension --- src/spikeinterface/core/__init__.py | 7 +- src/spikeinterface/core/result_core.py | 141 ++++++++++++++++++ src/spikeinterface/core/sortingresult.py | 69 +++++++-- .../core/tests/test_result_core.py | 59 ++++++++ .../core/tests/test_sortingresult.py | 19 ++- 5 files changed, 278 insertions(+), 17 deletions(-) create mode 100644 src/spikeinterface/core/result_core.py create mode 100644 src/spikeinterface/core/tests/test_result_core.py diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 25f5d6ef53..8af29239df 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -133,4 +133,9 @@ # channel sparsity from .sparsity import ChannelSparsity, compute_sparsity -from .sortingresult import SortingResult, start_sorting_result, load_sorting_result +# SortingResult and +from .sortingresult import SortingResult, ResultExtension, start_sorting_result, load_sorting_result +from .result_core import ( + ComputeWaveforms, compute_waveforms, + ComputTemplates, compute_templates +) \ No newline at end of file diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py new file mode 100644 index 0000000000..bfd5954fd4 --- /dev/null +++ b/src/spikeinterface/core/result_core.py @@ -0,0 +1,141 @@ +""" +Implement ResultExtension that are essential and imported in core + * ComputeWaveforms + * ComputTemplates + +Theses two classes replace the WaveformExtractor + +""" + +import numpy as np + +from .sortingresult import ResultExtension, register_result_extension +from .waveform_tools import extract_waveforms_to_single_buffer + +class ComputeWaveforms(ResultExtension): + extension_name = "waveforms" + depend_on = [] + need_recording = True + use_nodepiepline = False + + def _run(self, **kwargs): + self.data.clear() + + recording = self.sorting_result.recording + sorting = self.sorting_result.sorting + # TODO handle sampling + spikes = sorting.to_spike_vector() + unit_ids = sorting.unit_ids + + nbefore = int(self.params["ms_before"] * sorting.sampling_frequency / 1000.0) + nafter = int(self.params["ms_after"] * sorting.sampling_frequency / 1000.0) + + + # TODO find a solution maybe using node pipeline : here the waveforms is directly written to memamap + # the save will delete the memmap write it again!! this will not work on window. + if self.format == "binary_folder": + # in that case waveforms are extacted directly in files + file_path = self._get_binary_extension_folder() / "waveforms" + mode = "memmap" + copy = False + else: + file_path = None + mode = "shared_memory" + copy = True + + if self.sparsity is None: + sparsity_mask = None + else: + sparsity_mask = self.sparsity.mask + + # TODO propagate some job_kwargs + job_kwargs = dict(n_jobs=-1) + + all_waveforms = extract_waveforms_to_single_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode=mode, + return_scaled=self.params["return_scaled"], + file_path=file_path, + dtype=self.params["dtype"], + sparsity_mask=sparsity_mask, + copy=copy, + job_name="compute_waveforms", + **job_kwargs, + ) + + self.data["all_waveforms"] = all_waveforms + + def _set_params(self, + ms_before: float = 1.0, + ms_after: float = 2.0, + max_spikes_per_unit: int = 500, + return_scaled: bool = False, + dtype=None, + ): + recording = self.sorting_result.recording + if dtype is None: + dtype = recording.get_dtype() + + if return_scaled: + # check if has scaled values: + if not recording.has_scaled(): + print("Setting 'return_scaled' to False") + return_scaled = False + + if np.issubdtype(dtype, np.integer) and return_scaled: + dtype = "float32" + + dtype = np.dtype(dtype) + + if max_spikes_per_unit is not None: + max_spikes_per_unit = int(max_spikes_per_unit) + + params = dict( + ms_before=float(ms_before), + ms_after=float(ms_after), + max_spikes_per_unit=max_spikes_per_unit, + return_scaled=return_scaled, + dtype=dtype.str, + ) + return params + + def _select_extension_data(self, unit_ids): + # must be implemented in subclass + raise NotImplementedError + + # keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + # spikes = self.sorting_result.sorting.to_spike_vector() + # keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) + + + + +compute_waveforms = ComputeWaveforms.function_factory() +register_result_extension(ComputeWaveforms) + +class ComputTemplates(ResultExtension): + extension_name = "templates" + depend_on = ["waveforms"] + need_recording = False + use_nodepiepline = False + + def _run(self, **kwargs): + # must be implemented in subclass + # must populate the self.data dictionary + raise NotImplementedError + + def _set_params(self, **params): + # must be implemented in subclass + # must return a cleaned version of params dict + raise NotImplementedError + + def _select_extension_data(self, unit_ids): + # must be implemented in subclass + raise NotImplementedError + +compute_templates = ComputTemplates.function_factory() +register_result_extension(ComputTemplates) \ No newline at end of file diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 8fd3df059c..1e19ca8ebe 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -713,6 +713,10 @@ def __repr__(self) -> str: txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" if self.is_sparse(): txt += " - sparse" + if self.has_recording(): + txt += " - has recording" + ext_txt = f"Load extenstions [{len(self.extensions)}]: " + ", ".join(self.extensions.keys()) + txt += "\n" + ext_txt return txt ## extensions zone @@ -738,16 +742,27 @@ def compute(self, extension_name, **params): >>> unit_location = extension.get_data() """ - # TODO check extension dependency + extension_class = get_extension_class(extension_name) + + # check dependencies + if extension_class.need_recording: + assert self.has_recording(), f"Extension {extension_name} need the recording" + for dependency_name in extension_class.depend_on: + ext = self.get_extension(dependency_name) + assert ext is not None, f"Extension {extension_name} need {dependency_name} to be computed first" + extension_instance = extension_class(self) extension_instance.set_params(**params) extension_instance.run() self.extensions[extension_name] = extension_instance + # TODO : need discussion return extension_instance + # OR + return extension_instance.data def get_saved_extension_names(self): """ @@ -924,16 +939,29 @@ class ResultExtension: Possible extension can be register on the fly at import time with register_result_extension() mechanism. It also enables any custum computation on top on SortingResult to be implemented by the user. - An extension needs to inherit from this class and implement some abstract methods: - * _set_params - * _run - * _select_extension_data + An extension needs to inherit from this class and implement some attributes and abstract methods: + * extension_name + * depend_on + * need_recording + * use_nodepiepline + * _set_params() + * _run() + * _select_extension_data() The subclass must also set an `extension_name` class attribute which is not None by default. - The subclass must also hanle an attribute `__data` which is a dict contain the results after the `run()`. + The subclass must also hanle an attribute `data` which is a dict contain the results after the `run()`. + + All ResultExtension will have a function associate for instance (this use the function_factory): + comptute_unit_location(sorting_result, ...) will be equivalent to sorting_result.compute("unit_location", ...) + + """ + extension_name = None + depend_on = [] + need_recording = False + use_nodepiepline = False def __init__(self, sorting_result): self._sorting_result = weakref.ref(sorting_result) @@ -960,6 +988,20 @@ def _select_extension_data(self, unit_ids): # ####### + @classmethod + def function_factory(cls): + # make equivalent + # comptute_unit_location(sorting_result, ...) <> sorting_result.compute("unit_location", ...) + class FuncWrapper: + def __init__(self, extension_name): + self.extension_name = extension_name + def __call__(self, sorting_result, *args, **kwargs): + return sorting_result.compute(self.extension_name, *args, **kwargs) + func = FuncWrapper(cls.extension_name) + # TODO : make docstring from class docstring + # TODO: add load_if_exists + return func + @property def sorting_result(self): # Important : to avoid the SortingResult referencing a ResultExtension @@ -1057,6 +1099,8 @@ def load_data(self): ext_data_.store, group=f"{extension_group.name}/{ext_data_name}" ).to_pandas() ext_data.index.rename("", inplace=True) + elif "object" in ext_data_.attrs: + ext_data = ext_data_[0] else: ext_data = ext_data_ self.data[ext_data_name] = ext_data @@ -1095,7 +1139,6 @@ def _save(self, **kwargs): import pandas as pd extension_folder = self._get_binary_extension_folder() - for ext_data_name, ext_data in self.data.items(): if isinstance(ext_data, dict): with (extension_folder / f"{ext_data_name}.json").open("w") as f: @@ -1126,9 +1169,8 @@ def _save(self, **kwargs): del extension_group[ext_data_name] if isinstance(ext_data, dict): extension_group.create_dataset( - name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() + name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() ) - extension_group[ext_data_name].attrs["dict"] = True elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) elif isinstance(ext_data, pd.DataFrame): @@ -1139,12 +1181,14 @@ def _save(self, **kwargs): ) extension_group[ext_data_name].attrs["dataframe"] = True else: + # any object try: extension_group.create_dataset( - name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() + name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle() ) except: raise Exception(f"Could not save {ext_data_name} as extension data") + extension_group[ext_data_name].attrs["object"] = True def _reset_extension_folder(self): """ @@ -1206,3 +1250,8 @@ def _save_params(self): extension_group = self._get_zarr_extension_group(mode="r+") extension_group.attrs["params"] = check_json(params_to_save) + + + + + diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py new file mode 100644 index 0000000000..b09c40e545 --- /dev/null +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -0,0 +1,59 @@ +import pytest +from pathlib import Path + +import shutil + +from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.core import start_sorting_result + +import numpy as np + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "core" +else: + cache_folder = Path("cache_folder") / "core" + + +def get_dataset(): + recording, sorting = generate_ground_truth_recording( + durations=[30.0], sampling_frequency=16000.0, num_channels=10, num_units=5, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + return recording, sorting + + + + +def test_ComputeWaveforms(format="memory"): + + if format == "memory": + folder = None + elif format == "binary_folder": + folder = cache_folder / f"test_ComputeWaveforms_{format}" + elif format == "zarr": + folder = cache_folder / f"test_ComputeWaveforms.zarr" + if folder and folder.exists(): + shutil.rmtree(folder) + + + + recording, sorting = get_dataset() + sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=None) + print(sortres) + + ext = sortres.compute("waveforms") + wfs = ext.data["all_waveforms"] + + print(wfs.shape) + + +def test_ComputTemplates(): + pass + +if __name__ == '__main__': + test_ComputeWaveforms(format="memory") + # test_ComputeWaveforms(format="binary_folder") + # test_ComputeWaveforms(format="zarr") + # test_ComputTemplates() \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index c7d471b266..243cd2758f 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -54,7 +54,7 @@ def test_SortingResult_zarr(): sortres = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) sortres = load_sorting_result(folder, format="auto") - # _check_sorting_results(sortres, sorting) + _check_sorting_results(sortres, sorting) @@ -71,12 +71,15 @@ def _check_sorting_results(sortres, original_sorting): probe = sortres.get_probe() sparsity = sortres.sparsity - + # compute sortres.compute("dummy", param1=5.5) + # equivalent + compute_dummy(sortres, param1=5.5) ext = sortres.get_extension("dummy") assert ext is not None assert ext.params["param1"] == 5.5 + print(sortres) # recompute sortres.compute("dummy", param1=5.5) # and delete @@ -85,7 +88,7 @@ def _check_sorting_results(sortres, original_sorting): assert ext is None assert sortres.has_recording() - + # save to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": @@ -102,7 +105,6 @@ def _check_sorting_results(sortres, original_sorting): sortres.compute("dummy") sortres2 = sortres.save_as(format=format, folder=folder) - print(sortres2.recording) ext = sortres2.get_extension("dummy") assert ext is not None @@ -131,11 +133,14 @@ def _check_sorting_results(sortres, original_sorting): data = sortres2.get_extension("dummy").data assert data["result_one"] == sortres.get_extension("dummy").data["result_one"] # unit 1, 3, ... should be removed - assert np.all(~np.isin(data["result_two"], [-1, -3])) + assert np.all(~np.isin(data["result_two"], [1, 3])) class DummyResultExtension(ResultExtension): extension_name = "dummy" + depend_on = [] + need_recording = False + use_nodepiepline = False def _set_params(self, param0="yep", param1=1.2, param2=[1,2, 3.]): params = dict(param0=param0, param1=param1, param2=param2) @@ -148,7 +153,7 @@ def _run(self, **kwargs): # the result two has the same size of the spike vector!! # and represent nothing (the trick is to use unit_index for testing slice) spikes = self.sorting_result.sorting.to_spike_vector() - self.data["result_two"] = spikes["unit_index"] * -1 + self.data["result_two"] = spikes["unit_index"].copy() def _select_extension_data(self, unit_ids): keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) @@ -163,6 +168,8 @@ def _select_extension_data(self, unit_ids): return new_data +compute_dummy = DummyResultExtension.function_factory() + class DummyResultExtension2(ResultExtension): extension_name = "dummy" From 4cb07f7dcd1454a01bdd3e829652c0714ba6cd72 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sun, 14 Jan 2024 20:27:04 +0100 Subject: [PATCH 007/192] Handle random spikes. --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/result_core.py | 9 +- src/spikeinterface/core/sortingresult.py | 100 +++++++++++++++++- .../core/tests/test_result_core.py | 3 +- .../core/tests/test_sortingresult.py | 40 ++++++- 5 files changed, 142 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 8af29239df..ed28a62a82 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -134,7 +134,7 @@ from .sparsity import ChannelSparsity, compute_sparsity # SortingResult and -from .sortingresult import SortingResult, ResultExtension, start_sorting_result, load_sorting_result +from .sortingresult import SortingResult, ResultExtension, start_sorting_result, load_sorting_result, random_spikes_selection from .result_core import ( ComputeWaveforms, compute_waveforms, ComputTemplates, compute_templates diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index bfd5954fd4..46e041956f 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -21,6 +21,11 @@ class ComputeWaveforms(ResultExtension): def _run(self, **kwargs): self.data.clear() + if self.sorting_result.random_spikes_indices is None: + raise ValueError("compute_waveforms need SortingResult.select_random_spikes() need to be run first") + + + recording = self.sorting_result.recording sorting = self.sorting_result.sorting # TODO handle sampling @@ -51,9 +56,11 @@ def _run(self, **kwargs): # TODO propagate some job_kwargs job_kwargs = dict(n_jobs=-1) + some_spikes = spikes[self.sorting_result.random_spikes_indices] + all_waveforms = extract_waveforms_to_single_buffer( recording, - spikes, + some_spikes, unit_ids, nbefore, nafter, diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 1e19ca8ebe..f9d844bb00 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -182,7 +182,7 @@ class SortingResult: the SortingResult object can be reload even if references to the original sorting and/or to the original recording are lost. """ - def __init__(self, sorting=None, recording=None, rec_attributes=None, format=None, sparsity=None): + def __init__(self, sorting=None, recording=None, rec_attributes=None, format=None, sparsity=None, random_spikes_indices=None): # very fast init because checks are done in load and create self.sorting = sorting # self.recorsding will be a property @@ -190,6 +190,7 @@ def __init__(self, sorting=None, recording=None, rec_attributes=None, format=Non self.rec_attributes = rec_attributes self.format = format self.sparsity = sparsity + self.random_spikes_indices = random_spikes_indices # extensions are not loaded at init self.extensions = dict() @@ -367,12 +368,19 @@ def load_from_binary_folder(cls, folder, recording=None): else: sparsity = None + selected_spike_file = folder / "random_spikes_indices.npy" + if sparsity_file.is_file(): + random_spikes_indices = np.load(selected_spike_file) + else: + random_spikes_indices = None + sortres = SortingResult( sorting=sorting, recording=recording, rec_attributes=rec_attributes, format="binary_folder", - sparsity=sparsity) + sparsity=sparsity, + random_spikes_indices=random_spikes_indices) return sortres @@ -499,12 +507,18 @@ def load_from_zarr(cls, folder, recording=None): else: sparsity = None + if "random_spikes_indices" in zarr_root.keys(): + random_spikes_indices = zarr_root["random_spikes_indices"] + else: + random_spikes_indices = None + sortres = SortingResult( sorting=sorting, recording=recording, rec_attributes=rec_attributes, format="zarr", - sparsity=sparsity) + sparsity=sparsity, + random_spikes_indices=random_spikes_indices) return sortres @@ -875,6 +889,19 @@ def has_extension(self, extension_name: str) -> bool: else: return False + ## random_spikes_selection zone + def select_random_spikes(self, **random_kwargs): + + assert self.random_spikes_indices is None, "select random spikes is already computed" + + self.random_spikes_indices = random_spikes_selection(self.sorting, self.rec_attributes["num_samples"], **random_kwargs) + + if self.format == "binary_folder": + np.save(self.folder / "random_spikes_indices.npy", self.random_spikes_indices) + elif self.format == "zarr": + zarr_root = self._get_zarr_root() + zarr_root.create_dataset("random_spikes_indices", data=self.random_spikes_indices) + global _possible_extensions _possible_extensions = [] @@ -1252,6 +1279,73 @@ def _save_params(self): +# TODO implement other method like "maximum_rate", "by_percent", ... +def random_spikes_selection(sorting, num_samples, + method="uniform", max_spikes_per_unit=500, + margin_size=None, + seed=None): + """ + This replace `select_random_spikes_uniformly()`. + + Random spikes selection of spike across per units. + + Can optionaly avoid spikes on segment borders. + If nbefore and nafter + + Parameters + ---------- + recording + + sorting + + max_spikes_per_unit + + method: "uniform" + + margin_size + + seed=None + + Returns + ------- + random_spikes_indices + Selected spike indicies corespond to the sorting spike vector. + """ + + rng =np.random.default_rng(seed=seed) + spikes = sorting.to_spike_vector() + + random_spikes_indices = [] + for unit_index, unit_id in enumerate(sorting.unit_ids): + all_unit_indices = np.flatnonzero(unit_index == spikes["unit_index"]) + + if method == "uniform": + selected_unit_indices = rng.choice(all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), + replace=False, shuffle=False) + else: + raise ValueError(f"random_spikes_selection wring method {method}") + + if margin_size is not None: + margin_size = int(margin_size) + keep = np.ones(selected_unit_indices.size, dtype=bool) + # left margin + keep[selected_unit_indices < margin_size] = False + # right margin + for segment_index in range(sorting.get_num_segments()): + remove_mask = np.flatnonzero((spikes[selected_unit_indices]["segment_index"] == segment_index) + &(spikes[selected_unit_indices]["sample_index"] >= (num_samples[segment_index] - margin_size)) + ) + keep[remove_mask] = False + selected_unit_indices = selected_unit_indices[keep] + + random_spikes_indices.append(selected_unit_indices) + + random_spikes_indices = np.concatenate(random_spikes_indices) + random_spikes_indices = np.sort(random_spikes_indices) + + return random_spikes_indices + + diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index b09c40e545..cb7b020772 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -37,12 +37,11 @@ def test_ComputeWaveforms(format="memory"): if folder and folder.exists(): shutil.rmtree(folder) - - recording, sorting = get_dataset() sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=None) print(sortres) + sortres.select_random_spikes(max_spikes_per_unit=50, seed=2205) ext = sortres.compute("waveforms") wfs = ext.data["all_waveforms"] diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index 243cd2758f..55163e0cec 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -5,7 +5,7 @@ from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import SortingResult, start_sorting_result, load_sorting_result -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension, random_spikes_selection import numpy as np @@ -41,7 +41,6 @@ def test_SortingResult_binary_folder(): sortres = start_sorting_result(sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None) sortres = load_sorting_result(folder, format="auto") - _check_sorting_results(sortres, sorting) @@ -87,8 +86,14 @@ def _check_sorting_results(sortres, original_sorting): ext = sortres.get_extension("dummy") assert ext is None - assert sortres.has_recording() + assert sortres.has_recording() + + if sortres.random_spikes_indices is None: + sortres.select_random_spikes(max_spikes_per_unit=10, seed=2205) + assert sortres.random_spikes_indices is not None + assert sortres.random_spikes_indices.size == 10 * sortres.sorting.unit_ids.size + # save to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": @@ -184,9 +189,34 @@ def test_extension(): with pytest.raises(AssertionError): register_result_extension(DummyResultExtension2) +def test_random_spikes_selection(): + recording, sorting = get_dataset() + + max_spikes_per_unit = 12 + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + + random_spikes_indices = random_spikes_selection(sorting, num_samples, + method="uniform", max_spikes_per_unit=max_spikes_per_unit, + margin_size=None, + seed=2205) + spikes = sorting.to_spike_vector() + some_spikes = spikes[random_spikes_indices] + for unit_index, unit_id in enumerate(sorting.unit_ids): + spike_slected_unit = some_spikes[some_spikes["unit_index"] == unit_index] + assert spike_slected_unit.size == max_spikes_per_unit + + # with margin + random_spikes_indices = random_spikes_selection(sorting, num_samples, + method="uniform", max_spikes_per_unit=max_spikes_per_unit, + margin_size=25, + seed=2205) + # in that case the number is not garanty so it can be a bit less + assert random_spikes_indices.size >= (0.9 * sorting.unit_ids.size * max_spikes_per_unit) + if __name__ == "__main__": - # test_SortingResult_memory() - # test_SortingResult_binary_folder() + test_SortingResult_memory() + test_SortingResult_binary_folder() test_SortingResult_zarr() # test_extension() + # test_random_spikes_selection() From 2b49c830c1c7c52c6c507ef6ea0e0558d87e05de Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 15 Jan 2024 08:56:02 +0100 Subject: [PATCH 008/192] Avoid saveing twice memmap --- src/spikeinterface/core/result_core.py | 4 +- src/spikeinterface/core/sortingresult.py | 42 ++++++++++++------- .../core/tests/test_result_core.py | 7 ++-- .../core/tests/test_sortingresult.py | 4 +- 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 46e041956f..f10eed5f3b 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -40,7 +40,7 @@ def _run(self, **kwargs): # the save will delete the memmap write it again!! this will not work on window. if self.format == "binary_folder": # in that case waveforms are extacted directly in files - file_path = self._get_binary_extension_folder() / "waveforms" + file_path = self._get_binary_extension_folder() / "waveforms.npy" mode = "memmap" copy = False else: @@ -74,7 +74,7 @@ def _run(self, **kwargs): **job_kwargs, ) - self.data["all_waveforms"] = all_waveforms + self.data["waveforms"] = all_waveforms def _set_params(self, ms_before: float = 1.0, diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index f9d844bb00..f7e0cd9b32 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -1140,28 +1140,30 @@ def copy(self, new_sorting_result, unit_ids=None): new_extension.data = self.data else: new_extension.data = self._select_extension_data(unit_ids) - new_extension._save() + new_extension.save() return new_extension def run(self, **kwargs): + if not self.sorting_result.is_read_only(): + # this also reset the folder or zarr group + self._save_params() + self._run(**kwargs) + if not self.sorting_result.is_read_only(): - self._save(**kwargs) + self._save_data(**kwargs) def save(self, **kwargs): - self._save(**kwargs) + self._save_params() + self._save_data(**kwargs) - def _save(self, **kwargs): + def _save_data(self, **kwargs): if self.format == "memory": return if self.sorting_result.is_read_only(): raise ValueError(f"The SortingResult is read only save extension {self.extension_name} is not possible") - # delete already saved - self._reset_extension_folder() - self._save_params() - if self.format == "binary_folder": import pandas as pd @@ -1171,7 +1173,14 @@ def _save(self, **kwargs): with (extension_folder / f"{ext_data_name}.json").open("w") as f: json.dump(ext_data, f) elif isinstance(ext_data, np.ndarray): - np.save(extension_folder / f"{ext_data_name}.npy", ext_data) + # important some SortingResult like ComputeWaveforms already run the computation with memmap + # so no need to save theses array + data_file = extension_folder / f"{ext_data_name}.npy" + if isinstance(ext_data, np.memmap) and data_file.exists(): + pass + print("no save") + else: + np.save(data_file, ext_data) elif isinstance(ext_data, pd.DataFrame): ext_data.to_csv(extension_folder / f"{ext_data_name}.csv", index=True) else: @@ -1257,15 +1266,20 @@ def set_params(self, **params): if self.sorting_result.is_read_only(): return + self._save_params() def _save_params(self): params_to_save = self.params.copy() - if "sparsity" in params_to_save and params_to_save["sparsity"] is not None: - assert isinstance( - params_to_save["sparsity"], ChannelSparsity - ), "'sparsity' parameter must be a ChannelSparsity object!" - params_to_save["sparsity"] = params_to_save["sparsity"].to_dict() + + self._reset_extension_folder() + + # TODO make sparsity local Result specific + # if "sparsity" in params_to_save and params_to_save["sparsity"] is not None: + # assert isinstance( + # params_to_save["sparsity"], ChannelSparsity + # ), "'sparsity' parameter must be a ChannelSparsity object!" + # params_to_save["sparsity"] = params_to_save["sparsity"].to_dict() if self.format == "binary_folder": diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index cb7b020772..d84174dd3a 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -43,16 +43,17 @@ def test_ComputeWaveforms(format="memory"): sortres.select_random_spikes(max_spikes_per_unit=50, seed=2205) ext = sortres.compute("waveforms") - wfs = ext.data["all_waveforms"] + wfs = ext.data["waveforms"] print(wfs.shape) + print(sortres) def test_ComputTemplates(): pass if __name__ == '__main__': - test_ComputeWaveforms(format="memory") - # test_ComputeWaveforms(format="binary_folder") + # test_ComputeWaveforms(format="memory") + test_ComputeWaveforms(format="binary_folder") # test_ComputeWaveforms(format="zarr") # test_ComputTemplates() \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index 55163e0cec..18f910f3c2 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -215,8 +215,8 @@ def test_random_spikes_selection(): if __name__ == "__main__": - test_SortingResult_memory() + # test_SortingResult_memory() test_SortingResult_binary_folder() - test_SortingResult_zarr() + # test_SortingResult_zarr() # test_extension() # test_random_spikes_selection() From 789ec911e8bd9f74a0e733581a87b2b62865ecf4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 15 Jan 2024 17:11:21 +0100 Subject: [PATCH 009/192] implementation of ComputeTemplates --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/result_core.py | 118 +++++++++++++----- src/spikeinterface/core/sortingresult.py | 9 +- .../core/tests/test_result_core.py | 89 ++++++++++--- 4 files changed, 165 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index ed28a62a82..1a06563df3 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -137,5 +137,5 @@ from .sortingresult import SortingResult, ResultExtension, start_sorting_result, load_sorting_result, random_spikes_selection from .result_core import ( ComputeWaveforms, compute_waveforms, - ComputTemplates, compute_templates + ComputeTemplates, compute_templates ) \ No newline at end of file diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index f10eed5f3b..198eb9717f 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -1,7 +1,7 @@ """ Implement ResultExtension that are essential and imported in core * ComputeWaveforms - * ComputTemplates + * ComputeTemplates Theses two classes replace the WaveformExtractor @@ -24,20 +24,17 @@ def _run(self, **kwargs): if self.sorting_result.random_spikes_indices is None: raise ValueError("compute_waveforms need SortingResult.select_random_spikes() need to be run first") - - recording = self.sorting_result.recording sorting = self.sorting_result.sorting - # TODO handle sampling - spikes = sorting.to_spike_vector() unit_ids = sorting.unit_ids + # retrieve spike vector and the sampling + spikes = sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + nbefore = int(self.params["ms_before"] * sorting.sampling_frequency / 1000.0) nafter = int(self.params["ms_after"] * sorting.sampling_frequency / 1000.0) - - # TODO find a solution maybe using node pipeline : here the waveforms is directly written to memamap - # the save will delete the memmap write it again!! this will not work on window. if self.format == "binary_folder": # in that case waveforms are extacted directly in files file_path = self._get_binary_extension_folder() / "waveforms.npy" @@ -56,8 +53,6 @@ def _run(self, **kwargs): # TODO propagate some job_kwargs job_kwargs = dict(n_jobs=-1) - some_spikes = spikes[self.sorting_result.random_spikes_indices] - all_waveforms = extract_waveforms_to_single_buffer( recording, some_spikes, @@ -111,12 +106,15 @@ def _set_params(self, return params def _select_extension_data(self, unit_ids): - # must be implemented in subclass - raise NotImplementedError + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + spikes = self.sorting_result.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + keep_spike_mask = np.isin(some_spikes["unit_index"], keep_unit_indices) - # keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) - # spikes = self.sorting_result.sorting.to_spike_vector() - # keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) + new_data = dict() + new_data["waveforms"] = self.data["waveforms"][keep_spike_mask, :, :] + + return new_data @@ -124,25 +122,89 @@ def _select_extension_data(self, unit_ids): compute_waveforms = ComputeWaveforms.function_factory() register_result_extension(ComputeWaveforms) -class ComputTemplates(ResultExtension): +class ComputeTemplates(ResultExtension): extension_name = "templates" depend_on = ["waveforms"] need_recording = False use_nodepiepline = False def _run(self, **kwargs): - # must be implemented in subclass - # must populate the self.data dictionary - raise NotImplementedError - - def _set_params(self, **params): - # must be implemented in subclass - # must return a cleaned version of params dict - raise NotImplementedError + + + if self.sparsity is not None: + # TODO handle sparsity + raise NotImplementedError + + unit_ids = self.sorting_result.sorting.unit_ids + waveforms_extension = self.sorting_result.get_extension("waveforms") + waveforms = waveforms_extension.data["waveforms"] + + num_samples = waveforms.shape[1] + # channel can be sparse or not + num_channels = waveforms.shape[2] + + + for operator in self.params["operators"]: + if isinstance(operator, str) and operator in ("average", "std", "median"): + key = operator + elif isinstance(operator, (list, tuple)): + operator, percentile = operator + assert operator == "percentile" + key = f"pencentile_{percentile}" + else: + raise ValueError(f"ComputeTemplates: wrong operator {operator}") + self.data[key] = np.zeros((unit_ids.size, num_samples, num_channels)) + + spikes = self.sorting_result.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + for unit_index, unit_id in enumerate(unit_ids): + spike_mask = some_spikes["unit_index"] == unit_index + wfs = waveforms[spike_mask, :, :] + if wfs.shape[0] == 0: + continue + + for operator in self.params["operators"]: + if operator == "average": + arr = np.average(wfs, axis=0) + key = operator + elif operator == "std": + arr = np.std(wfs, axis=0) + key = operator + elif operator == "median": + arr = np.median(wfs, axis=0) + key = operator + elif isinstance(operator, (list, tuple)): + operator, percentile = operator + arr = np.percentile(wfs, percentile, axis=0) + key = f"pencentile_{percentile}" + + self.data[key][unit_index, :, :] = arr + + + def _set_params(self, operators = ["average", "std"]): + assert isinstance(operators, list) + for operator in operators: + if isinstance(operator, str): + assert operator in ("average", "std", "median", "mad") + else: + assert isinstance(operator, (list, tuple)) + assert len(operator) == 2 + assert operator[0] == "percentile" + + params = dict(operators=operators) + return params def _select_extension_data(self, unit_ids): - # must be implemented in subclass - raise NotImplementedError + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + + new_data = dict() + for key, arr in self.data.items(): + new_data[key] = arr[keep_unit_indices, :, :] + + return new_data + + + -compute_templates = ComputTemplates.function_factory() -register_result_extension(ComputTemplates) \ No newline at end of file +compute_templates = ComputeTemplates.function_factory() +register_result_extension(ComputeTemplates) \ No newline at end of file diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index f7e0cd9b32..c97dc485f8 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -569,7 +569,7 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> return new_sortres - def save_as(self, format="binary_folder", folder=None) -> "SortingResult": + def save_as(self, format="memory", folder=None) -> "SortingResult": """ Save SortingResult object into another format. Uselfull for memory to zarr or memory to binray. @@ -588,7 +588,7 @@ def save_as(self, format="binary_folder", folder=None) -> "SortingResult": return self._save_or_select(format=format, folder=folder, unit_ids=None) - def select_units(self, unit_ids, folder=None, format="binary_folder") -> "SortingResult": + def select_units(self, unit_ids, format="memory", folder=None) -> "SortingResult": """ This method is equivalent to `save_as()`but with a subset of units. Filters units by creating a new waveform extractor object in a new folder. @@ -1173,12 +1173,11 @@ def _save_data(self, **kwargs): with (extension_folder / f"{ext_data_name}.json").open("w") as f: json.dump(ext_data, f) elif isinstance(ext_data, np.ndarray): - # important some SortingResult like ComputeWaveforms already run the computation with memmap - # so no need to save theses array data_file = extension_folder / f"{ext_data_name}.npy" if isinstance(ext_data, np.memmap) and data_file.exists(): + # important some SortingResult like ComputeWaveforms already run the computation with memmap + # so no need to save theses array pass - print("no save") else: np.save(data_file, ext_data) elif isinstance(ext_data, pd.DataFrame): diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index d84174dd3a..8bb02b9248 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -14,20 +14,13 @@ cache_folder = Path("cache_folder") / "core" -def get_dataset(): +def get_sorting_result(format="memory"): recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, num_channels=10, num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), seed=2205, ) - return recording, sorting - - - - -def test_ComputeWaveforms(format="memory"): - if format == "memory": folder = None elif format == "binary_folder": @@ -36,24 +29,82 @@ def test_ComputeWaveforms(format="memory"): folder = cache_folder / f"test_ComputeWaveforms.zarr" if folder and folder.exists(): shutil.rmtree(folder) - - recording, sorting = get_dataset() - sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=None) - print(sortres) + + sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=None) + + return sortres + + +def _check_result_extension(sortres, extension_name): + + + # select unit_ids to several format + # for format in ("memory", "binary_folder", "zarr"): + for format in ("memory", ): + if format != "memory": + if format == "zarr": + folder = cache_folder / f"test_SortingResult_{extension_name}_select_units_with_{format}.zarr" + else: + folder = cache_folder / f"test_SortingResult_{extension_name}_select_units_with_{format}" + if folder.exists(): + shutil.rmtree(folder) + else: + folder = None + + # check unit slice + keep_unit_ids = sortres.sorting.unit_ids[::2] + sortres2 = sortres.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) + + data = sortres2.get_extension(extension_name).data + # for k, arr in data.items(): + # print(k, arr.shape) + + +def test_ComputeWaveforms(format="memory"): + sortres = get_sorting_result(format=format) sortres.select_random_spikes(max_spikes_per_unit=50, seed=2205) ext = sortres.compute("waveforms") wfs = ext.data["waveforms"] + _check_result_extension(sortres, "waveforms") + + +def test_ComputeTemplates(format="memory"): + sortres = get_sorting_result(format=format) + + sortres.select_random_spikes(max_spikes_per_unit=20, seed=2205) + + with pytest.raises(AssertionError): + # This require "waveforms first and should trig an error + sortres.compute("templates") + + sortres.compute("waveforms") + waveforms = sortres.get_extension("waveforms").data["waveforms"] + sortres.compute("templates", operators=["average", "std", "median", ("percentile", 5.), ("percentile", 95.),]) + + + data = sortres.get_extension("templates").data + for k in ['average', 'std', 'median', 'pencentile_5.0', 'pencentile_95.0']: + assert k in data.keys() + assert data[k].shape[0] == sortres.sorting.unit_ids.size + assert data[k].shape[1] == waveforms.shape[1] - print(wfs.shape) - print(sortres) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # unit_index = 2 + # for k in data.keys(): + # wf0 = data[k][unit_index, :, :] + # ax.plot(wf0.T.flatten(), label=k) + # ax.legend() + # plt.show() -def test_ComputTemplates(): - pass + _check_result_extension(sortres, "templates") if __name__ == '__main__': # test_ComputeWaveforms(format="memory") - test_ComputeWaveforms(format="binary_folder") - # test_ComputeWaveforms(format="zarr") - # test_ComputTemplates() \ No newline at end of file + # test_ComputeWaveforms(format="binary_folder") + # test_ComputeWaveforms(format="zarr") + test_ComputeTemplates(format="memory") + # test_ComputeTemplates(format="binary_folder") + # test_ComputeTemplates(format="zarr") From 70f11403fefe04106cf1bc679ab1f5fe3d806517 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 15 Jan 2024 17:26:24 +0100 Subject: [PATCH 010/192] oups --- src/spikeinterface/core/sortingresult.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index c97dc485f8..f8135f3fc2 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -27,7 +27,6 @@ # TODO # * make info.json that contain some version info of spikeinterface # * same for zarr -# * sample spikes and propagate in compute with option @@ -533,6 +532,11 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> else: recording = None + if self.sparsity is not None: + # handle sparsity propagation and slicing + raise NotImplementedError + sparsity = None + # Note that the sorting is a copy we need to go back to the orginal sorting (if available) sorting_provenance = self.get_sorting_provenance() if sorting_provenance is None: @@ -545,18 +549,18 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> if format == "memory": # This make a copy of actual SortingResult - new_sortres = SortingResult.create_memory(sorting_provenance, recording, self.sparsity, self.rec_attributes) + new_sortres = SortingResult.create_memory(sorting_provenance, recording, sparsity, self.rec_attributes) elif format == "binary_folder": # create a new folder assert folder is not None, "For format='binary_folder' folder must be provided" - SortingResult.create_binary_folder(folder, sorting_provenance, recording, self.sparsity, self.rec_attributes) + SortingResult.create_binary_folder(folder, sorting_provenance, recording, sparsity, self.rec_attributes) new_sortres = SortingResult.load_from_binary_folder(folder, recording=recording) new_sortres.folder = folder elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" - SortingResult.create_zarr(folder, sorting_provenance, recording, self.sparsity, self.rec_attributes) + SortingResult.create_zarr(folder, sorting_provenance, recording, sparsity, self.rec_attributes) new_sortres = SortingResult.load_from_zarr(folder, recording=recording) new_sortres.folder = folder else: @@ -614,7 +618,7 @@ def copy(self): """ Create a a copy of SortingResult with format "memory". """ - return self._save_or_select(format="binary_folder", folder=None, unit_ids=None) + return self._save_or_select(format="memory", folder=None, unit_ids=None) def is_read_only(self) -> bool: if self.format == "memory": From 1d13b818ca3350e11ce55f8fe50fb78287c0a39c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 16 Jan 2024 08:40:11 +0100 Subject: [PATCH 011/192] SortingResult save mask instead of dict --- src/spikeinterface/core/sortingresult.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index f8135f3fc2..c1d32e79bd 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -318,8 +318,9 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attribut probeinterface.write_probeinterface(probegroup_file, probegroup) if sparsity is not None: - with open(folder / "sparsity.json", mode="w") as f: - json.dump(check_json(sparsity.to_dict()), f) + np.save(folder / "sparsity_mask.npy", sparsity.mask) + # with open(folder / "sparsity.json", mode="w") as f: + # json.dump(check_json(sparsity.to_dict()), f) @classmethod def load_from_binary_folder(cls, folder, recording=None): @@ -360,10 +361,13 @@ def load_from_binary_folder(cls, folder, recording=None): rec_attributes["probegroup"] = None # sparsity - sparsity_file = folder / "sparsity.json" + # sparsity_file = folder / "sparsity.json" + sparsity_file = folder / "sparsity_mask.npy" if sparsity_file.is_file(): - with open(sparsity_file, mode="r") as f: - sparsity = ChannelSparsity.from_dict(json.load(f)) + sparsity_mask = np.load(sparsity_file) + # with open(sparsity_file, mode="r") as f: + # sparsity = ChannelSparsity.from_dict(json.load(f)) + sparsity = ChannelSparsity(zarr_root["sparsity_mask"], rec_attributes["unit_ids"], rec_attributes["channel_ids"]) else: sparsity = None @@ -451,8 +455,9 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): # recording_info.create_dataset("probegroup", data=check_json(probegroup.to_dict()), object_codec=numcodecs.JSON()) if sparsity is not None: - zarr_root.attrs["sparsity"] = check_json(sparsity.to_dict()) + # zarr_root.attrs["sparsity"] = check_json(sparsity.to_dict()) # zarr_root.create_dataset("sparsity", data=check_json(sparsity.to_dict()), object_codec=numcodecs.JSON()) + zarr_root.create_dataset("sparsity_mask", data=sparsity.mask) # write sorting copy from .zarrextractors import add_sorting_to_zarr_group @@ -500,9 +505,9 @@ def load_from_zarr(cls, folder, recording=None): rec_attributes["probegroup"] = None # sparsity - if "sparsity" in zarr_root.attrs: + if "sparsity_mask" in zarr_root.attrs: # sparsity = zarr_root.attrs["sparsity"] - sparsity = zarr_root["sparsity"] + sparsity = ChannelSparsity(zarr_root["sparsity_mask"], rec_attributes["unit_ids"], rec_attributes["channel_ids"]) else: sparsity = None From 6bec81fb4cbba98e5df3809df32c77c9eab9bcfa Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 17 Jan 2024 16:37:12 +0100 Subject: [PATCH 012/192] Handle sparsity in SortingResult and ResultExtension. --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/result_core.py | 20 +-- src/spikeinterface/core/sortingresult.py | 140 +++--------------- .../core/tests/test_result_core.py | 70 +++++---- .../core/tests/test_sorting_tools.py | 2 +- .../core/tests/test_sortingresult.py | 36 +---- 6 files changed, 82 insertions(+), 188 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index e1ad6ccf00..4ffdc505bb 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -142,7 +142,7 @@ from .template import Templates # SortingResult and ResultExtension -from .sortingresult import SortingResult, ResultExtension, start_sorting_result, load_sorting_result, random_spikes_selection +from .sortingresult import SortingResult, ResultExtension, start_sorting_result, load_sorting_result from .result_core import ( ComputeWaveforms, compute_waveforms, ComputeTemplates, compute_templates diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 198eb9717f..b242d4bdc9 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -130,19 +130,12 @@ class ComputeTemplates(ResultExtension): def _run(self, **kwargs): - - if self.sparsity is not None: - # TODO handle sparsity - raise NotImplementedError - - unit_ids = self.sorting_result.sorting.unit_ids + unit_ids = self.sorting_result.unit_ids + channel_ids = self.sorting_result.channel_ids waveforms_extension = self.sorting_result.get_extension("waveforms") waveforms = waveforms_extension.data["waveforms"] num_samples = waveforms.shape[1] - # channel can be sparse or not - num_channels = waveforms.shape[2] - for operator in self.params["operators"]: if isinstance(operator, str) and operator in ("average", "std", "median"): @@ -153,7 +146,7 @@ def _run(self, **kwargs): key = f"pencentile_{percentile}" else: raise ValueError(f"ComputeTemplates: wrong operator {operator}") - self.data[key] = np.zeros((unit_ids.size, num_samples, num_channels)) + self.data[key] = np.zeros((unit_ids.size, num_samples, channel_ids.size)) spikes = self.sorting_result.sorting.to_spike_vector() some_spikes = spikes[self.sorting_result.random_spikes_indices] @@ -177,9 +170,12 @@ def _run(self, **kwargs): operator, percentile = operator arr = np.percentile(wfs, percentile, axis=0) key = f"pencentile_{percentile}" - - self.data[key][unit_index, :, :] = arr + if self.sparsity is None: + self.data[key][unit_index, :, :] = arr + else: + channel_indices = self.sparsity.unit_id_to_channel_indices[unit_id] + self.data[key][unit_index, :, :][:, channel_indices] = arr[:, :channel_indices.size] def _set_params(self, operators = ["average", "std"]): assert isinstance(operators, list) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index c1d32e79bd..bf92c87942 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -18,12 +18,14 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes +from .sorting_tools import random_spikes_selection from .core_tools import check_json from .numpyextractors import SharedMemorySorting -from .sparsity import ChannelSparsity +from .sparsity import ChannelSparsity, estimate_sparsity from .sortingfolder import NumpyFolderSorting from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor + # TODO # * make info.json that contain some version info of spikeinterface # * same for zarr @@ -32,8 +34,7 @@ # high level function def start_sorting_result(sorting, recording, format="memory", folder=None, - sparse=True, sparsity=None, - # **kwargs + sparse=True, sparsity=None, **sparsity_kwargs ): """ Create a SortingResult by pairing a Sorting and the corresponding Recording. @@ -57,28 +58,12 @@ def start_sorting_result(sorting, recording, format="memory", folder=None, The "folder" argument must be specified in case of mode "folder". If "memory" is used, the waveforms are stored in RAM. Use this option carefully! sparse: bool, default: True - If True, then a sparsity mask is computed usingthe `precompute_sparsity()` function is run using + If True, then a sparsity mask is computed usingthe `estimate_sparsity()` function is run using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the sparsity will be propagated to all ResultExtention that handle sparsity (like wavforms, pca, ...) + You can control `estimate_sparsity()` : all extra arguments are propagated to it (included job_kwargs) sparsity: ChannelSparsity or None, default: None The sparsity used to compute waveforms. If this is given, `sparse` is ignored. Default None. - sparsity_temp_folder: str or Path or None, default: None - If sparse is True, this is the temporary folder where the dense waveforms are temporarily saved. - If None, dense waveforms are extracted in memory in batches (which can be controlled by the `unit_batch_size` - parameter. With a large number of units (e.g., > 400), it is advisable to use a temporary folder. - num_spikes_for_sparsity: int, default: 100 - The number of spikes to use to estimate sparsity (if sparse=True). - unit_batch_size: int, default: 200 - The number of units to process at once when extracting dense waveforms (if sparse=True and sparsity_temp_folder - is None). - - sparsity kwargs: - {} - - - job kwargs: - {} - Returns ------- @@ -109,28 +94,12 @@ def start_sorting_result(sorting, recording, format="memory", folder=None, # handle sparsity if sparsity is not None: + # some checks assert isinstance(sparsity, ChannelSparsity), "'sparsity' must be a ChannelSparsity object" - unit_id_to_channel_ids = sparsity.unit_id_to_channel_ids - assert all(u in sorting.unit_ids for u in unit_id_to_channel_ids), "Invalid unit ids in sparsity" - for channels in unit_id_to_channel_ids.values(): - assert all(ch in recording.channel_ids for ch in channels), "Invalid channel ids in sparsity" + assert np.array_equal(sorting.unit_ids, sparsity.unit_ids), "start_sorting_result(): if external sparsity is given unit_ids must correspond" + assert np.array_equal(recording.channel_ids, recording.channel_ids), "start_sorting_result(): if external sparsity is given unit_ids must correspond" elif sparse: - # TODO - # raise NotImplementedError() - sparsity = None - # estimate_kwargs, job_kwargs = split_job_kwargs(kwargs) - # sparsity = precompute_sparsity( - # recording, - # sorting, - # ms_before=ms_before, - # ms_after=ms_after, - # num_spikes_for_sparsity=num_spikes_for_sparsity, - # unit_batch_size=unit_batch_size, - # temp_folder=sparsity_temp_folder, - # allow_unfiltered=allow_unfiltered, - # **estimate_kwargs, - # **job_kwargs, - # ) + sparsity = estimate_sparsity( recording, sorting, **sparsity_kwargs) else: sparsity = None @@ -367,12 +336,12 @@ def load_from_binary_folder(cls, folder, recording=None): sparsity_mask = np.load(sparsity_file) # with open(sparsity_file, mode="r") as f: # sparsity = ChannelSparsity.from_dict(json.load(f)) - sparsity = ChannelSparsity(zarr_root["sparsity_mask"], rec_attributes["unit_ids"], rec_attributes["channel_ids"]) + sparsity = ChannelSparsity(sparsity_mask, sorting.unit_ids, rec_attributes["channel_ids"]) else: sparsity = None selected_spike_file = folder / "random_spikes_indices.npy" - if sparsity_file.is_file(): + if selected_spike_file.is_file(): random_spikes_indices = np.load(selected_spike_file) else: random_spikes_indices = None @@ -507,7 +476,7 @@ def load_from_zarr(cls, folder, recording=None): # sparsity if "sparsity_mask" in zarr_root.attrs: # sparsity = zarr_root.attrs["sparsity"] - sparsity = ChannelSparsity(zarr_root["sparsity_mask"], rec_attributes["unit_ids"], rec_attributes["channel_ids"]) + sparsity = ChannelSparsity(zarr_root["sparsity_mask"], self.unit_ids, rec_attributes["channel_ids"]) else: sparsity = None @@ -537,11 +506,14 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> else: recording = None - if self.sparsity is not None: - # handle sparsity propagation and slicing - raise NotImplementedError - sparsity = None - + if self.sparsity is not None and unit_ids is None: + sparsity = self.sparsity + elif self.sparsity is not None and unit_ids is not None: + sparsity_mask = self.sparsity.mask[np.isin(self.unit_ids, unit_ids), :] + sparsity = ChannelSparsity(sparsity_mask, unit_ids, self.channel_ids) + else: + sparsity = None + # Note that the sorting is a copy we need to go back to the orginal sorting (if available) sorting_provenance = self.get_sorting_provenance() if sorting_provenance is None: @@ -1301,73 +1273,3 @@ def _save_params(self): -# TODO implement other method like "maximum_rate", "by_percent", ... -def random_spikes_selection(sorting, num_samples, - method="uniform", max_spikes_per_unit=500, - margin_size=None, - seed=None): - """ - This replace `select_random_spikes_uniformly()`. - - Random spikes selection of spike across per units. - - Can optionaly avoid spikes on segment borders. - If nbefore and nafter - - Parameters - ---------- - recording - - sorting - - max_spikes_per_unit - - method: "uniform" - - margin_size - - seed=None - - Returns - ------- - random_spikes_indices - Selected spike indicies corespond to the sorting spike vector. - """ - - rng =np.random.default_rng(seed=seed) - spikes = sorting.to_spike_vector() - - random_spikes_indices = [] - for unit_index, unit_id in enumerate(sorting.unit_ids): - all_unit_indices = np.flatnonzero(unit_index == spikes["unit_index"]) - - if method == "uniform": - selected_unit_indices = rng.choice(all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), - replace=False, shuffle=False) - else: - raise ValueError(f"random_spikes_selection wring method {method}") - - if margin_size is not None: - margin_size = int(margin_size) - keep = np.ones(selected_unit_indices.size, dtype=bool) - # left margin - keep[selected_unit_indices < margin_size] = False - # right margin - for segment_index in range(sorting.get_num_segments()): - remove_mask = np.flatnonzero((spikes[selected_unit_indices]["segment_index"] == segment_index) - &(spikes[selected_unit_indices]["sample_index"] >= (num_samples[segment_index] - margin_size)) - ) - keep[remove_mask] = False - selected_unit_indices = selected_unit_indices[keep] - - random_spikes_indices.append(selected_unit_indices) - - random_spikes_indices = np.concatenate(random_spikes_indices) - random_spikes_indices = np.sort(random_spikes_indices) - - return random_spikes_indices - - - - - diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index 8bb02b9248..7ae6c40d86 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -14,12 +14,22 @@ cache_folder = Path("cache_folder") / "core" -def get_sorting_result(format="memory"): +def get_sorting_result(format="memory", sparse=True): recording, sorting = generate_ground_truth_recording( - durations=[30.0], sampling_frequency=16000.0, num_channels=10, num_units=5, + durations=[30.0], sampling_frequency=16000.0, num_channels=20, num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params_range=dict( + alpha=(9_000.0, 12_000.0), + ) + ), noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), - seed=2205, + seed=2406, ) if format == "memory": folder = None @@ -30,7 +40,7 @@ def get_sorting_result(format="memory"): if folder and folder.exists(): shutil.rmtree(folder) - sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=None) + sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=sparse, sparsity=None) return sortres @@ -39,8 +49,8 @@ def _check_result_extension(sortres, extension_name): # select unit_ids to several format - # for format in ("memory", "binary_folder", "zarr"): - for format in ("memory", ): + for format in ("memory", "binary_folder", "zarr"): + # for format in ("memory", ): if format != "memory": if format == "zarr": folder = cache_folder / f"test_SortingResult_{extension_name}_select_units_with_{format}.zarr" @@ -60,8 +70,10 @@ def _check_result_extension(sortres, extension_name): # print(k, arr.shape) -def test_ComputeWaveforms(format="memory"): - sortres = get_sorting_result(format=format) +@pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) +@pytest.mark.parametrize("sparse", [True, False]) +def test_ComputeWaveforms(format, sparse): + sortres = get_sorting_result(format=format, sparse=sparse) sortres.select_random_spikes(max_spikes_per_unit=50, seed=2205) ext = sortres.compute("waveforms") @@ -69,8 +81,10 @@ def test_ComputeWaveforms(format="memory"): _check_result_extension(sortres, "waveforms") -def test_ComputeTemplates(format="memory"): - sortres = get_sorting_result(format=format) +@pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) +@pytest.mark.parametrize("sparse", [True, False]) +def test_ComputeTemplates(format, sparse): + sortres = get_sorting_result(format=format, sparse=sparse) sortres.select_random_spikes(max_spikes_per_unit=20, seed=2205) @@ -79,32 +93,36 @@ def test_ComputeTemplates(format="memory"): sortres.compute("templates") sortres.compute("waveforms") - waveforms = sortres.get_extension("waveforms").data["waveforms"] sortres.compute("templates", operators=["average", "std", "median", ("percentile", 5.), ("percentile", 95.),]) data = sortres.get_extension("templates").data for k in ['average', 'std', 'median', 'pencentile_5.0', 'pencentile_95.0']: assert k in data.keys() - assert data[k].shape[0] == sortres.sorting.unit_ids.size - assert data[k].shape[1] == waveforms.shape[1] - + assert data[k].shape[0] == sortres.unit_ids.size + assert data[k].shape[2] == sortres.channel_ids.size + assert np.any(data[k] > 0) # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # unit_index = 2 - # for k in data.keys(): - # wf0 = data[k][unit_index, :, :] - # ax.plot(wf0.T.flatten(), label=k) - # ax.legend() + # for unit_index, unit_id in enumerate(sortres.unit_ids): + # fig, ax = plt.subplots() + # for k in data.keys(): + # wf0 = data[k][unit_index, :, :] + # ax.plot(wf0.T.flatten(), label=k) + # ax.legend() # plt.show() _check_result_extension(sortres, "templates") if __name__ == '__main__': - # test_ComputeWaveforms(format="memory") - # test_ComputeWaveforms(format="binary_folder") - # test_ComputeWaveforms(format="zarr") - test_ComputeTemplates(format="memory") - # test_ComputeTemplates(format="binary_folder") - # test_ComputeTemplates(format="zarr") + # test_ComputeWaveforms(format="memory", sparse=True) + # test_ComputeWaveforms(format="memory", sparse=False) + # test_ComputeWaveforms(format="binary_folder", sparse=True) + # test_ComputeWaveforms(format="binary_folder", sparse=False) + # test_ComputeWaveforms(format="zarr", sparse=True) + # test_ComputeWaveforms(format="zarr", sparse=False) + + # test_ComputeTemplates(format="memory", sparse=True) + # test_ComputeTemplates(format="memory", sparse=False) + test_ComputeTemplates(format="binary_folder", sparse=True) + # test_ComputeTemplates(format="zarr", sparse=True) diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index ee2532d806..13570ace9d 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -52,5 +52,5 @@ def test_random_spikes_selection(): if __name__ == "__main__": - # test_spike_vector_to_spike_trains() + test_spike_vector_to_spike_trains() test_random_spikes_selection() diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index 18f910f3c2..afcebc2383 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -5,7 +5,7 @@ from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import SortingResult, start_sorting_result, load_sorting_result -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension, random_spikes_selection +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension import numpy as np @@ -25,12 +25,15 @@ def get_dataset(): return recording, sorting - def test_SortingResult_memory(): recording, sorting = get_dataset() sortres = start_sorting_result(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_results(sortres, sorting) + sortres = start_sorting_result(sorting, recording, format="memory", sparse=True, sparsity=None) + _check_sorting_results(sortres, sorting) + + def test_SortingResult_binary_folder(): recording, sorting = get_dataset() @@ -189,34 +192,9 @@ def test_extension(): with pytest.raises(AssertionError): register_result_extension(DummyResultExtension2) -def test_random_spikes_selection(): - recording, sorting = get_dataset() - - max_spikes_per_unit = 12 - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - - random_spikes_indices = random_spikes_selection(sorting, num_samples, - method="uniform", max_spikes_per_unit=max_spikes_per_unit, - margin_size=None, - seed=2205) - spikes = sorting.to_spike_vector() - some_spikes = spikes[random_spikes_indices] - for unit_index, unit_id in enumerate(sorting.unit_ids): - spike_slected_unit = some_spikes[some_spikes["unit_index"] == unit_index] - assert spike_slected_unit.size == max_spikes_per_unit - - # with margin - random_spikes_indices = random_spikes_selection(sorting, num_samples, - method="uniform", max_spikes_per_unit=max_spikes_per_unit, - margin_size=25, - seed=2205) - # in that case the number is not garanty so it can be a bit less - assert random_spikes_indices.size >= (0.9 * sorting.unit_ids.size * max_spikes_per_unit) - if __name__ == "__main__": - # test_SortingResult_memory() - test_SortingResult_binary_folder() + test_SortingResult_memory() + # test_SortingResult_binary_folder() # test_SortingResult_zarr() # test_extension() - # test_random_spikes_selection() From 9b7ed2ade349fa21d4c8bbc3e8d76f31148a3494 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 17 Jan 2024 18:48:00 +0100 Subject: [PATCH 013/192] Implement ComputeFastTemplates --- src/spikeinterface/core/__init__.py | 3 +- src/spikeinterface/core/result_core.py | 81 ++++++++++++++++--- .../core/tests/test_result_core.py | 54 ++++++++++--- 3 files changed, 119 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 4ffdc505bb..1ffa52e43f 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -145,7 +145,8 @@ from .sortingresult import SortingResult, ResultExtension, start_sorting_result, load_sorting_result from .result_core import ( ComputeWaveforms, compute_waveforms, - ComputeTemplates, compute_templates + ComputeTemplates, compute_templates, + ComputeFastTemplates, compute_fast_templates, ) diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index b242d4bdc9..f4f6e57db4 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -10,9 +10,14 @@ import numpy as np from .sortingresult import ResultExtension, register_result_extension -from .waveform_tools import extract_waveforms_to_single_buffer +from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates class ComputeWaveforms(ResultExtension): + """ + ResultExtension that extract some waveforms of each units. + + The sparsity is controlled by the SortingResult sparsity. + """ extension_name = "waveforms" depend_on = [] need_recording = True @@ -74,8 +79,7 @@ def _run(self, **kwargs): def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, - max_spikes_per_unit: int = 500, - return_scaled: bool = False, + return_scaled: bool = True, dtype=None, ): recording = self.sorting_result.recording @@ -93,13 +97,9 @@ def _set_params(self, dtype = np.dtype(dtype) - if max_spikes_per_unit is not None: - max_spikes_per_unit = int(max_spikes_per_unit) - params = dict( ms_before=float(ms_before), ms_after=float(ms_after), - max_spikes_per_unit=max_spikes_per_unit, return_scaled=return_scaled, dtype=dtype.str, ) @@ -122,7 +122,15 @@ def _select_extension_data(self, unit_ids): compute_waveforms = ComputeWaveforms.function_factory() register_result_extension(ComputeWaveforms) + class ComputeTemplates(ResultExtension): + """ + ResultExtension that compute templates (average, str, median, percentile, ...) + + This must be run after "waveforms" extension (`SortingResult.compute("waveforms")`) + + Note that when "waveforms" is already done, then the recording is not needed anymore for this extension. + """ extension_name = "templates" depend_on = ["waveforms"] need_recording = False @@ -199,8 +207,63 @@ def _select_extension_data(self, unit_ids): return new_data +compute_templates = ComputeTemplates.function_factory() +register_result_extension(ComputeTemplates) + + +class ComputeFastTemplates(ResultExtension): + """ + ResultExtension which is similar to the extension "templates" (ComputeTemplates) **but only for average**. + This is way faster because it do not need "waveforms" to be computed first. + """ + extension_name = "fast_templates" + depend_on = [] + need_recording = True + use_nodepiepline = False + + def _run(self, **kwargs): + self.data.clear() + + if self.sorting_result.random_spikes_indices is None: + raise ValueError("compute_waveforms need SortingResult.select_random_spikes() need to be run first") + + recording = self.sorting_result.recording + sorting = self.sorting_result.sorting + unit_ids = sorting.unit_ids + + # retrieve spike vector and the sampling + spikes = sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + + nbefore = int(self.params["ms_before"] * sorting.sampling_frequency / 1000.0) + nafter = int(self.params["ms_after"] * sorting.sampling_frequency / 1000.0) + + return_scaled = self.params["return_scaled"] + + # TODO jobw_kwargs + self.data["average"] = estimate_templates(recording, some_spikes, unit_ids, nbefore, nafter, return_scaled=return_scaled) + + def _set_params(self, + ms_before: float = 1.0, + ms_after: float = 2.0, + return_scaled: bool = True, + ): + params = dict( + ms_before=float(ms_before), + ms_after=float(ms_after), + return_scaled=return_scaled, + ) + return params + + def _select_extension_data(self, unit_ids): + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) -compute_templates = ComputeTemplates.function_factory() -register_result_extension(ComputeTemplates) \ No newline at end of file + new_data = dict() + new_data["average"] = self.data["average"][keep_unit_indices, :, :] + + return new_data + +compute_fast_templates = ComputeFastTemplates.function_factory() +register_result_extension(ComputeFastTemplates) diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index 7ae6c40d86..949663294f 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -46,8 +46,6 @@ def get_sorting_result(format="memory", sparse=True): def _check_result_extension(sortres, extension_name): - - # select unit_ids to several format for format in ("memory", "binary_folder", "zarr"): # for format in ("memory", ): @@ -103,16 +101,52 @@ def test_ComputeTemplates(format, sparse): assert data[k].shape[2] == sortres.channel_ids.size assert np.any(data[k] > 0) + import matplotlib.pyplot as plt + for unit_index, unit_id in enumerate(sortres.unit_ids): + fig, ax = plt.subplots() + for k in data.keys(): + wf0 = data[k][unit_index, :, :] + ax.plot(wf0.T.flatten(), label=k) + ax.legend() + # plt.show() + + _check_result_extension(sortres, "templates") + + +@pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) +@pytest.mark.parametrize("sparse", [True, False]) +def test_ComputeFastTemplates(format, sparse): + sortres = get_sorting_result(format=format, sparse=sparse) + + ms_before=1.0 + ms_after=2.5 + + sortres.select_random_spikes(max_spikes_per_unit=20, seed=2205) + sortres.compute("fast_templates", ms_before=ms_before, ms_after=ms_after, return_scaled=True) + + _check_result_extension(sortres, "fast_templates") + + # compare ComputeTemplates with dense and ComputeFastTemplates: should give the same on "average" + other_sortres = get_sorting_result(format=format, sparse=False) + other_sortres.select_random_spikes(max_spikes_per_unit=20, seed=2205) + other_sortres.compute("waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True) + other_sortres.compute("templates", operators=["average",]) + + templates0 = sortres.get_extension("fast_templates").data["average"] + templates1 = other_sortres.get_extension("templates").data["average"] + np.testing.assert_almost_equal(templates0, templates1) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() # for unit_index, unit_id in enumerate(sortres.unit_ids): - # fig, ax = plt.subplots() - # for k in data.keys(): - # wf0 = data[k][unit_index, :, :] - # ax.plot(wf0.T.flatten(), label=k) - # ax.legend() + # wf0 = templates0[unit_index, :, :] + # ax.plot(wf0.T.flatten(), label=f"{unit_id}") + # wf1 = templates1[unit_index, :, :] + # ax.plot(wf1.T.flatten(), ls='--', color='k') + # ax.legend() # plt.show() - _check_result_extension(sortres, "templates") + if __name__ == '__main__': # test_ComputeWaveforms(format="memory", sparse=True) @@ -124,5 +158,7 @@ def test_ComputeTemplates(format, sparse): # test_ComputeTemplates(format="memory", sparse=True) # test_ComputeTemplates(format="memory", sparse=False) - test_ComputeTemplates(format="binary_folder", sparse=True) + # test_ComputeTemplates(format="binary_folder", sparse=True) # test_ComputeTemplates(format="zarr", sparse=True) + + test_ComputeFastTemplates(format="memory", sparse=True) From 013a80553516396b137fa55c06aeee035c2ca773 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 17 Jan 2024 19:27:20 +0100 Subject: [PATCH 014/192] Add spikeinterface_info in SortingResult --- src/spikeinterface/core/result_core.py | 2 +- src/spikeinterface/core/sortingresult.py | 17 +++++++++++++++++ .../core/tests/test_sortingresult.py | 6 +++--- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index f4f6e57db4..20ca61d937 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -2,9 +2,9 @@ Implement ResultExtension that are essential and imported in core * ComputeWaveforms * ComputeTemplates - Theses two classes replace the WaveformExtractor +It also implement ComputeFastTemplates which is equivalent but without extacting waveforms. """ import numpy as np diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index bf92c87942..23317e11df 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -13,6 +13,8 @@ import probeinterface +import spikeinterface + from .baserecording import BaseRecording from .basesorting import BaseSorting @@ -254,6 +256,15 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attribut raise ValueError(f"Folder already exists {folder}") folder.mkdir(parents=True) + + info_file = folder / f"spikeinterface_info.json" + info = dict( + version=spikeinterface.__version__, + dev_mode=spikeinterface.DEV_MODE, + ) + with open(info_file, mode="w") as f: + json.dump(check_json(info), f, indent=4) + # save a copy of the sorting NumpyFolderSorting.write_sorting(sorting, folder / "sorting") @@ -378,6 +389,12 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): raise ValueError(f"Folder already exists {folder}") zarr_root = zarr.open(folder, mode="w") + + info = dict( + version=spikeinterface.__version__, + dev_mode=spikeinterface.DEV_MODE, + ) + zarr_root.attrs["spikeinterface_info"] = check_json(info) # the recording rec_dict = recording.to_dict(relative_to=folder, recursive=True) diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index afcebc2383..e341e02874 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -194,7 +194,7 @@ def test_extension(): if __name__ == "__main__": - test_SortingResult_memory() - # test_SortingResult_binary_folder() - # test_SortingResult_zarr() + # test_SortingResult_memory() + test_SortingResult_binary_folder() + test_SortingResult_zarr() # test_extension() From b9a66d7f49dd074b9aa68ef34793b28b5b75caa0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 17 Jan 2024 21:58:48 +0100 Subject: [PATCH 015/192] yep --- src/spikeinterface/core/sortingresult.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 23317e11df..dceabeac77 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -28,10 +28,6 @@ from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor -# TODO -# * make info.json that contain some version info of spikeinterface -# * same for zarr - # high level function @@ -247,8 +243,6 @@ def create_memory(cls, sorting, recording, sparsity, rec_attributes): def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attributes): # used by create and save_as - # TODO add a spikeinterface_info.json folder with SI version and object type - assert recording is not None, "To create a SortingResult you need recording not None" folder = Path(folder) @@ -261,6 +255,7 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attribut info = dict( version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, + object="SortingResult", ) with open(info_file, mode="w") as f: json.dump(check_json(info), f, indent=4) @@ -378,8 +373,6 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): import zarr import numcodecs - # TODO add an attribute with SI version and object type - folder = Path(folder) # force zarr sufix if folder.suffix != ".zarr": @@ -393,6 +386,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): info = dict( version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, + object="SortingResult" ) zarr_root.attrs["spikeinterface_info"] = check_json(info) @@ -727,7 +721,7 @@ def __repr__(self) -> str: txt += " - sparse" if self.has_recording(): txt += " - has recording" - ext_txt = f"Load extenstions [{len(self.extensions)}]: " + ", ".join(self.extensions.keys()) + ext_txt = f"Loaded {len(self.extensions)} extenstions: " + ", ".join(self.extensions.keys()) txt += "\n" + ext_txt return txt From 4a8b54e1c02f6284516b7f5dff7cbde9b085cbfc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 18 Jan 2024 16:43:05 +0100 Subject: [PATCH 016/192] Start MockWaveformExtractor for backawrds compatibility --- src/spikeinterface/core/generate.py | 2 +- ...forms_extractor_backwards_compatibility.py | 63 +++++ ...forms_extractor_backwards_compatibility.py | 265 ++++++++++++++++++ 3 files changed, 329 insertions(+), 1 deletion(-) create mode 100644 src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py create mode 100644 src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 76bd0f5a20..f7170361fb 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1156,7 +1156,7 @@ def generate_single_fake_waveform( bins = np.arange(-n, n + 1) smooth_kernel = np.exp(-(bins**2) / (2 * smooth_size**2)) smooth_kernel /= np.sum(smooth_kernel) - smooth_kernel = smooth_kernel[4:] + # smooth_kernel = smooth_kernel[4:] wf = np.convolve(wf, smooth_kernel, mode="same") # ensure the the peak to be extatly at nbefore (smooth can modify this) diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py new file mode 100644 index 0000000000..c50b988543 --- /dev/null +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -0,0 +1,63 @@ +import pytest +from pathlib import Path + +import shutil + + +from spikeinterface.core import generate_ground_truth_recording + +from spikeinterface.core.waveforms_extractor_backwards_compatibility import extract_waveforms as mock_extract_waveforms + +# remove this when WaveformsExtractor will be removed +from spikeinterface.core import extract_waveforms as old_extract_waveforms + + + + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "core" +else: + cache_folder = Path("cache_folder") / "core" + + +def get_dataset(): + recording, sorting = generate_ground_truth_recording( + durations=[3600.0], sampling_frequency=16000.0, num_channels=128, num_units=100, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params_range=dict( + alpha=(9_000.0, 12_000.0), + ) + ), + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=2406, + ) + return recording, sorting + + +def test_extract_waveforms(): + recording, sorting = get_dataset() + print(recording) + + folder = cache_folder / "mock_waveforms_extractor" + if folder.exists(): + shutil.rmtree(folder) + + we = mock_extract_waveforms(recording, sorting, folder=folder, sparse=True) + print(we) + + folder = cache_folder / "old_waveforms_extractor" + if folder.exists(): + shutil.rmtree(folder) + + we = old_extract_waveforms(recording, sorting, folder=folder, sparse=True) + print(we) + + +if __name__ == "__main__": + test_extract_waveforms() diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py new file mode 100644 index 0000000000..14aed355ce --- /dev/null +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -0,0 +1,265 @@ +""" +This backwards compatibility module aims to: + * load old WaveformsExtractor saved with folder or zarr (version <=0.100) into the SortingResult (version>0.100) + * mock the function extract_waveforms() and the class SortingResult() but based SortingResult +""" +from __future__ import annotations + +from typing import Literal, Optional + +from pathlib import Path + +import numpy as np + + +from .baserecording import BaseRecording +from .basesorting import BaseSorting +from .sortingresult import start_sorting_result +from .job_tools import split_job_kwargs + + +_backwards_compatibility_msg = """#### +# extract_waveforms() and WaveformExtractor() have been replace by SortingResult since version 0.101 +# You should use start_sorting_result() instead. +# extract_waveforms() is now mocking the old behavior for backwards compatibility only and will be removed after 0.103 +####""" + + +def extract_waveforms( + recording, + sorting, + folder=None, + mode="folder", + precompute_template=("average",), + ms_before=1.0, + ms_after=2.0, + max_spikes_per_unit=500, + overwrite=None, + return_scaled=True, + dtype=None, + sparse=True, + sparsity=None, + sparsity_temp_folder=None, + num_spikes_for_sparsity=100, + unit_batch_size=None, + allow_unfiltered=None, + use_relative_path=None, + seed=None, + load_if_exists=None, + **kwargs, +): + """ + This mock the extract_waveforms() in version <= 0.100 to not break old codes but using + the SortingResult (version >0.100) internally. + + This return a MockWaveformExtractor object that mock the old WaveformExtractor + """ + print(_backwards_compatibility_msg) + + assert load_if_exists is None, "load_if_exists=True/False is not supported anymore. use load_if_exists=None" + assert overwrite is None, "overwrite=True/False is not supported anymore. use overwrite=None" + + other_kwargs, job_kwargs = split_job_kwargs(kwargs) + + if mode == "folder": + assert folder is not None + folder = Path(folder) + format = "binary_folder" + else: + mode = "memory" + + assert sparsity_temp_folder is None, "sparsity_temp_folder must be None" + assert unit_batch_size is None, "unit_batch_size must be None" + + if use_relative_path is not None: + print("use_relative_path is ignored") + + if allow_unfiltered is not None: + print("allow_unfiltered is ignored") + + sparsity_kwargs = dict( + num_spikes_for_sparsity=num_spikes_for_sparsity, + ms_before=ms_before, + ms_after=ms_after, + **other_kwargs, + **job_kwargs + ) + sorting_result = start_sorting_result(sorting, recording, format=format, folder=folder, + sparse=sparse, sparsity=sparsity, **sparsity_kwargs + ) + + # TODO propagate job_kwargs + + sorting_result.select_random_spikes(max_spikes_per_unit=max_spikes_per_unit, seed=seed) + + waveforms_params = dict(ms_before=ms_before, ms_after=ms_after, return_scaled=return_scaled, dtype=dtype) + sorting_result.compute("waveforms", **waveforms_params) + + templates_params = dict(operators=list(precompute_template)) + sorting_result.compute("templates", **templates_params) + + we = MockWaveformExtractor(sorting_result) + + return we + + + +class MockWaveformExtractor: + def __init__(self, sorting_result): + self.sorting_result = sorting_result + + def __repr__(self): + txt = "MockWaveformExtractor: mock the old WaveformExtractor with " + txt += self.sorting_result.__repr__() + return txt + + def is_sparse(self) -> bool: + return self.sorting_result.is_sparse() + + def has_waveforms(self) -> bool: + + raise NotImplementedError + + def delete_waveforms(self) -> None: + raise NotImplementedError + + @property + def recording(self) -> BaseRecording: + raise NotImplementedError + + @property + def channel_ids(self) -> np.ndarray: + raise NotImplementedError + + @property + def sampling_frequency(self) -> float: + raise NotImplementedError + + @property + def unit_ids(self) -> np.ndarray: + raise NotImplementedError + + @property + def nbefore(self) -> int: + raise NotImplementedError + + @property + def nafter(self) -> int: + raise NotImplementedError + + @property + def nsamples(self) -> int: + return self.nbefore + self.nafter + + @property + def return_scaled(self) -> bool: + raise NotImplementedError + + @property + def dtype(self): + raise NotImplementedError + + def is_read_only(self) -> bool: + raise NotImplementedError + + def has_recording(self) -> bool: + raise NotImplementedError + + def get_num_samples(self, segment_index: Optional[int] = None) -> int: + raise NotImplementedError + + def get_total_samples(self) -> int: + s = 0 + for segment_index in range(self.get_num_segments()): + s += self.get_num_samples(segment_index) + return s + + def get_total_duration(self) -> float: + duration = self.get_total_samples() / self.sampling_frequency + return duration + + def get_num_channels(self) -> int: + raise NotImplementedError + # if self.has_recording(): + # return self.recording.get_num_channels() + # else: + # return self._rec_attributes["num_channels"] + + def get_num_segments(self) -> int: + return self.sorting_result.sorting.get_num_segments() + + def get_probegroup(self): + raise NotImplementedError + # if self.has_recording(): + # return self.recording.get_probegroup() + # else: + # return self._rec_attributes["probegroup"] + + # def is_filtered(self) -> bool: + # if self.has_recording(): + # return self.recording.is_filtered() + # else: + # return self._rec_attributes["is_filtered"] + + def get_probe(self): + probegroup = self.get_probegroup() + assert len(probegroup.probes) == 1, "There are several probes. Use `get_probegroup()`" + return probegroup.probes[0] + + def get_channel_locations(self) -> np.ndarray: + raise NotImplementedError + + def channel_ids_to_indices(self, channel_ids) -> np.ndarray: + raise NotImplementedError + + def get_recording_property(self, key) -> np.ndarray: + raise NotImplementedError + + def get_sorting_property(self, key) -> np.ndarray: + return self.sorting.get_property(key) + + # def has_extension(self, extension_name: str) -> bool: + # raise NotImplementedError + + def get_waveforms( + self, + unit_id, + with_index: bool = False, + cache: bool = False, + lazy: bool = True, + sparsity=None, + force_dense: bool = False, + ): + raise NotImplementedError + + def get_sampled_indices(self, unit_id): + raise NotImplementedError + + + def get_all_templates( + self, unit_ids: list | np.array | tuple | None = None, mode="average", percentile: float | None = None + ): + raise NotImplementedError + + def get_template( + self, unit_id, mode="average", sparsity=None, force_dense: bool = False, percentile: float | None = None + ): + raise NotImplementedError + + + +def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None, output="SortingResult"): + """ + This read an old WaveformsExtactor folder (folder or zarr) and convert it into a SortingResult or MockWaveformExtractor. + + """ + + raise NotImplementedError + + # This will be something like this create a SortingResult in memory and copy/translate all data into the new structure. + # sorting_result = ... + + # if output == "SortingResult": + # return sorting_result + # elif output in ("WaveformExtractor", "MockWaveformExtractor"): + # return MockWaveformExtractor(sorting_result) From 2668188de2e44757c67826ed55dceedfa633666b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 19 Jan 2024 12:51:01 +0100 Subject: [PATCH 017/192] continue MockWaveformExtractor implementation --- src/spikeinterface/core/__init__.py | 7 +- src/spikeinterface/core/result_core.py | 6 +- src/spikeinterface/core/sortingresult.py | 50 ++++-- .../core/tests/test_sortingresult.py | 2 +- ...forms_extractor_backwards_compatibility.py | 33 +++- src/spikeinterface/core/waveform_tools.py | 2 +- ...forms_extractor_backwards_compatibility.py | 154 +++++++++++------- 7 files changed, 166 insertions(+), 88 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 1ffa52e43f..36fccad8fa 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -106,6 +106,8 @@ from .snippets_tools import snippets_from_sorting # waveform extractor +# Important not for compatibility!! +# This wil be commented after 0.100 relase but the module will not be removed. from .waveform_extractor import ( WaveformExtractor, BaseWaveformExtractorExtension, @@ -149,6 +151,7 @@ ComputeFastTemplates, compute_fast_templates, ) - - +# Important not for compatibility!! +# This wil be uncommented after 0.100 +# from .waveforms_extractor_backwards_compatibility import extract_waveforms, load_waveforms diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 20ca61d937..17174c84ac 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -21,7 +21,7 @@ class ComputeWaveforms(ResultExtension): extension_name = "waveforms" depend_on = [] need_recording = True - use_nodepiepline = False + use_nodepipeline = False def _run(self, **kwargs): self.data.clear() @@ -134,7 +134,7 @@ class ComputeTemplates(ResultExtension): extension_name = "templates" depend_on = ["waveforms"] need_recording = False - use_nodepiepline = False + use_nodepipeline = False def _run(self, **kwargs): @@ -219,7 +219,7 @@ class ComputeFastTemplates(ResultExtension): extension_name = "fast_templates" depend_on = [] need_recording = True - use_nodepiepline = False + use_nodepipeline = False def _run(self, **kwargs): self.data.clear() diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index dceabeac77..252e41667e 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -161,6 +161,20 @@ def __init__(self, sorting=None, recording=None, rec_attributes=None, format=Non # extensions are not loaded at init self.extensions = dict() + def __repr__(self) -> str: + clsname = self.__class__.__name__ + nseg = self.get_num_segments() + nchan = self.get_num_channels() + nunits = self.sorting.get_num_units() + txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" + if self.is_sparse(): + txt += " - sparse" + if self.has_recording(): + txt += " - has recording" + ext_txt = f"Loaded {len(self.extensions)} extenstions: " + ", ".join(self.extensions.keys()) + txt += "\n" + ext_txt + return txt + ## create and load zone @classmethod @@ -711,19 +725,12 @@ def channel_ids_to_indices(self, channel_ids) -> np.ndarray: indices = np.array([all_channel_ids.index(id) for id in channel_ids], dtype=int) return indices - def __repr__(self) -> str: - clsname = self.__class__.__name__ - nseg = self.get_num_segments() - nchan = self.get_num_channels() - nunits = self.sorting.get_num_units() - txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" - if self.is_sparse(): - txt += " - sparse" - if self.has_recording(): - txt += " - has recording" - ext_txt = f"Loaded {len(self.extensions)} extenstions: " + ", ".join(self.extensions.keys()) - txt += "\n" + ext_txt - return txt + def get_recording_property(self, key) -> np.ndarray: + values = np.array(self.rec_attributes["properties"].get(key, None)) + return values + + def get_sorting_property(self, key) -> np.ndarray: + return self.sorting.get_property(key) ## extensions zone def compute(self, extension_name, **params): @@ -883,7 +890,7 @@ def has_extension(self, extension_name: str) -> bool: ## random_spikes_selection zone def select_random_spikes(self, **random_kwargs): - + # random_spikes_indices is a vector that refer to the spike vector of the sorting in absolut index assert self.random_spikes_indices is None, "select random spikes is already computed" self.random_spikes_indices = random_spikes_selection(self.sorting, self.rec_attributes["num_samples"], **random_kwargs) @@ -894,6 +901,17 @@ def select_random_spikes(self, **random_kwargs): zarr_root = self._get_zarr_root() zarr_root.create_dataset("random_spikes_indices", data=self.random_spikes_indices) + def get_selected_indices_in_spike_train(self, unit_id, segment_index): + # usefull for Waveforms extractor backwars compatibility + # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain + assert self.random_spikes_indices is not None, "random spikes selection is not computeds" + unit_index = self.sorting.id_to_index(unit_id) + spikes = self.sorting.to_spike_vector() + spike_indices_in_seg = np.flatnonzero((spikes["segment_index"] == segment_index) & (spikes["unit_index"] == unit_index)) + common_element, inds_left, inds_right = np.intersect1d(spike_indices_in_seg, self.random_spikes_indices, return_indices=True) + selected_spikes_in_spike_train = inds_left + return selected_spikes_in_spike_train + global _possible_extensions _possible_extensions = [] @@ -962,7 +980,7 @@ class ResultExtension: * extension_name * depend_on * need_recording - * use_nodepiepline + * use_nodepipeline * _set_params() * _run() * _select_extension_data() @@ -980,7 +998,7 @@ class ResultExtension: extension_name = None depend_on = [] need_recording = False - use_nodepiepline = False + use_nodepipeline = False def __init__(self, sorting_result): self._sorting_result = weakref.ref(sorting_result) diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index e341e02874..111ccd6cd7 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -148,7 +148,7 @@ class DummyResultExtension(ResultExtension): extension_name = "dummy" depend_on = [] need_recording = False - use_nodepiepline = False + use_nodepipeline = False def _set_params(self, param0="yep", param1=1.2, param2=[1,2, 3.]): params = dict(param0=param0, param1=param1, param2=param2) diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index c50b988543..1ddbaa525c 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -22,7 +22,7 @@ def get_dataset(): recording, sorting = generate_ground_truth_recording( - durations=[3600.0], sampling_frequency=16000.0, num_channels=128, num_units=100, + durations=[30.0, 20.], sampling_frequency=16000.0, num_channels=4, num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), generate_unit_locations_kwargs=dict( margin_um=5.0, @@ -42,21 +42,38 @@ def get_dataset(): def test_extract_waveforms(): recording, sorting = get_dataset() - print(recording) - folder = cache_folder / "mock_waveforms_extractor" + folder = cache_folder / "old_waveforms_extractor" if folder.exists(): shutil.rmtree(folder) - we = mock_extract_waveforms(recording, sorting, folder=folder, sparse=True) - print(we) + we_kwargs = dict(sparse=True, max_spikes_per_unit=30) - folder = cache_folder / "old_waveforms_extractor" + we_old = old_extract_waveforms(recording, sorting, folder=folder, **we_kwargs) + print(we_old) + + + folder = cache_folder / "mock_waveforms_extractor" if folder.exists(): shutil.rmtree(folder) - we = old_extract_waveforms(recording, sorting, folder=folder, sparse=True) - print(we) + we_mock = mock_extract_waveforms(recording, sorting, folder=folder, **we_kwargs) + print(we_mock) + + for we in (we_old, we_mock): + + selected_spikes = we.get_sampled_indices(unit_id=sorting.unit_ids[0]) + # print(selected_spikes.size, selected_spikes.dtype) + + wfs = we.get_waveforms(sorting.unit_ids[0]) + # print(wfs.shape) + + wfs = we.get_waveforms(sorting.unit_ids[0], force_dense=True) + # print(wfs.shape) + + templates = we.get_all_templates() + # print(templates.shape) + if __name__ == "__main__": diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index b67066e02f..1fee9a44a1 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -767,7 +767,7 @@ def estimate_templates( ) - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="estimate_sparsity", **job_kwargs) + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="estimate_templates", **job_kwargs) processor.run() # average diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 14aed355ce..59867f5433 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -117,35 +117,40 @@ def is_sparse(self) -> bool: return self.sorting_result.is_sparse() def has_waveforms(self) -> bool: - - raise NotImplementedError + return self.sorting_result.get_extension("waveforms") is not None def delete_waveforms(self) -> None: - raise NotImplementedError + self.sorting_result.delete_extension("waveforms") @property def recording(self) -> BaseRecording: - raise NotImplementedError + return self.sorting_result.recording + + @property + def sorting(self) -> BaseSorting: + return self.sorting_result.sorting @property def channel_ids(self) -> np.ndarray: - raise NotImplementedError + return self.sorting_result.channel_ids @property def sampling_frequency(self) -> float: - raise NotImplementedError + return self.sorting_result.sampling_frequency @property def unit_ids(self) -> np.ndarray: - raise NotImplementedError + return self.sorting_result.unit_ids @property def nbefore(self) -> int: - raise NotImplementedError + ms_before = self.sorting_result.get_extension("waveforms").params["ms_before"] + return int(ms_before * self.sampling_frequency / 1000.0) @property def nafter(self) -> int: - raise NotImplementedError + ms_after = self.sorting_result.get_extension("waveforms").params["ms_after"] + return int(ms_after * self.sampling_frequency / 1000.0) @property def nsamples(self) -> int: @@ -153,74 +158,68 @@ def nsamples(self) -> int: @property def return_scaled(self) -> bool: - raise NotImplementedError + return self.sorting_result.get_extension("waveforms").params["return_scaled"] @property def dtype(self): - raise NotImplementedError + return self.sorting_result.get_extension("waveforms").params["dtype"] def is_read_only(self) -> bool: - raise NotImplementedError + return self.sorting_result.is_read_only() def has_recording(self) -> bool: - raise NotImplementedError + return self.sorting_result._recording is not None def get_num_samples(self, segment_index: Optional[int] = None) -> int: - raise NotImplementedError + return self.sorting_result.get_num_samples(segment_index) def get_total_samples(self) -> int: - s = 0 - for segment_index in range(self.get_num_segments()): - s += self.get_num_samples(segment_index) - return s + return self.sorting_result.get_total_samples() def get_total_duration(self) -> float: - duration = self.get_total_samples() / self.sampling_frequency - return duration + return self.sorting_result.get_total_duration() def get_num_channels(self) -> int: - raise NotImplementedError - # if self.has_recording(): - # return self.recording.get_num_channels() - # else: - # return self._rec_attributes["num_channels"] + return self.sorting_result.get_num_channels() def get_num_segments(self) -> int: - return self.sorting_result.sorting.get_num_segments() + return self.sorting_result.get_num_segments() def get_probegroup(self): - raise NotImplementedError - # if self.has_recording(): - # return self.recording.get_probegroup() - # else: - # return self._rec_attributes["probegroup"] - - # def is_filtered(self) -> bool: - # if self.has_recording(): - # return self.recording.is_filtered() - # else: - # return self._rec_attributes["is_filtered"] + return self.sorting_result.get_probegroup() def get_probe(self): - probegroup = self.get_probegroup() - assert len(probegroup.probes) == 1, "There are several probes. Use `get_probegroup()`" - return probegroup.probes[0] + return self.sorting_result.get_probe() + + def is_filtered(self) -> bool: + return self.sorting_result.rec_attributes["is_filtered"] def get_channel_locations(self) -> np.ndarray: - raise NotImplementedError + return self.sorting_result.get_channel_locations() def channel_ids_to_indices(self, channel_ids) -> np.ndarray: - raise NotImplementedError + return self.sorting_result.channel_ids_to_indices(channel_ids) def get_recording_property(self, key) -> np.ndarray: - raise NotImplementedError + return self.sorting_result.get_recording_property(key) def get_sorting_property(self, key) -> np.ndarray: - return self.sorting.get_property(key) + return self.sorting_result.get_sorting_property(key) + + def has_extension(self, extension_name: str) -> bool: + return self.sorting_result.has_extension(extension_name) + + def get_sampled_indices(self, unit_id): + # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) with a complex dtype as follow + selected_spikes = [] + for segment_index in range(self.get_num_segments()): + inds = self.sorting_result.get_selected_indices_in_spike_train(unit_id, segment_index) + sampled_index = np.zeros(inds.size, dtype=[("spike_index", "int64"), ("segment_index", "int64")]) + sampled_index["spike_index"] = inds + sampled_index["segment_index"][:] = segment_index + selected_spikes.append(sampled_index) + return np.concatenate(selected_spikes) - # def has_extension(self, extension_name: str) -> bool: - # raise NotImplementedError - def get_waveforms( self, unit_id, @@ -229,26 +228,67 @@ def get_waveforms( lazy: bool = True, sparsity=None, force_dense: bool = False, - ): - raise NotImplementedError - - def get_sampled_indices(self, unit_id): - raise NotImplementedError - + ): + # lazy and cache are ingnored + ext = self.sorting_result.get_extension("waveforms") + unit_index = self.sorting.id_to_index(unit_id) + spikes = self.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + spike_mask = some_spikes["unit_index"] == unit_index + wfs = ext.data["waveforms"][spike_mask, :, :] + + if sparsity is not None: + assert self.sorting_result.sparsity is None, "Waveforms are alreayd sparse! Cannot apply an additional sparsity." + wfs = wfs[:, :, sparsity.mask[self.sorting.id_to_index(unit_id)]] + + if force_dense: + assert sparsity is None + if self.sorting_result.sparsity is None: + # nothing to do + pass + else: + num_channels = self.get_num_channels() + dense_wfs = np.zeros((wfs.shape[0], wfs.shape[1], num_channels), dtype=np.float32) + unit_sparsity = self.sorting_result.sparsity.mask[unit_index] + dense_wfs[:, :, unit_sparsity] = wfs + wfs = dense_wfs + + if with_index: + sampled_index = self.get_sampled_indices(unit_id) + return wfs, sampled_index + else: + return wfs def get_all_templates( self, unit_ids: list | np.array | tuple | None = None, mode="average", percentile: float | None = None ): - raise NotImplementedError + ext = self.sorting_result.get_extension("templates") + + if mode == "percentile": + key = f"pencentile_{percentile}" + else: + key = mode + + templates = ext.data.get(key) + if templates is None: + raise ValueError(f"{mode} is not computed") + + if unit_ids is not None: + unit_indices = self.sorting.ids_to_indices(unit_ids) + templates = templates[unit_indices, :, :] + + return templates + def get_template( self, unit_id, mode="average", sparsity=None, force_dense: bool = False, percentile: float | None = None ): - raise NotImplementedError - + # force_dense and sparsity are ignored + templates = self.get_all_templates(unit_ids=[unit_id], mode=mode, percentile=percentile) + return templates[0] -def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None, output="SortingResult"): +def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None, output="SortingResult", ): """ This read an old WaveformsExtactor folder (folder or zarr) and convert it into a SortingResult or MockWaveformExtractor. From 2ecfe6bd7099c23501a4e70671584212ed39a681 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 19 Jan 2024 14:16:53 +0100 Subject: [PATCH 018/192] SortingResult : small and important details --- src/spikeinterface/core/sortingresult.py | 54 ++++++++++++++++-------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 252e41667e..1b5b33bd74 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -733,15 +733,22 @@ def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) ## extensions zone - def compute(self, extension_name, **params): + def compute(self, extension_name, save=True, **params): """ Compute one extension Parameters ---------- - extension_name + extension_name: str + The name of the extension. + For instance "waveforms", "templates", ... + save: bool, default True + It the extension can be saved then it is saved. + If not then the extension will only live in memory as long as the object is deleted. + save=False is convinient to try some parameters without changing an already saved extension. - **params + **params: + All other kwargs are transimited to extension.set_params() Returns ------- @@ -751,9 +758,9 @@ def compute(self, extension_name, **params): Examples -------- - >>> extension = sortres.compute("unit_location", **some_params) - >>> unit_location = extension.get_data() - + >>> extension = sortres.compute("waveforms", **some_params) + >>> wfs = extension.data["waveforms"] + """ @@ -767,8 +774,8 @@ def compute(self, extension_name, **params): assert ext is not None, f"Extension {extension_name} need {dependency_name} to be computed first" extension_instance = extension_class(self) - extension_instance.set_params(**params) - extension_instance.run() + extension_instance.set_params(save=save, **params) + extension_instance.run(save=save) self.extensions[extension_name] = extension_instance @@ -1032,10 +1039,23 @@ def function_factory(cls): class FuncWrapper: def __init__(self, extension_name): self.extension_name = extension_name - def __call__(self, sorting_result, *args, **kwargs): - return sorting_result.compute(self.extension_name, *args, **kwargs) + def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): + # backward compatibility with "load_if_exists" + if load_if_exists is not None: + warnings.warn(f"compute_{cls.extension_name}(..., load_if_exists=True/False) is kept for backward compatibility but should not be used anymore") + assert isinstance(load_if_exists, bool) + if load_if_exists: + ext = sorting_result.get_extension(self.extension_name) + return ext + + ext = sorting_result.compute(cls.extension_name, *args, **kwargs) + # TODO be discussed + return ext + # return ext.data + # return ext.get_data() + func = FuncWrapper(cls.extension_name) - # TODO : make docstring from class docstring + func.__doc__ = cls.__doc__ # TODO: add load_if_exists return func @@ -1153,14 +1173,14 @@ def copy(self, new_sorting_result, unit_ids=None): new_extension.save() return new_extension - def run(self, **kwargs): - if not self.sorting_result.is_read_only(): + def run(self, save=True, **kwargs): + if save and not self.sorting_result.is_read_only(): # this also reset the folder or zarr group self._save_params() self._run(**kwargs) - if not self.sorting_result.is_read_only(): + if save and not self.sorting_result.is_read_only(): self._save_data(**kwargs) def save(self, **kwargs): @@ -1260,7 +1280,7 @@ def reset(self): self.data = dict() - def set_params(self, **params): + def set_params(self, save=True, **params): """ Set parameters for the extension and make it persistent in json. @@ -1275,8 +1295,8 @@ def set_params(self, **params): if self.sorting_result.is_read_only(): return - - self._save_params() + if save: + self._save_params() def _save_params(self): params_to_save = self.params.copy() From 24acdbb2a20d926a9741fd5f08a4a1a958b4aaff Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 22 Jan 2024 12:53:19 +0100 Subject: [PATCH 019/192] Implement load_waveforms that mimic the old version for backwards compatibility. --- ...forms_extractor_backwards_compatibility.py | 12 +- ...forms_extractor_backwards_compatibility.py | 192 +++++++++++++++++- 2 files changed, 195 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index 1ddbaa525c..443d75c08a 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -7,6 +7,7 @@ from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core.waveforms_extractor_backwards_compatibility import extract_waveforms as mock_extract_waveforms +from spikeinterface.core.waveforms_extractor_backwards_compatibility import load_waveforms as load_waveforms_backwards # remove this when WaveformsExtractor will be removed from spikeinterface.core import extract_waveforms as old_extract_waveforms @@ -49,10 +50,10 @@ def test_extract_waveforms(): we_kwargs = dict(sparse=True, max_spikes_per_unit=30) + we_old = old_extract_waveforms(recording, sorting, folder=folder, **we_kwargs) print(we_old) - folder = cache_folder / "mock_waveforms_extractor" if folder.exists(): shutil.rmtree(folder) @@ -76,5 +77,14 @@ def test_extract_waveforms(): + # test reading old WaveformsExtractor folder + folder = cache_folder / "old_waveforms_extractor" + sorting_result_from_we = load_waveforms_backwards(folder, output="SortingResult") + print(sorting_result_from_we) + mock_loaded_we_old = load_waveforms_backwards(folder, output="MockWaveformExtractor") + print(mock_loaded_we_old) + + + if __name__ == "__main__": test_extract_waveforms() diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 59867f5433..46555eee54 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -9,14 +9,20 @@ from pathlib import Path +import json + import numpy as np +import probeinterface from .baserecording import BaseRecording from .basesorting import BaseSorting from .sortingresult import start_sorting_result from .job_tools import split_job_kwargs - +from .sparsity import ChannelSparsity +from .sortingresult import SortingResult +from .base import load_extractor +from .result_core import ComputeWaveforms, ComputeTemplates _backwards_compatibility_msg = """#### # extract_waveforms() and WaveformExtractor() have been replace by SortingResult since version 0.101 @@ -288,18 +294,188 @@ def get_template( return templates[0] + def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None, output="SortingResult", ): """ This read an old WaveformsExtactor folder (folder or zarr) and convert it into a SortingResult or MockWaveformExtractor. """ - raise NotImplementedError + folder = Path(folder) + assert folder.is_dir(), "Waveform folder does not exists" + if folder.suffix == ".zarr": + raise NotImplementedError + # Alessio this is for you + else: + sorting_result = _read_old_waveforms_extractor_binary(folder) + + if output == "SortingResult": + return sorting_result + elif output in ("WaveformExtractor", "MockWaveformExtractor"): + return MockWaveformExtractor(sorting_result) + + + +def _read_old_waveforms_extractor_binary(folder): + params_file = folder / "params.json" + if not params_file.exists(): + raise ValueError(f"This folder is not a WaveformsExtractor folder {folder}") + with open(params_file, "r") as f: + params = json.load(f) + + sparsity_file = folder / "sparsity.json" + if params_file.exists(): + with open(sparsity_file, "r") as f: + sparsity_dict = json.load(f) + sparsity = ChannelSparsity.from_dict(sparsity_dict) + else: + sparsity = None + + # recording attributes + rec_attributes_file = folder / "recording_info" / "recording_attributes.json" + with open(rec_attributes_file, "r") as f: + rec_attributes = json.load(f) + probegroup_file = folder / "recording_info" / "probegroup.json" + if probegroup_file.is_file(): + rec_attributes["probegroup"] = probeinterface.read_probeinterface(probegroup_file) + else: + rec_attributes["probegroup"] = None + + # recording + recording = None + if (folder / "recording.json").exists(): + try: + recording = load_extractor(folder / "recording.json", base_folder=folder) + except: + pass + elif (folder / "recording.pickle").exists(): + try: + recording = load_extractor(folder / "recording.pickle", base_folder=folder) + except: + pass + + # sorting + if (folder / "sorting.json").exists(): + sorting = load_extractor(folder / "sorting.json", base_folder=folder) + elif (folder / "sorting.pickle").exists(): + sorting = load_extractor(folder / "sorting.pickle", base_folder=folder) + + sorting_result = SortingResult.create_memory(sorting, recording, sparsity, rec_attributes=rec_attributes) + + # waveforms + # need to concatenate all waveforms in one unique buffer + # need to concatenate sampled_index and order it + waveform_folder = folder / "waveforms" + if waveform_folder.exists(): + + spikes = sorting.to_spike_vector() + random_spike_mask = np.zeros(spikes.size, dtype="bool") + + all_sampled_indices = [] + # first readd all sampled_index to get the correct ordering + for unit_index, unit_id in enumerate(sorting.unit_ids): + # unit_indices has dtype=[("spike_index", "int64"), ("segment_index", "int64")] + unit_indices = np.load(waveform_folder / f"sampled_index_{unit_id}.npy") + for segment_index in range(sorting.get_num_segments()): + in_seg_selected = unit_indices[unit_indices["segment_index"] == segment_index]["spike_index"] + spikes_indices = np.flatnonzero((spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index)) + random_spike_mask[spikes_indices[in_seg_selected]] = True + random_spikes_indices = np.flatnonzero(random_spike_mask) + + num_spikes = random_spikes_indices.size + if sparsity is None: + max_num_channel = len(rec_attributes["channel_ids"]) + else: + max_num_channel = np.max(np.sum(sparsity.mask, axis=1)) + + nbefore = int(params["ms_before"] * sorting.sampling_frequency / 1000.0) + nafter = int(params["ms_after"] * sorting.sampling_frequency / 1000.0) + + waveforms = np.zeros((num_spikes, nbefore + nafter, max_num_channel), dtype=params["dtype"]) + # then read waveforms per units + some_spikes = spikes[random_spikes_indices] + for unit_index, unit_id in enumerate(sorting.unit_ids): + wfs = np.load(waveform_folder / f"waveforms_{unit_id}.npy") + mask = some_spikes["unit_index"] == unit_index + waveforms[:, :, :wfs.shape[1]][mask, :, :] = wfs + + sorting_result.random_spikes_indices = random_spikes_indices + + ext = ComputeWaveforms(sorting_result) + ext.params = params + ext.data["waveforms"] = waveforms + sorting_result.extensions["waveforms"] = ext + + # templates saved dense + # load cached templates + templates = {} + for mode in ("average", "std", "median", "percentile"): + template_file = folder / f"templates_{mode}.npy" + if template_file.is_file(): + templates [mode] = np.load(template_file) + if len(templates) > 0: + ext = ComputeTemplates(sorting_result) + ext.params = dict(operators=list(templates.keys())) + for mode, arr in templates.items(): + ext.data[mode] = arr + sorting_result.extensions["templates"] = ext + + + # TODO : implement this when extension will be prted in the new API + # old_extension_to_new_class : { + # old extensions with same names and equvalent data + # "spike_amplitudes": , + # "spike_locations": , + # "amplitude_scalings": , + # "template_metrics" : , + # "similarity": , + # "unit_locations": , + # "correlograms" : , + # isi_histograms: , + # "noise_levels": , + # "quality_metrics": , + # "principal_components" : , + # } + # for ext_name, new_class in old_extension_to_new_class.items(): + # ext_folder = folder / ext_name + # ext = new_class(sorting_result) + # with open(ext_folder / "params.json", "r") as f: + # params = json.load(f) + # ext.params = params + # if ext_name == "spike_amplitudes": + # amplitudes = [] + # for segment_index in range(sorting.get_num_segments()): + # amplitudes.append(np.load(ext_folder / f"amplitude_segment_{segment_index}.npy")) + # amplitudes = np.concatenate(amplitudes) + # ext.data["amplitudes"] = amplitudes + # elif ext_name == "spike_locations": + # ext.data["spike_locations"] = np.load(ext_folder / "spike_locations.npy") + # elif ext_name == "amplitude_scalings": + # ext.data["amplitude_scalings"] = np.load(ext_folder / "amplitude_scalings.npy") + # elif ext_name == "template_metrics": + # import pandas as pd + # ext.data["metrics"] = pd.read_csv(ext_folder / "metrics.csv", index_col=0) + # elif ext_name == "similarity": + # ext.data["similarity"] = np.load(ext_folder / "similarity.npy") + # elif ext_name == "unit_locations": + # ext.data["unit_locations"] = np.load(ext_folder / "unit_locations.npy") + # elif ext_name == "correlograms": + # ext.data["ccgs"] = np.load(ext_folder / "ccgs.npy") + # ext.data["bins"] = np.load(ext_folder / "bins.npy") + # elif ext_name == "isi_histograms": + # ext.data["isi_histograms"] = np.load(ext_folder / "isi_histograms.npy") + # ext.data["bins"] = np.load(ext_folder / "bins.npy") + # elif ext_name == "noise_levels": + # ext.data["noise_levels"] = np.load(ext_folder / "noise_levels.npy") + # elif ext_name == "quality_metrics": + # import pandas as pd + # ext.data["metrics"] = pd.read_csv(ext_folder / "metrics.csv", index_col=0) + # elif ext_name == "principal_components": + # # TODO: this is for you + # pass + + + + return sorting_result - # This will be something like this create a SortingResult in memory and copy/translate all data into the new structure. - # sorting_result = ... - # if output == "SortingResult": - # return sorting_result - # elif output in ("WaveformExtractor", "MockWaveformExtractor"): - # return MockWaveformExtractor(sorting_result) From 87daecfadf3e5db20383f45e2a66b97c4ac9ef07 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 23 Jan 2024 10:20:00 +0100 Subject: [PATCH 020/192] oups --- .../core/waveforms_extractor_backwards_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 46555eee54..93c27cded2 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -397,7 +397,7 @@ def _read_old_waveforms_extractor_binary(folder): for unit_index, unit_id in enumerate(sorting.unit_ids): wfs = np.load(waveform_folder / f"waveforms_{unit_id}.npy") mask = some_spikes["unit_index"] == unit_index - waveforms[:, :, :wfs.shape[1]][mask, :, :] = wfs + waveforms[:, :, :wfs.shape[2]][mask, :, :] = wfs sorting_result.random_spikes_indices = random_spikes_indices From c641275c41ccb99c35d7e80de59a9719df58d329 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 23 Jan 2024 19:01:03 +0100 Subject: [PATCH 021/192] Port spike_amplitudes to SortingResult --- src/spikeinterface/core/result_core.py | 40 ++- src/spikeinterface/core/sortingresult.py | 8 +- src/spikeinterface/core/template_tools.py | 50 +++- .../core/tests/test_result_core.py | 10 +- src/spikeinterface/postprocessing/__init__.py | 2 +- .../postprocessing/spike_amplitudes.py | 244 +++++++++--------- .../tests/common_extension_tests.py | 100 ++++++- .../tests/test_spike_amplitudes.py | 66 ++--- 8 files changed, 339 insertions(+), 181 deletions(-) diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 17174c84ac..6cb93b21dc 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -23,6 +23,14 @@ class ComputeWaveforms(ResultExtension): need_recording = True use_nodepipeline = False + @property + def nbefore(self): + return int(self.params["ms_before"] * self.sorting_result.sampling_frequency / 1000.0) + + @property + def nafter(self): + return int(self.params["ms_after"] * self.sorting_result.sampling_frequency / 1000.0) + def _run(self, **kwargs): self.data.clear() @@ -37,9 +45,6 @@ def _run(self, **kwargs): spikes = sorting.to_spike_vector() some_spikes = spikes[self.sorting_result.random_spikes_indices] - nbefore = int(self.params["ms_before"] * sorting.sampling_frequency / 1000.0) - nafter = int(self.params["ms_after"] * sorting.sampling_frequency / 1000.0) - if self.format == "binary_folder": # in that case waveforms are extacted directly in files file_path = self._get_binary_extension_folder() / "waveforms.npy" @@ -62,8 +67,8 @@ def _run(self, **kwargs): recording, some_spikes, unit_ids, - nbefore, - nafter, + self.nbefore, + self.nafter, mode=mode, return_scaled=self.params["return_scaled"], file_path=file_path, @@ -195,9 +200,19 @@ def _set_params(self, operators = ["average", "std"]): assert len(operator) == 2 assert operator[0] == "percentile" - params = dict(operators=operators) + waveforms_extension = self.sorting_result.get_extension("waveforms") + + params = dict(operators=operators, nbefore=waveforms_extension.nbefore, nafter=waveforms_extension.nafter) return params + @property + def nbefore(self): + return self.params["nbefore"] + + @property + def nafter(self): + return self.params["nafter"] + def _select_extension_data(self, unit_ids): keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) @@ -221,6 +236,14 @@ class ComputeFastTemplates(ResultExtension): need_recording = True use_nodepipeline = False + @property + def nbefore(self): + return int(self.params["ms_before"] * self.sorting_result.sampling_frequency / 1000.0) + + @property + def nafter(self): + return int(self.params["ms_after"] * self.sorting_result.sampling_frequency / 1000.0) + def _run(self, **kwargs): self.data.clear() @@ -235,13 +258,10 @@ def _run(self, **kwargs): spikes = sorting.to_spike_vector() some_spikes = spikes[self.sorting_result.random_spikes_indices] - nbefore = int(self.params["ms_before"] * sorting.sampling_frequency / 1000.0) - nafter = int(self.params["ms_after"] * sorting.sampling_frequency / 1000.0) - return_scaled = self.params["return_scaled"] # TODO jobw_kwargs - self.data["average"] = estimate_templates(recording, some_spikes, unit_ids, nbefore, nafter, return_scaled=return_scaled) + self.data["average"] = estimate_templates(recording, some_spikes, unit_ids, self.nbefore, self.nafter, return_scaled=return_scaled) def _set_params(self, ms_before: float = 1.0, diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 1b5b33bd74..0911b5fbf6 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -770,8 +770,12 @@ def compute(self, extension_name, save=True, **params): if extension_class.need_recording: assert self.has_recording(), f"Extension {extension_name} need the recording" for dependency_name in extension_class.depend_on: - ext = self.get_extension(dependency_name) - assert ext is not None, f"Extension {extension_name} need {dependency_name} to be computed first" + if "|" in dependency_name: + # at least one extension must be done : usefull for "templates|fast_templates" for instance + ok = any(self.get_extension(name) is not None for name in dependency_name.split("|")) + else: + ok = self.get_extension(dependency_name) is not None + assert ok, f"Extension {extension_name} need {dependency_name} to be computed first" extension_instance = extension_class(self) extension_instance.set_params(save=save, **params) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 792d1f8b63..09022d0b34 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -6,16 +6,46 @@ from .waveform_extractor import WaveformExtractor from .sparsity import compute_sparsity, _sparsity_doc from .recording_tools import get_channel_distances, get_noise_levels +from .sortingresult import SortingResult + + +def _get_dense_templates_array(one_object): + if isinstance(one_object, Templates): + templates_array = one_object.get_dense_templates() + elif isinstance(one_object, WaveformExtractor): + templates_array = one_object.get_all_templates(mode="average") + elif isinstance(one_object, SortingResult): + ext = one_object.get_extension("templates") + if ext is not None: + templates_array = ext.data["average"] + else: + ext = one_object.get_extension("fast_templates") + if ext is not None: + templates_array = ext.data["average"] + else: + raise ValueError("SortingResult need extension 'templates' or 'fast_templates' to be computed") + else: + raise ValueError("Input should be Templates or WaveformExtractor or SortingResult") + return templates_array -def _get_dense_templates_array(templates_or_waveform_extractor): - if isinstance(templates_or_waveform_extractor, Templates): - templates_array = templates_or_waveform_extractor.get_dense_templates() - elif isinstance(templates_or_waveform_extractor, WaveformExtractor): - templates_array = templates_or_waveform_extractor.get_all_templates(mode="average") +def _get_nbefore(one_object): + if isinstance(one_object, Templates): + return one_object.nbefore + elif isinstance(one_object, WaveformExtractor): + return one_object.nbefore + elif isinstance(one_object, SortingResult): + ext = one_object.get_extension("templates") + if ext is not None: + return ext.nbefore + ext = one_object.get_extension("fast_templates") + if ext is not None: + return ext.nbefore + raise ValueError("SortingResult need extension 'templates' or 'fast_templates' to be computed") else: - raise ValueError("templates_or_waveform_extractor should be Templates or WaveformExtractor") - return templates_array + raise ValueError("Input should be Templates or WaveformExtractor or SortingResult") + + def get_template_amplitudes( @@ -44,7 +74,7 @@ def get_template_amplitudes( unit_ids = templates_or_waveform_extractor.unit_ids - before = templates_or_waveform_extractor.nbefore + before = _get_nbefore(templates_or_waveform_extractor) peak_values = {} @@ -201,7 +231,7 @@ def get_template_extremum_channel_peak_shift(templates_or_waveform_extractor, pe """ unit_ids = templates_or_waveform_extractor.unit_ids channel_ids = templates_or_waveform_extractor.channel_ids - nbefore = templates_or_waveform_extractor.nbefore + nbefore = _get_nbefore(templates_or_waveform_extractor) extremum_channels_ids = get_template_extremum_channel(templates_or_waveform_extractor, peak_sign=peak_sign) @@ -254,8 +284,6 @@ def get_template_extremum_amplitude( unit_ids = templates_or_waveform_extractor.unit_ids channel_ids = templates_or_waveform_extractor.channel_ids - before = templates_or_waveform_extractor.nbefore - extremum_channels_ids = get_template_extremum_channel(templates_or_waveform_extractor, peak_sign=peak_sign, mode=mode) extremum_amplitudes = get_template_amplitudes(templates_or_waveform_extractor, peak_sign=peak_sign, mode=mode) diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index 949663294f..855099ce0e 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -156,9 +156,9 @@ def test_ComputeFastTemplates(format, sparse): # test_ComputeWaveforms(format="zarr", sparse=True) # test_ComputeWaveforms(format="zarr", sparse=False) - # test_ComputeTemplates(format="memory", sparse=True) - # test_ComputeTemplates(format="memory", sparse=False) - # test_ComputeTemplates(format="binary_folder", sparse=True) - # test_ComputeTemplates(format="zarr", sparse=True) + test_ComputeTemplates(format="memory", sparse=True) + test_ComputeTemplates(format="memory", sparse=False) + test_ComputeTemplates(format="binary_folder", sparse=True) + test_ComputeTemplates(format="zarr", sparse=True) - test_ComputeFastTemplates(format="memory", sparse=True) + # test_ComputeFastTemplates(format="memory", sparse=True) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 3aebd13797..762de382b5 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -15,7 +15,7 @@ compute_principal_components, ) -from .spike_amplitudes import compute_spike_amplitudes, SpikeAmplitudesCalculator +from .spike_amplitudes import compute_spike_amplitudes, ComputeSpikeAmplitudes from .correlograms import ( CorrelogramsCalculator, diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 795b3cae7d..58a6390ef7 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -5,18 +5,45 @@ from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift -from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +# from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension -class SpikeAmplitudesCalculator(BaseWaveformExtractorExtension): - """ - Computes spike amplitudes from WaveformExtractor. + +class ComputeSpikeAmplitudes(ResultExtension): """ + ResultExtension + Computes the spike amplitudes. + + Need "templates" or "fast_templates" to be computed first. + + 1. Determine the max channel per unit. + 2. Then a "peak_shift" is estimated because for some sorters the spike index is not always at the + peak. + 3. Amplitudes are extracted in chunks (parallel or not) + Parameters + ---------- + sorting_result: SortingResult + The waveform extractor object + load_if_exists : bool, default: False + Whether to load precomputed spike amplitudes, if they already exist. + peak_sign: "neg" | "pos" | "both", default: "neg + The sign to compute maximum channel + return_scaled: bool + If True and recording has gain_to_uV/offset_to_uV properties, amplitudes are converted to uV. + outputs: "concatenated" | "by_unit", default: "concatenated" + How the output should be returned + {} + """ extension_name = "spike_amplitudes" + depend_on = ["fast_templates|templates", ] + need_recording = True + # TODO: implement this as a pipeline + use_nodepipeline = False - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) self._all_spikes = None @@ -25,37 +52,34 @@ def _set_params(self, peak_sign="neg", return_scaled=True): return params def _select_extension_data(self, unit_ids): - # load filter and save amplitude files - sorting = self.waveform_extractor.sorting - spikes = sorting.to_spike_vector(concatenated=False) - (keep_unit_indices,) = np.nonzero(np.isin(sorting.unit_ids, unit_ids)) - - new_extension_data = dict() - for seg_index in range(sorting.get_num_segments()): - amp_data_name = f"amplitude_segment_{seg_index}" - amps = self._extension_data[amp_data_name] - filtered_idxs = np.isin(spikes[seg_index]["unit_index"], keep_unit_indices) - new_extension_data[amp_data_name] = amps[filtered_idxs] - return new_extension_data + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + + spikes = self.sorting_result.sorting.to_spike_vector() + keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) + + new_data = dict() + new_data["amplitudes"] = self.data["amplitudes"][keep_spike_mask] + + return new_data def _run(self, **job_kwargs): - if not self.waveform_extractor.has_recording(): - self.waveform_extractor.delete_extension(SpikeAmplitudesCalculator.extension_name) - raise ValueError("compute_spike_amplitudes() cannot run with a WaveformExtractor in recordless mode.") + if not self.sorting_result.has_recording(): + self.sorting_result.delete_extension(ComputeSpikeAmplitudes.extension_name) + raise ValueError("compute_spike_amplitudes() cannot run with a SortingResult in recordless mode.") job_kwargs = fix_job_kwargs(job_kwargs) - we = self.waveform_extractor - recording = we.recording - sorting = we.sorting + sorting_result = self.sorting_result + recording = sorting_result.recording + sorting = sorting_result.sorting all_spikes = sorting.to_spike_vector() self._all_spikes = all_spikes - peak_sign = self._params["peak_sign"] - return_scaled = self._params["return_scaled"] + peak_sign = self.params["peak_sign"] + return_scaled = self.params["return_scaled"] - extremum_channels_index = get_template_extremum_channel(we, peak_sign=peak_sign, outputs="index") - peak_shifts = get_template_extremum_channel_peak_shift(we, peak_sign=peak_sign) + extremum_channels_index = get_template_extremum_channel(sorting_result, peak_sign=peak_sign, outputs="index") + peak_shifts = get_template_extremum_channel_peak_shift(sorting_result, peak_sign=peak_sign) # put extremum_channels_index and peak_shifts in vector way extremum_channels_index = np.array( @@ -77,104 +101,81 @@ def _run(self, **job_kwargs): processor = ChunkRecordingExecutor( recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs ) - out = processor.run() - amps, segments = zip(*out) + # out = processor.run() + # amps, segments = zip(*out) + # amps = np.concatenate(amps) + # segments = np.concatenate(segments) + + # for segment_index in range(recording.get_num_segments()): + # mask = segments == segment_index + # amps_seg = amps[mask] + # self._extension_data[f"amplitude_segment_{segment_index}"] = amps_seg + amps = processor.run() amps = np.concatenate(amps) - segments = np.concatenate(segments) - - for segment_index in range(recording.get_num_segments()): - mask = segments == segment_index - amps_seg = amps[mask] - self._extension_data[f"amplitude_segment_{segment_index}"] = amps_seg - - def get_data(self, outputs="concatenated"): - """ - Get computed spike amplitudes. - - Parameters - ---------- - outputs : "concatenated" | "by_unit", default: "concatenated" - The output format - - Returns - ------- - spike_amplitudes : np.array or dict - The spike amplitudes as an array (outputs="concatenated") or - as a dict with units as key and spike amplitudes as values. - """ - we = self.waveform_extractor - sorting = we.sorting - - if outputs == "concatenated": - amplitudes = [] - for segment_index in range(we.get_num_segments()): - amplitudes.append(self._extension_data[f"amplitude_segment_{segment_index}"]) - return amplitudes - elif outputs == "by_unit": - all_spikes = sorting.to_spike_vector(concatenated=False) - - amplitudes_by_unit = [] - for segment_index in range(we.get_num_segments()): - amplitudes_by_unit.append({}) - for unit_index, unit_id in enumerate(sorting.unit_ids): - spike_labels = all_spikes[segment_index]["unit_index"] - mask = spike_labels == unit_index - amps = self._extension_data[f"amplitude_segment_{segment_index}"][mask] - amplitudes_by_unit[segment_index][unit_id] = amps - return amplitudes_by_unit - - @staticmethod - def get_extension_function(): - return compute_spike_amplitudes - - -WaveformExtractor.register_extension(SpikeAmplitudesCalculator) - - -def compute_spike_amplitudes( - waveform_extractor, load_if_exists=False, peak_sign="neg", return_scaled=True, outputs="concatenated", **job_kwargs -): - """ - Computes the spike amplitudes from a WaveformExtractor. + self.data["amplitudes"] = amps - 1. The waveform extractor is used to determine the max channel per unit. - 2. Then a "peak_shift" is estimated because for some sorters the spike index is not always at the - peak. - 3. Amplitudes are extracted in chunks (parallel or not) + # def get_data(self, outputs="concatenated"): + # """ + # Get computed spike amplitudes. - Parameters - ---------- - waveform_extractor: WaveformExtractor - The waveform extractor object - load_if_exists : bool, default: False - Whether to load precomputed spike amplitudes, if they already exist. - peak_sign: "neg" | "pos" | "both", default: "neg - The sign to compute maximum channel - return_scaled: bool - If True and recording has gain_to_uV/offset_to_uV properties, amplitudes are converted to uV. - outputs: "concatenated" | "by_unit", default: "concatenated" - How the output should be returned - {} + # Parameters + # ---------- + # outputs : "concatenated" | "by_unit", default: "concatenated" + # The output format - Returns - ------- - amplitudes: np.array or list of dict - The spike amplitudes. - - If "concatenated" all amplitudes for all spikes and all units are concatenated - - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) - """ - if load_if_exists and waveform_extractor.has_extension(SpikeAmplitudesCalculator.extension_name): - sac = waveform_extractor.load_extension(SpikeAmplitudesCalculator.extension_name) - else: - sac = SpikeAmplitudesCalculator(waveform_extractor) - sac.set_params(peak_sign=peak_sign, return_scaled=return_scaled) - sac.run(**job_kwargs) + # Returns + # ------- + # spike_amplitudes : np.array or dict + # The spike amplitudes as an array (outputs="concatenated") or + # as a dict with units as key and spike amplitudes as values. + # """ + # sorting_result = self.sorting_result + # sorting = sorting_result.sorting + + # if outputs == "concatenated": + # amplitudes = [] + # for segment_index in range(sorting_result.get_num_segments()): + # amplitudes.append(self._extension_data[f"amplitude_segment_{segment_index}"]) + # return amplitudes + # elif outputs == "by_unit": + # all_spikes = sorting.to_spike_vector(concatenated=False) + + # amplitudes_by_unit = [] + # for segment_index in range(sorting_result.get_num_segments()): + # amplitudes_by_unit.append({}) + # for unit_index, unit_id in enumerate(sorting.unit_ids): + # spike_labels = all_spikes[segment_index]["unit_index"] + # mask = spike_labels == unit_index + # amps = self._extension_data[f"amplitude_segment_{segment_index}"][mask] + # amplitudes_by_unit[segment_index][unit_id] = amps + # return amplitudes_by_unit + + # @staticmethod + # def get_extension_function(): + # return compute_spike_amplitudes + + +# WaveformExtractor.register_extension(SpikeAmplitudesCalculator) +register_result_extension(ComputeSpikeAmplitudes) + +compute_spike_amplitudes = ComputeSpikeAmplitudes.function_factory() + +# def compute_spike_amplitudes( +# sorting_result, load_if_exists=False, peak_sign="neg", return_scaled=True, outputs="concatenated", **job_kwargs +# ): + +# if load_if_exists and sorting_result.has_extension(SpikeAmplitudesCalculator.extension_name): +# sac = sorting_result.load_extension(SpikeAmplitudesCalculator.extension_name) +# else: +# sac = SpikeAmplitudesCalculator(sorting_result) +# sac.set_params(peak_sign=peak_sign, return_scaled=return_scaled) +# sac.run(**job_kwargs) - amps = sac.get_data(outputs=outputs) - return amps +# amps = sac.get_data(outputs=outputs) +# return amps -compute_spike_amplitudes.__doc__.format(_shared_job_kwargs_doc) +# compute_spike_amplitudes.__doc__.format(_shared_job_kwargs_doc) def _init_worker_spike_amplitudes(recording, sorting, extremum_channels_index, peak_shifts, return_scaled): @@ -241,6 +242,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): # and get amplitudes amplitudes = traces[sample_inds, chan_inds] - segments = np.zeros(amplitudes.size, dtype="int64") + segment_index + # segments = np.zeros(amplitudes.size, dtype="int64") + segment_index - return amplitudes, segments + # return amplitudes, segments + return amplitudes diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 729aff3a4c..e6bd0e722c 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -5,14 +5,110 @@ import platform from pathlib import Path -from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity -from spikeinterface.core.generate import generate_ground_truth_recording +# from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity +# from spikeinterface.core.generate import generate_ground_truth_recording + +from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.core import start_sorting_result +from spikeinterface.core import estimate_sparsity + if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "postprocessing" else: cache_folder = Path("cache_folder") / "postprocessing" +def get_dataset(): + recording, sorting = generate_ground_truth_recording( + durations=[30.0, 20.0], sampling_frequency=24000.0, num_channels=10, num_units=5, + generate_sorting_kwargs=dict(firing_rates=3.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params_range=dict( + alpha=(9_000.0, 12_000.0), + ) + ), + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + return recording, sorting + +def get_sorting_result(recording, sorting, format="memory", sparsity=None, name=""): + sparse = sparsity is not None + if format == "memory": + folder = None + elif format == "binary_folder": + folder = cache_folder / f"test_{name}_sparse{sparse}_{format}" + elif format == "zarr": + folder = cache_folder / f"test_{name}_sparse{sparse}_{format}.zarr" + if folder and folder.exists(): + shutil.rmtree(folder) + + sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity) + + return sortres + +class ResultExtensionCommonTestSuite: + """ + Common tests with class approach to compute extension on several cases (3 format x 2 sparsity) + + This automatically precompute extension dependencies with default params before running computation. + + This also test the select_units() ability. + """ + extension_class = None + extension_function_kwargs_list = None + def setUp(self): + + recording, sorting = get_dataset() + # sparsity is computed once for all cases to save processing + sparsity = estimate_sparsity(recording, sorting) + + self.sorting_results = {} + for sparse in (True, False): + for format in ("memory", "binary_folder", "zarr"): + sparsity_ = sparsity if sparse else None + sorting_result = get_sorting_result(recording, sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name) + key = f"spare{sparse}_{format}" + self.sorting_results[key] = sorting_result + + @property + def extension_name(self): + return self.extension_class.extension_name + + def _check_one(self, sorting_result): + sorting_result.select_random_spikes(max_spikes_per_unit=50, seed=2205) + + for dependency_name in self.extension_class.depend_on: + if "|" in dependency_name: + dependency_name = dependency_name.split("|")[0] + sorting_result.compute(dependency_name) + + + for kwargs in self.extension_function_kwargs_list: + sorting_result.compute(self.extension_name, **kwargs) + ext = sorting_result.get_extension(self.extension_name) + assert ext is not None + assert len(ext.data) > 0 + + some_unit_ids = sorting_result.unit_ids[::2] + sliced = sorting_result.select_units(some_unit_ids, format="memory") + assert np.array_equal(sliced.unit_ids, sorting_result.unit_ids[::2]) + # print(sliced) + + + def test_extension(self): + + for key, sorting_result in self.sorting_results.items(): + print() + print(key) + self._check_one(sorting_result) + + class WaveformExtensionCommonTestSuite: """ diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index d96598691e..bf03ec00ce 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -1,44 +1,52 @@ import unittest import numpy as np -from spikeinterface.postprocessing import SpikeAmplitudesCalculator +from spikeinterface.postprocessing import ComputeSpikeAmplitudes -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite +# from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import get_sorting_result, ResultExtensionCommonTestSuite -class SpikeAmplitudesExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = SpikeAmplitudesCalculator - extension_data_names = ["amplitude_segment_0"] + +class ComputeSpikeAmplitudesTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputeSpikeAmplitudes extension_function_kwargs_list = [ - dict(peak_sign="neg", outputs="concatenated", chunk_size=10000, n_jobs=1), - dict(peak_sign="neg", outputs="by_unit", chunk_size=10000, n_jobs=1), + dict(), ] - def test_scaled(self): - amplitudes_scaled = self.extension_class.get_extension_function()( - self.we1, peak_sign="neg", outputs="concatenated", chunk_size=10000, n_jobs=1, return_scaled=True - ) - amplitudes_unscaled = self.extension_class.get_extension_function()( - self.we1, peak_sign="neg", outputs="concatenated", chunk_size=10000, n_jobs=1, return_scaled=False - ) - gain = self.we1.recording.get_channel_gains()[0] - - assert np.allclose(amplitudes_scaled[0], amplitudes_unscaled[0] * gain) - - def test_parallel(self): - amplitudes1 = self.extension_class.get_extension_function()( - self.we1, peak_sign="neg", load_if_exists=False, outputs="concatenated", chunk_size=10000, n_jobs=1 - ) - # TODO : fix multi processing for spike amplitudes!!!!!!! - amplitudes2 = self.extension_class.get_extension_function()( - self.we1, peak_sign="neg", load_if_exists=False, outputs="concatenated", chunk_size=10000, n_jobs=2 - ) - - assert np.array_equal(amplitudes1[0], amplitudes2[0]) +# class SpikeAmplitudesExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): +# extension_class = SpikeAmplitudesCalculator +# extension_data_names = ["amplitude_segment_0"] +# extension_function_kwargs_list = [ +# dict(peak_sign="neg", outputs="concatenated", chunk_size=10000, n_jobs=1), +# dict(peak_sign="neg", outputs="by_unit", chunk_size=10000, n_jobs=1), +# ] + +# def test_scaled(self): +# amplitudes_scaled = self.extension_class.get_extension_function()( +# self.we1, peak_sign="neg", outputs="concatenated", chunk_size=10000, n_jobs=1, return_scaled=True +# ) +# amplitudes_unscaled = self.extension_class.get_extension_function()( +# self.we1, peak_sign="neg", outputs="concatenated", chunk_size=10000, n_jobs=1, return_scaled=False +# ) +# gain = self.we1.recording.get_channel_gains()[0] + +# assert np.allclose(amplitudes_scaled[0], amplitudes_unscaled[0] * gain) + +# def test_parallel(self): +# amplitudes1 = self.extension_class.get_extension_function()( +# self.we1, peak_sign="neg", load_if_exists=False, outputs="concatenated", chunk_size=10000, n_jobs=1 +# ) +# # TODO : fix multi processing for spike amplitudes!!!!!!! +# amplitudes2 = self.extension_class.get_extension_function()( +# self.we1, peak_sign="neg", load_if_exists=False, outputs="concatenated", chunk_size=10000, n_jobs=2 +# ) + +# assert np.array_equal(amplitudes1[0], amplitudes2[0]) if __name__ == "__main__": - test = SpikeAmplitudesExtensionTest() + test = ComputeSpikeAmplitudesTest() test.setUp() test.test_extension() # test.test_scaled() From c5f67f769fc8d7dd4fa3332967e3d5dbc42f05e7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 23 Jan 2024 20:24:13 +0100 Subject: [PATCH 022/192] Port "correlograms" to SortingResult --- src/spikeinterface/core/sortingresult.py | 2 + src/spikeinterface/postprocessing/__init__.py | 4 +- .../postprocessing/correlograms.py | 172 ++++++++++-------- .../tests/common_extension_tests.py | 3 +- .../postprocessing/tests/test_correlograms.py | 58 +++--- .../tests/test_spike_amplitudes.py | 2 +- 6 files changed, 125 insertions(+), 116 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 0911b5fbf6..94b605d663 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -547,6 +547,7 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> if unit_ids is not None: # when only some unit_ids then the sorting must be sliced + # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! sorting_provenance = sorting_provenance.select_units(unit_ids) if format == "memory": @@ -614,6 +615,7 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingResult we : WaveformExtractor The newly create waveform extractor with the selected units """ + # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! return self._save_or_select(format=format, folder=folder, unit_ids=unit_ids) def copy(self): diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 762de382b5..1269376ac6 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -18,10 +18,10 @@ from .spike_amplitudes import compute_spike_amplitudes, ComputeSpikeAmplitudes from .correlograms import ( - CorrelogramsCalculator, + ComputeCorrelograms, + compute_correlograms, compute_autocorrelogram_from_spiketrain, compute_crosscorrelogram_from_spiketrain, - compute_correlograms, correlogram_for_one_segment, compute_correlograms_numba, compute_correlograms_numpy, diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 1b82548c15..8b76f0e8b2 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -1,8 +1,7 @@ import math import warnings import numpy as np -from ..core import WaveformExtractor -from ..core.waveform_extractor import BaseWaveformExtractorExtension +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension try: import numba @@ -12,19 +11,39 @@ HAVE_NUMBA = False -class CorrelogramsCalculator(BaseWaveformExtractorExtension): - """Compute correlograms of spike trains. +class ComputeCorrelograms(ResultExtension): + """ + Compute auto and cross correlograms. Parameters ---------- - waveform_extractor: WaveformExtractor - A waveform extractor object - """ + sorting_result: SortingResult + A SortingResult object + window_ms : float, default: 100.0 + The window in ms + bin_ms : float, default: 5 + The bin size in ms + method : "auto" | "numpy" | "numba", default: "auto" + If "auto" and numba is installed, numba is used, otherwise numpy is used + Returns + ------- + ccgs : np.array + Correlograms with shape (num_units, num_units, num_bins) + The diagonal of ccgs is the auto correlogram. + ccgs[A, B, :] is the symetrie of ccgs[B, A, :] + ccgs[A, B, :] have to be read as the histogram of spiketimesA - spiketimesB + bins : np.array + The bin edges in ms + + """ extension_name = "correlograms" + depend_on = [] + need_recording = False + use_nodepipeline = False - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) @@ -33,38 +52,35 @@ def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str def _select_extension_data(self, unit_ids): # filter metrics dataframe - unit_indices = self.waveform_extractor.sorting.ids_to_indices(unit_ids) - new_ccgs = self._extension_data["ccgs"][unit_indices][:, unit_indices] - new_bins = self._extension_data["bins"] - new_extension_data = dict(ccgs=new_ccgs, bins=new_bins) - return new_extension_data + unit_indices = self.sorting_result.sorting.ids_to_indices(unit_ids) + new_ccgs = self.data["ccgs"][unit_indices][:, unit_indices] + new_bins = self.data["bins"] + new_data = dict(ccgs=new_ccgs, bins=new_bins) + return new_data def _run(self): - ccgs, bins = _compute_correlograms(self.waveform_extractor.sorting, **self._params) - self._extension_data["ccgs"] = ccgs - self._extension_data["bins"] = bins + ccgs, bins = _compute_correlograms(self.sorting_result.sorting, **self.params) + self.data["ccgs"] = ccgs + self.data["bins"] = bins - def get_data(self): - """ - Get the computed ISI histograms. + # def get_data(self): + # """ + # Get the computed ISI histograms. - Returns - ------- - isi_histograms : np.array - 2D array with ISI histograms (num_units, num_bins) - bins : np.array - 1D array with bins in ms - """ - msg = "Crosscorrelograms are not computed. Use the 'run()' function." - assert self._extension_data["ccgs"] is not None and self._extension_data["bins"] is not None, msg - return self._extension_data["ccgs"], self._extension_data["bins"] + # Returns + # ------- + # isi_histograms : np.array + # 2D array with ISI histograms (num_units, num_bins) + # bins : np.array + # 1D array with bins in ms + # """ + # msg = "Crosscorrelograms are not computed. Use the 'run()' function." + # assert self.data["ccgs"] is not None and self.data["bins"] is not None, msg + # return self.data["ccgs"], self.data["bins"] - @staticmethod - def get_extension_function(): - return compute_correlograms - -WaveformExtractor.register_extension(CorrelogramsCalculator) +register_result_extension(ComputeCorrelograms) +compute_correlograms = ComputeCorrelograms.function_factory() def _make_bins(sorting, window_ms, bin_ms): @@ -134,49 +150,49 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ return _compute_crosscorr_numba(spike_times1.astype(np.int64), spike_times2.astype(np.int64), window_size, bin_size) -def compute_correlograms( - waveform_or_sorting_extractor, - load_if_exists=False, - window_ms: float = 50.0, - bin_ms: float = 1.0, - method: str = "auto", -): - """Compute auto and cross correlograms. - - Parameters - ---------- - waveform_or_sorting_extractor : WaveformExtractor or BaseSorting - If WaveformExtractor, the correlograms are saved as WaveformExtensions - load_if_exists : bool, default: False - Whether to load precomputed crosscorrelograms, if they already exist - window_ms : float, default: 100.0 - The window in ms - bin_ms : float, default: 5 - The bin size in ms - method : "auto" | "numpy" | "numba", default: "auto" - If "auto" and numba is installed, numba is used, otherwise numpy is used - - Returns - ------- - ccgs : np.array - Correlograms with shape (num_units, num_units, num_bins) - The diagonal of ccgs is the auto correlogram. - ccgs[A, B, :] is the symetrie of ccgs[B, A, :] - ccgs[A, B, :] have to be read as the histogram of spiketimesA - spiketimesB - bins : np.array - The bin edges in ms - """ - if isinstance(waveform_or_sorting_extractor, WaveformExtractor): - if load_if_exists and waveform_or_sorting_extractor.is_extension(CorrelogramsCalculator.extension_name): - ccc = waveform_or_sorting_extractor.load_extension(CorrelogramsCalculator.extension_name) - else: - ccc = CorrelogramsCalculator(waveform_or_sorting_extractor) - ccc.set_params(window_ms=window_ms, bin_ms=bin_ms, method=method) - ccc.run() - ccgs, bins = ccc.get_data() - return ccgs, bins - else: - return _compute_correlograms(waveform_or_sorting_extractor, window_ms=window_ms, bin_ms=bin_ms, method=method) +# def compute_correlograms( +# waveform_or_sorting_extractor, +# load_if_exists=False, +# window_ms: float = 50.0, +# bin_ms: float = 1.0, +# method: str = "auto", +# ): +# """Compute auto and cross correlograms. + +# Parameters +# ---------- +# waveform_or_sorting_extractor : WaveformExtractor or BaseSorting +# If WaveformExtractor, the correlograms are saved as WaveformExtensions +# load_if_exists : bool, default: False +# Whether to load precomputed crosscorrelograms, if they already exist +# window_ms : float, default: 100.0 +# The window in ms +# bin_ms : float, default: 5 +# The bin size in ms +# method : "auto" | "numpy" | "numba", default: "auto" +# If "auto" and numba is installed, numba is used, otherwise numpy is used + +# Returns +# ------- +# ccgs : np.array +# Correlograms with shape (num_units, num_units, num_bins) +# The diagonal of ccgs is the auto correlogram. +# ccgs[A, B, :] is the symetrie of ccgs[B, A, :] +# ccgs[A, B, :] have to be read as the histogram of spiketimesA - spiketimesB +# bins : np.array +# The bin edges in ms +# """ +# if isinstance(waveform_or_sorting_extractor, WaveformExtractor): +# if load_if_exists and waveform_or_sorting_extractor.is_extension(CorrelogramsCalculator.extension_name): +# ccc = waveform_or_sorting_extractor.load_extension(CorrelogramsCalculator.extension_name) +# else: +# ccc = CorrelogramsCalculator(waveform_or_sorting_extractor) +# ccc.set_params(window_ms=window_ms, bin_ms=bin_ms, method=method) +# ccc.run() +# ccgs, bins = ccc.get_data() +# return ccgs, bins +# else: +# return _compute_correlograms(waveform_or_sorting_extractor, window_ms=window_ms, bin_ms=bin_ms, method=method) def _compute_correlograms(sorting, window_ms, bin_ms, method="auto"): diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index e6bd0e722c..32f6c11017 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -90,6 +90,7 @@ def _check_one(self, sorting_result): for kwargs in self.extension_function_kwargs_list: + print(' kwargs', kwargs) sorting_result.compute(self.extension_name, **kwargs) ext = sorting_result.get_extension(self.extension_name) assert ext is not None @@ -105,7 +106,7 @@ def test_extension(self): for key, sorting_result in self.sorting_results.items(): print() - print(key) + print(self.extension_name, key) self._check_one(sorting_result) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 3d562ba5a0..1d6cb24826 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -2,39 +2,30 @@ import numpy as np from typing import List - -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite, cache_folder - -from spikeinterface import download_dataset, extract_waveforms, NumpySorting -import spikeinterface.extractors as se - -from spikeinterface.postprocessing import compute_correlograms, CorrelogramsCalculator -from spikeinterface.postprocessing.correlograms import _make_bins -from spikeinterface.core import generate_sorting - - try: import numba - HAVE_NUMBA = True except ModuleNotFoundError as err: HAVE_NUMBA = False -class CorrelogramsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = CorrelogramsCalculator - extension_data_names = ["ccgs", "bins"] - extension_function_kwargs_list = [dict(method="numpy")] +from spikeinterface import NumpySorting, generate_sorting +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing import ComputeCorrelograms +from spikeinterface.postprocessing.correlograms import _compute_correlograms, _make_bins + + - def test_compute_correlograms(self): - methods = ["numpy", "auto"] - if HAVE_NUMBA: - methods.append("numba") +class ComputeCorrelogramsTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputeCorrelograms + extension_function_kwargs_list = [ + dict(method="numpy"), + dict(method="auto"), + ] - sorting = self.we1.sorting +if HAVE_NUMBA: + ComputeCorrelogramsTest.extension_function_kwargs_list.append(dict(method="numba")) - _test_correlograms(sorting, window_ms=60.0, bin_ms=2.0, methods=methods) - _test_correlograms(sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) def test_make_bins(): @@ -55,7 +46,7 @@ def test_make_bins(): def _test_correlograms(sorting, window_ms, bin_ms, methods): for method in methods: - correlograms, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + correlograms, bins = _compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) if method == "numpy": ref_correlograms = correlograms ref_bins = bins @@ -99,7 +90,7 @@ def test_flat_cross_correlogram(): # ~ fig, ax = plt.subplots() for method in methods: - correlograms, bins = compute_correlograms(sorting, window_ms=50.0, bin_ms=1.0, method=method) + correlograms, bins = _compute_correlograms(sorting, window_ms=50.0, bin_ms=1.0, method=method) cc = correlograms[0, 1, :].copy() m = np.mean(cc) assert np.all(cc > (m * 0.90)) @@ -131,7 +122,7 @@ def test_auto_equal_cross_correlograms(): sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=10000.0) for method in methods: - correlograms, bins = compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) + correlograms, bins = _compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) num_half_bins = correlograms.shape[2] // 2 @@ -181,7 +172,7 @@ def test_detect_injected_correlation(): sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=sampling_frequency) for method in methods: - correlograms, bins = compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) + correlograms, bins = _compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) cc_01 = correlograms[0, 1, :] cc_10 = correlograms[1, 0, :] @@ -204,13 +195,12 @@ def test_detect_injected_correlation(): if __name__ == "__main__": - test_make_bins() - test_equal_results_correlograms() - test_flat_cross_correlogram() - test_auto_equal_cross_correlograms() - test_detect_injected_correlation() + # test_make_bins() + # test_equal_results_correlograms() + # test_flat_cross_correlogram() + # test_auto_equal_cross_correlograms() + # test_detect_injected_correlation() - test = CorrelogramsExtensionTest() + test = ComputeCorrelogramsTest() test.setUp() - test.test_compute_correlograms() test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index bf03ec00ce..60416570c1 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -5,7 +5,7 @@ # from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite -from spikeinterface.postprocessing.tests.common_extension_tests import get_sorting_result, ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite class ComputeSpikeAmplitudesTest(ResultExtensionCommonTestSuite, unittest.TestCase): From e53d1b29cef7865f5b6733a24369d7af52656119 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 24 Jan 2024 09:00:03 +0100 Subject: [PATCH 023/192] Port ISI and noise_level to SortingResult --- dev_pool.py | 74 ++++++++ src/spikeinterface/postprocessing/__init__.py | 4 +- src/spikeinterface/postprocessing/isi.py | 170 ++++++++++-------- .../postprocessing/noise_level.py | 87 ++++----- .../postprocessing/tests/test_correlograms.py | 5 +- .../postprocessing/tests/test_isi.py | 25 ++- .../postprocessing/tests/test_noise_levels.py | 16 +- 7 files changed, 229 insertions(+), 152 deletions(-) create mode 100644 dev_pool.py diff --git a/dev_pool.py b/dev_pool.py new file mode 100644 index 0000000000..52f5a7572a --- /dev/null +++ b/dev_pool.py @@ -0,0 +1,74 @@ +import multiprocessing +from concurrent.futures import ProcessPoolExecutor + +def f(x): + import os + # global _worker_num + p = multiprocessing.current_process() + print(p, type(p), p.name, p._identity[0], type(p._identity[0]), p.ident) + return x * x + + +def init_worker(lock, array_pid): + print(array_pid, len(array_pid)) + child_process = multiprocessing.current_process() + + lock.acquire() + num_worker = None + for i in range(len(array_pid)): + print(array_pid[i]) + if array_pid[i] == -1: + num_worker = i + array_pid[i] = child_process.ident + break + print(num_worker, child_process.ident) + lock.release() + +num_worker = 6 +lock = multiprocessing.Lock() +array_pid = multiprocessing.Array('i', num_worker) +for i in range(num_worker): + array_pid[i] = -1 + + +# with ProcessPoolExecutor( +# max_workers=4, +# ) as executor: +# print(executor._processes) +# results = executor.map(f, range(6)) + +with ProcessPoolExecutor( + max_workers=4, + initializer=init_worker, + initargs=(lock, array_pid) +) as executor: + print(executor._processes) + results = executor.map(f, range(6)) + +exit() +# global _worker_num +# def set_worker_index(i): +# global _worker_num +# _worker_num = i + +p = multiprocessing.Pool(processes=3) +# children = multiprocessing.active_children() +# for i, child in enumerate(children): +# child.submit(set_worker_index) + +# print(children) +print(p.map(f, range(6))) +p.close() +p.join() + +p = multiprocessing.Pool(processes=3) +print(p.map(f, range(6))) +print(p.map(f, range(6))) + +p.close() +p.join() + + +# print(multiprocessing.current_process()) +# p = multiprocessing.current_process() +# print(p._identity) \ No newline at end of file diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 1269376ac6..55044bb018 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -28,7 +28,7 @@ ) from .isi import ( - ISIHistogramsCalculator, + ComputeISIHistograms, compute_isi_histograms, compute_isi_histograms_numpy, compute_isi_histograms_numba, @@ -46,4 +46,4 @@ from .alignsorting import align_sorting, AlignSortingExtractor -from .noise_level import compute_noise_levels, NoiseLevelsCalculator +from .noise_level import compute_noise_levels, ComputeNoiseLevels diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 1185e179b1..a4e2c41818 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -1,6 +1,6 @@ import numpy as np -from ..core import WaveformExtractor -from ..core.waveform_extractor import BaseWaveformExtractorExtension + +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension try: import numba @@ -10,19 +10,35 @@ HAVE_NUMBA = False -class ISIHistogramsCalculator(BaseWaveformExtractorExtension): - """Compute ISI histograms of spike trains. +class ComputeISIHistograms(ResultExtension): + """Compute ISI histograms. Parameters ---------- - waveform_extractor: WaveformExtractor - A waveform extractor object + sorting_result: SortingResult + A SortingResult object + window_ms : float, default: 50 + The window in ms + bin_ms : float, default: 1 + The bin size in ms + method : "auto" | "numpy" | "numba", default: "auto" + . If "auto" and numba is installed, numba is used, otherwise numpy is used + + Returns + ------- + isi_histograms : np.array + IDI_histograms with shape (num_units, num_bins) + bins : np.array + The bin edges in ms """ extension_name = "isi_histograms" + depend_on = [] + need_recording = False + use_nodepipeline = False - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) @@ -31,80 +47,78 @@ def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str def _select_extension_data(self, unit_ids): # filter metrics dataframe - unit_indices = self.waveform_extractor.sorting.ids_to_indices(unit_ids) - new_isi_hists = self._extension_data["isi_histograms"][unit_indices, :] - new_bins = self._extension_data["bins"] + unit_indices = self.sorting_result.sorting.ids_to_indices(unit_ids) + new_isi_hists = self.data["isi_histograms"][unit_indices, :] + new_bins = self.data["bins"] new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) return new_extension_data def _run(self): - isi_histograms, bins = _compute_isi_histograms(self.waveform_extractor.sorting, **self._params) - self._extension_data["isi_histograms"] = isi_histograms - self._extension_data["bins"] = bins - - def get_data(self): - """ - Get the computed ISI histograms. - - Returns - ------- - isi_histograms : np.array - 2D array with ISI histograms (num_units, num_bins) - bins : np.array - 1D array with bins in ms - """ - msg = "ISI histograms are not computed. Use the 'run()' function." - assert self._extension_data["isi_histograms"] is not None and self._extension_data["bins"] is not None, msg - return self._extension_data["isi_histograms"], self._extension_data["bins"] - - @staticmethod - def get_extension_function(): - return compute_isi_histograms - - -WaveformExtractor.register_extension(ISIHistogramsCalculator) - - -def compute_isi_histograms( - waveform_or_sorting_extractor, - load_if_exists=False, - window_ms: float = 50.0, - bin_ms: float = 1.0, - method: str = "auto", -): - """Compute ISI histograms. - - Parameters - ---------- - waveform_or_sorting_extractor : WaveformExtractor or BaseSorting - If WaveformExtractor, the ISI histograms are saved as WaveformExtensions - load_if_exists : bool, default: False - Whether to load precomputed crosscorrelograms, if they already exist - window_ms : float, default: 50 - The window in ms - bin_ms : float, default: 1 - The bin size in ms - method : "auto" | "numpy" | "numba", default: "auto" - . If "auto" and numba is installed, numba is used, otherwise numpy is used - - Returns - ------- - isi_histograms : np.array - IDI_histograms with shape (num_units, num_bins) - bins : np.array - The bin edges in ms - """ - if isinstance(waveform_or_sorting_extractor, WaveformExtractor): - if load_if_exists and waveform_or_sorting_extractor.is_extension(ISIHistogramsCalculator.extension_name): - isic = waveform_or_sorting_extractor.load_extension(ISIHistogramsCalculator.extension_name) - else: - isic = ISIHistogramsCalculator(waveform_or_sorting_extractor) - isic.set_params(window_ms=window_ms, bin_ms=bin_ms, method=method) - isic.run() - isi_histograms, bins = isic.get_data() - return isi_histograms, bins - else: - return _compute_isi_histograms(waveform_or_sorting_extractor, window_ms=window_ms, bin_ms=bin_ms, method=method) + isi_histograms, bins = _compute_isi_histograms(self.sorting_result.sorting, **self.params) + self.data["isi_histograms"] = isi_histograms + self.data["bins"] = bins + + # def get_data(self): + # """ + # Get the computed ISI histograms. + + # Returns + # ------- + # isi_histograms : np.array + # 2D array with ISI histograms (num_units, num_bins) + # bins : np.array + # 1D array with bins in ms + # """ + # msg = "ISI histograms are not computed. Use the 'run()' function." + # assert self.data["isi_histograms"] is not None and self.data["bins"] is not None, msg + # return self.data["isi_histograms"], self.data["bins"] + + +register_result_extension(ComputeISIHistograms) +compute_isi_histograms = ComputeISIHistograms.function_factory() + + + +# def compute_isi_histograms( +# waveform_or_sorting_extractor, +# load_if_exists=False, +# window_ms: float = 50.0, +# bin_ms: float = 1.0, +# method: str = "auto", +# ): +# """Compute ISI histograms. + +# Parameters +# ---------- +# waveform_or_sorting_extractor : WaveformExtractor or BaseSorting +# If WaveformExtractor, the ISI histograms are saved as WaveformExtensions +# load_if_exists : bool, default: False +# Whether to load precomputed crosscorrelograms, if they already exist +# window_ms : float, default: 50 +# The window in ms +# bin_ms : float, default: 1 +# The bin size in ms +# method : "auto" | "numpy" | "numba", default: "auto" +# . If "auto" and numba is installed, numba is used, otherwise numpy is used + +# Returns +# ------- +# isi_histograms : np.array +# IDI_histograms with shape (num_units, num_bins) +# bins : np.array +# The bin edges in ms +# """ +# if isinstance(waveform_or_sorting_extractor, WaveformExtractor): +# if load_if_exists and waveform_or_sorting_extractor.is_extension(ISIHistogramsCalculator.extension_name): +# isic = waveform_or_sorting_extractor.load_extension(ISIHistogramsCalculator.extension_name) +# else: +# isic = ISIHistogramsCalculator(waveform_or_sorting_extractor) +# isic.set_params(window_ms=window_ms, bin_ms=bin_ms, method=method) +# isic.run() +# isi_histograms, bins = isic.get_data() +# return isi_histograms, bins +# else: +# return _compute_isi_histograms(waveform_or_sorting_extractor, window_ms=window_ms, bin_ms=bin_ms, method=method) def _compute_isi_histograms(sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): diff --git a/src/spikeinterface/postprocessing/noise_level.py b/src/spikeinterface/postprocessing/noise_level.py index db93731977..4b56fc81a3 100644 --- a/src/spikeinterface/postprocessing/noise_level.py +++ b/src/spikeinterface/postprocessing/noise_level.py @@ -1,47 +1,9 @@ -from spikeinterface.core.waveform_extractor import BaseWaveformExtractorExtension, WaveformExtractor -from spikeinterface.core import get_noise_levels - - -class NoiseLevelsCalculator(BaseWaveformExtractorExtension): - extension_name = "noise_levels" - - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) - - def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None): - params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed) - return params - - def _select_extension_data(self, unit_ids): - # this do not depend on units - return self._extension_data - - def _run(self): - return_scaled = self.waveform_extractor.return_scaled - self._extension_data["noise_levels"] = get_noise_levels( - self.waveform_extractor.recording, return_scaled=return_scaled, **self._params - ) - - def get_data(self): - """ - Get computed noise levels. - - Returns - ------- - noise_levels : np.array - The noise levels associated to each channel. - """ - return self._extension_data["noise_levels"] - - @staticmethod - def get_extension_function(): - return compute_noise_levels - -WaveformExtractor.register_extension(NoiseLevelsCalculator) +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core import get_noise_levels -def compute_noise_levels(waveform_extractor, load_if_exists=False, **params): +class ComputeNoiseLevels(ResultExtension): """ Computes the noise level associated to each recording channel. @@ -55,23 +17,42 @@ def compute_noise_levels(waveform_extractor, load_if_exists=False, **params): Parameters ---------- - waveform_extractor: WaveformExtractor - A waveform extractor object - load_if_exists: bool, default: False - If True, the noise levels are loaded if they already exist + sorting_result: SortingResult + A SortingResult object **params: dict with additional parameters - Returns ------- noise_levels: np.array noise level vector. """ - if load_if_exists and waveform_extractor.is_extension(NoiseLevelsCalculator.extension_name): - ext = waveform_extractor.load_extension(NoiseLevelsCalculator.extension_name) - else: - ext = NoiseLevelsCalculator(waveform_extractor) - ext.set_params(**params) - ext.run() + extension_name = "noise_levels" + + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) + + def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, return_scaled=True, seed=None): + params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, return_scaled=return_scaled, seed=seed) + return params + + def _select_extension_data(self, unit_ids): + # this do not depend on units + return self.data + + def _run(self): + self.data["noise_levels"] = get_noise_levels(self.sorting_result.recording, **self.params) + + # def get_data(self): + # """ + # Get computed noise levels. + + # Returns + # ------- + # noise_levels : np.array + # The noise levels associated to each channel. + # """ + # return self._extension_data["noise_levels"] + - return ext.get_data() +register_result_extension(ComputeNoiseLevels) +compute_noise_levels = ComputeNoiseLevels.function_factory() diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 1d6cb24826..42e3421036 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -22,9 +22,8 @@ class ComputeCorrelogramsTest(ResultExtensionCommonTestSuite, unittest.TestCase) dict(method="numpy"), dict(method="auto"), ] - -if HAVE_NUMBA: - ComputeCorrelogramsTest.extension_function_kwargs_list.append(dict(method="numba")) + if HAVE_NUMBA: + extension_function_kwargs_list.append(dict(method="numba")) diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 421c8f80cc..8867de08f3 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -2,9 +2,10 @@ import numpy as np from typing import List -from spikeinterface.postprocessing import compute_isi_histograms, ISIHistogramsCalculator -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing import compute_isi_histograms, ComputeISIHistograms +from spikeinterface.postprocessing.isi import _compute_isi_histograms try: @@ -15,16 +16,22 @@ HAVE_NUMBA = False -class ISIHistogramsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = ISIHistogramsCalculator - extension_data_names = ["isi_histograms", "bins"] +class ComputeISIHistogramsTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputeISIHistograms + extension_function_kwargs_list = [ + dict(method="numpy"), + dict(method="auto"), + ] + if HAVE_NUMBA: + extension_function_kwargs_list.append(dict(method="numba")) def test_compute_ISI(self): methods = ["numpy", "auto"] if HAVE_NUMBA: methods.append("numba") - sorting = self.we2.sorting + key0 = list(self.sorting_results.keys())[0] + sorting = self.sorting_results[key0].sorting _test_ISI(sorting, window_ms=60.0, bin_ms=1.0, methods=methods) _test_ISI(sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) @@ -32,7 +39,7 @@ def test_compute_ISI(self): def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): for method in methods: - ISI, bins = compute_isi_histograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + ISI, bins = _compute_isi_histograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) if method == "numpy": ref_ISI = ISI @@ -43,6 +50,8 @@ def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): if __name__ == "__main__": - test = ISIHistogramsExtensionTest() + test = ComputeISIHistogramsTest() test.setUp() + test.test_extension() test.test_compute_ISI() + diff --git a/src/spikeinterface/postprocessing/tests/test_noise_levels.py b/src/spikeinterface/postprocessing/tests/test_noise_levels.py index 9e3a4fd45c..db9174aa38 100644 --- a/src/spikeinterface/postprocessing/tests/test_noise_levels.py +++ b/src/spikeinterface/postprocessing/tests/test_noise_levels.py @@ -1,17 +1,17 @@ import unittest -from spikeinterface.postprocessing import compute_noise_levels, NoiseLevelsCalculator -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing import compute_noise_levels, ComputeNoiseLevels -class NoiseLevelsCalculatorExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = NoiseLevelsCalculator - extension_data_names = ["noise_levels"] - - exact_same_content = False +class ComputeNoiseLevelsTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputeNoiseLevels + extension_function_kwargs_list = [ + dict(), + ] if __name__ == "__main__": - test = NoiseLevelsCalculatorExtensionTest() + test = ComputeNoiseLevelsTest() test.setUp() test.test_extension() From 0d73ca6107e961dfdaeafed32c80342214d60912 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 24 Jan 2024 11:19:44 +0100 Subject: [PATCH 024/192] ComputeNoiseLevels extension in result_core refactor compute_sparsity with sorting_result --- src/spikeinterface/core/__init__.py | 1 + src/spikeinterface/core/numpyextractors.py | 4 +- src/spikeinterface/core/result_core.py | 85 ++++++++++- src/spikeinterface/core/sortingresult.py | 8 +- src/spikeinterface/core/sparsity.py | 140 +++++++++++------- src/spikeinterface/core/template_tools.py | 10 +- .../core/tests/test_result_core.py | 20 ++- .../core/tests/test_sparsity.py | 43 +++++- .../postprocessing/noise_level.py | 61 +------- .../tests/common_extension_tests.py | 8 + .../postprocessing/tests/test_noise_levels.py | 18 +-- 11 files changed, 248 insertions(+), 150 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 36fccad8fa..1249c76143 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -149,6 +149,7 @@ ComputeWaveforms, compute_waveforms, ComputeTemplates, compute_templates, ComputeFastTemplates, compute_fast_templates, + ComputeNoiseLevels, compute_noise_levels, ) # Important not for compatibility!! diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index e50e92e33d..5ce64fcc49 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -504,7 +504,7 @@ def __del__(self): self.shm.unlink() @staticmethod - def from_sorting(source_sorting): + def from_sorting(source_sorting, with_metadata=False): spikes = source_sorting.to_spike_vector() shm_spikes, shm = make_shared_array(spikes.shape, spikes.dtype) shm_spikes[:] = spikes @@ -517,6 +517,8 @@ def from_sorting(source_sorting): main_shm_owner=True, ) shm.close() + if with_metadata: + source_sorting.copy_metadata(sorting) return sorting diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 6cb93b21dc..77c3ac4d67 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -11,6 +11,7 @@ from .sortingresult import ResultExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates +from .recording_tools import get_noise_levels class ComputeWaveforms(ResultExtension): """ @@ -121,6 +122,25 @@ def _select_extension_data(self, unit_ids): return new_data + def get_waveforms_one_unit(self, unit_id, force_dense: bool = False,): + sorting = self.sorting_result.sorting + unit_index = sorting.id_to_index(unit_id) + spikes = sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + spike_mask = some_spikes["unit_index"] == unit_index + wfs = self.data["waveforms"][spike_mask, :, :] + + if force_dense: + if self.sorting_result.sparsity is not None: + num_channels = self.get_num_channels() + dense_wfs = np.zeros((wfs.shape[0], wfs.shape[1], num_channels), dtype=wfs.dtype) + unit_sparsity = self.sorting_result.sparsity.mask[unit_index] + dense_wfs[:, :, unit_sparsity] = wfs + wfs = dense_wfs + + return wfs + + @@ -202,7 +222,12 @@ def _set_params(self, operators = ["average", "std"]): waveforms_extension = self.sorting_result.get_extension("waveforms") - params = dict(operators=operators, nbefore=waveforms_extension.nbefore, nafter=waveforms_extension.nafter) + params = dict( + operators=operators, + nbefore=waveforms_extension.nbefore, + nafter=waveforms_extension.nafter, + return_scaled=waveforms_extension.params["return_scaled"], + ) return params @property @@ -222,6 +247,8 @@ def _select_extension_data(self, unit_ids): return new_data + + compute_templates = ComputeTemplates.function_factory() register_result_extension(ComputeTemplates) @@ -287,3 +314,59 @@ def _select_extension_data(self, unit_ids): compute_fast_templates = ComputeFastTemplates.function_factory() register_result_extension(ComputeFastTemplates) + + +class ComputeNoiseLevels(ResultExtension): + """ + Computes the noise level associated to each recording channel. + + This function will wraps the `get_noise_levels(recording)` to make the noise levels persistent + on disk (folder or zarr) as a `WaveformExtension`. + The noise levels do not depend on the unit list, only the recording, but it is a convenient way to + retrieve the noise levels directly ine the WaveformExtractor. + + Note that the noise levels can be scaled or not, depending on the `return_scaled` parameter + of the `WaveformExtractor`. + + Parameters + ---------- + sorting_result: SortingResult + A SortingResult object + **params: dict with additional parameters + + Returns + ------- + noise_levels: np.array + noise level vector. + """ + extension_name = "noise_levels" + + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) + + def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, return_scaled=True, seed=None): + params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, return_scaled=return_scaled, seed=seed) + return params + + def _select_extension_data(self, unit_ids): + # this do not depend on units + return self.data + + def _run(self): + self.data["noise_levels"] = get_noise_levels(self.sorting_result.recording, **self.params) + + # def get_data(self): + # """ + # Get computed noise levels. + + # Returns + # ------- + # noise_levels : np.array + # The noise levels associated to each channel. + # """ + # return self._extension_data["noise_levels"] + + +register_result_extension(ComputeNoiseLevels) +compute_noise_levels = ComputeNoiseLevels.function_factory() + diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 94b605d663..5d96c8f845 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -248,7 +248,7 @@ def create_memory(cls, sorting, recording, sparsity, rec_attributes): rec_attributes = rec_attributes.copy() # a copy of sorting is created directly in shared memory format to avoid further duplication of spikes. - sorting_copy = SharedMemorySorting.from_sorting(sorting) + sorting_copy = SharedMemorySorting.from_sorting(sorting, with_metadata=True) sortres = SortingResult(sorting=sorting_copy, recording=recording, rec_attributes=rec_attributes, format="memory", sparsity=sparsity) return sortres @@ -317,7 +317,7 @@ def load_from_binary_folder(cls, folder, recording=None): assert folder.is_dir(), f"This folder does not exists {folder}" # load internal sorting copy and make it sharedmem - sorting = SharedMemorySorting.from_sorting(NumpyFolderSorting(folder / "sorting")) + sorting = SharedMemorySorting.from_sorting(NumpyFolderSorting(folder / "sorting"), with_metadata=True) # load recording if possible if recording is None: @@ -473,7 +473,7 @@ def load_from_zarr(cls, folder, recording=None): # load internal sorting and make it sharedmem # TODO propagate storage_options - sorting = SharedMemorySorting.from_sorting(ZarrSortingExtractor(folder, zarr_group="sorting")) + sorting = SharedMemorySorting.from_sorting(ZarrSortingExtractor(folder, zarr_group="sorting"), with_metadata=True) # load recording if possible if recording is None: @@ -970,7 +970,7 @@ def get_extension_class(extension_name: str): """ global _possible_extensions extensions_dict = {ext.extension_name: ext for ext in _possible_extensions} - assert extension_name in extensions_dict, "Extension is not registered, please import related module before" + assert extension_name in extensions_dict, f"Extension '{extension_name}' is not registered, please import related module before" ext_class = extensions_dict[extension_name] return ext_class diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 79816cd341..8d3fbb5e5a 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -266,118 +266,147 @@ def from_dict(cls, dictionary: dict): ## Some convinient function to compute sparsity from several strategy @classmethod - def from_best_channels(cls, templates_or_we, num_channels, peak_sign="neg"): + def from_best_channels(cls, templates_or_sorting_result, num_channels, peak_sign="neg"): """ Construct sparsity from N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. """ from .template_tools import get_template_amplitudes - mask = np.zeros((templates_or_we.unit_ids.size, templates_or_we.channel_ids.size), dtype="bool") - peak_values = get_template_amplitudes(templates_or_we, peak_sign=peak_sign) - for unit_ind, unit_id in enumerate(templates_or_we.unit_ids): + mask = np.zeros((templates_or_sorting_result.unit_ids.size, templates_or_sorting_result.channel_ids.size), dtype="bool") + peak_values = get_template_amplitudes(templates_or_sorting_result, peak_sign=peak_sign) + for unit_ind, unit_id in enumerate(templates_or_sorting_result.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_we.unit_ids, templates_or_we.channel_ids) + return cls(mask, templates_or_sorting_result.unit_ids, templates_or_sorting_result.channel_ids) @classmethod - def from_radius(cls, templates_or_we, radius_um, peak_sign="neg"): + def from_radius(cls, templates_or_sorting_result, radius_um, peak_sign="neg"): """ Construct sparsity from a radius around the best channel. Use the "radius_um" argument to specify the radius in um """ from .template_tools import get_template_extremum_channel - mask = np.zeros((templates_or_we.unit_ids.size, templates_or_we.channel_ids.size), dtype="bool") - channel_locations = templates_or_we.get_channel_locations() + mask = np.zeros((templates_or_sorting_result.unit_ids.size, templates_or_sorting_result.channel_ids.size), dtype="bool") + channel_locations = templates_or_sorting_result.get_channel_locations() distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - best_chan = get_template_extremum_channel(templates_or_we, peak_sign=peak_sign, outputs="index") - for unit_ind, unit_id in enumerate(templates_or_we.unit_ids): + best_chan = get_template_extremum_channel(templates_or_sorting_result, peak_sign=peak_sign, outputs="index") + for unit_ind, unit_id in enumerate(templates_or_sorting_result.unit_ids): chan_ind = best_chan[unit_id] (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_we.unit_ids, templates_or_we.channel_ids) + return cls(mask, templates_or_sorting_result.unit_ids, templates_or_sorting_result.channel_ids) @classmethod - def from_snr(cls, we, threshold, peak_sign="neg"): + def from_snr(cls, sorting_result, threshold, peak_sign="neg"): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. """ from .template_tools import get_template_amplitudes - mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool") + assert sorting_result.sparsity is None, "To compute sparsity you need a dense SortingResult" - peak_values = get_template_amplitudes(we, peak_sign=peak_sign, mode="extremum") - noise = get_noise_levels(we.recording, return_scaled=we.return_scaled) - for unit_ind, unit_id in enumerate(we.unit_ids): + mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") + + peak_values = get_template_amplitudes(sorting_result, peak_sign=peak_sign, mode="extremum", return_scaled=True) + + ext = sorting_result.get_extension("noise_levels") + assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" + assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" + noise = ext.data["noise_levels"] + + for unit_ind, unit_id in enumerate(sorting_result.unit_ids): chan_inds = np.nonzero((np.abs(peak_values[unit_id]) / noise) >= threshold) mask[unit_ind, chan_inds] = True - return cls(mask, we.unit_ids, we.channel_ids) + return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) @classmethod - def from_ptp(cls, we, threshold): + def from_ptp(cls, sorting_result, threshold): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the SNR threshold. """ - - mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool") - templates_ptps = np.ptp(we.get_all_templates(), axis=1) - noise = get_noise_levels(we.recording, return_scaled=we.return_scaled) - for unit_ind, unit_id in enumerate(we.unit_ids): + assert sorting_result.sparsity is None, "To compute sparsity with ptp you need a dense SortingResult" + + from .template_tools import _get_dense_templates_array + mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") + templates_array = _get_dense_templates_array(sorting_result, return_scaled=True) + templates_ptps = np.ptp(templates_array, axis=1) + ext = sorting_result.get_extension("noise_levels") + assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" + noise = ext.data["noise_levels"] + + for unit_ind, unit_id in enumerate(sorting_result.unit_ids): chan_inds = np.nonzero(templates_ptps[unit_ind] / noise >= threshold) mask[unit_ind, chan_inds] = True - return cls(mask, we.unit_ids, we.channel_ids) + return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) @classmethod - def from_energy(cls, we, threshold): + def from_energy(cls, sorting_result, threshold): """ Construct sparsity from a threshold based on per channel energy ratio. Use the "threshold" argument to specify the SNR threshold. """ - mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool") - noise = np.sqrt(we.nsamples) * get_noise_levels(we.recording, return_scaled=we.return_scaled) - for unit_ind, unit_id in enumerate(we.unit_ids): - wfs = we.get_waveforms(unit_id) + assert sorting_result.sparsity is None, "To compute sparsity with energy you need a dense SortingResult" + + mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") + + # noise_levels + ext = sorting_result.get_extension("noise_levels") + assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" + noise_levels = ext.data["noise_levels"] + + # waveforms + ext_waveforms = sorting_result.get_extension("waveforms") + assert ext_waveforms is not None, "To compute sparsity from energy you need to compute 'waveforms' first" + namples = ext_waveforms.nbefore + ext_waveforms.nafter + + noise = np.sqrt(namples) * noise_levels + + for unit_ind, unit_id in enumerate(sorting_result.unit_ids): + wfs = ext_waveforms.get_waveforms_one_unit(unit_id, force_dense=True) energies = np.linalg.norm(wfs, axis=(0, 1)) chan_inds = np.nonzero(energies / (noise * np.sqrt(len(wfs))) >= threshold) mask[unit_ind, chan_inds] = True - return cls(mask, we.unit_ids, we.channel_ids) + return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) @classmethod - def from_property(cls, we, by_property): + def from_property(cls, sorting_result, by_property): """ Construct sparsity witha property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. """ # check consistency - assert by_property in we.recording.get_property_keys(), f"Property {by_property} is not a recording property" - assert by_property in we.sorting.get_property_keys(), f"Property {by_property} is not a sorting property" + assert by_property in sorting_result.recording.get_property_keys(), f"Property {by_property} is not a recording property" + assert by_property in sorting_result.sorting.get_property_keys(), f"Property {by_property} is not a sorting property" - mask = np.zeros((we.unit_ids.size, we.channel_ids.size), dtype="bool") - rec_by = we.recording.split_by(by_property) - for unit_ind, unit_id in enumerate(we.unit_ids): - unit_property = we.sorting.get_property(by_property)[unit_ind] + mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") + rec_by = sorting_result.recording.split_by(by_property) + for unit_ind, unit_id in enumerate(sorting_result.unit_ids): + unit_property = sorting_result.sorting.get_property(by_property)[unit_ind] assert ( unit_property in rec_by.keys() ), f"Unit property {unit_property} cannot be found in the recording properties" - chan_inds = we.recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) + chan_inds = sorting_result.recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) mask[unit_ind, chan_inds] = True - return cls(mask, we.unit_ids, we.channel_ids) + return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) @classmethod - def create_dense(cls, we): + def create_dense(cls, sorting_result): """ Create a sparsity object with all selected channel for all units. """ - mask = np.ones((we.unit_ids.size, we.channel_ids.size), dtype="bool") - return cls(mask, we.unit_ids, we.channel_ids) + mask = np.ones((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") + return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) def compute_sparsity( - templates_or_waveform_extractor, + templates_or_sorting_result, method="radius", peak_sign="neg", num_channels=5, @@ -390,10 +419,10 @@ def compute_sparsity( Parameters ---------- - templates_or_waveform_extractor: Templates | WaveformExtractor - A Templates or a WaveformExtractor object. - Some method accept both objects ("best_channels", "radius", ) - Other method need WaveformExtractor because internally the recording is needed. + templates_or_sorting_result: Templates | SortingResult + A Templates or a SortingResult object. + Some methods accept both objects ("best_channels", "radius", ) + Other methods require only SortingResult because internally the recording is needed. {} @@ -406,30 +435,31 @@ def compute_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates from .waveform_extractor import WaveformExtractor + from .sortingresult import SortingResult if method in ("best_channels", "radius"): - assert isinstance(templates_or_waveform_extractor, (Templates, WaveformExtractor)), "compute_sparsity() need Templates or WaveformExtractor" + assert isinstance(templates_or_sorting_result, (Templates, WaveformExtractor, SortingResult)), "compute_sparsity() need Templates or WaveformExtractor or SortingResult" else: - assert isinstance(templates_or_waveform_extractor, WaveformExtractor), f"compute_sparsity(method='{method}') need WaveformExtractor" + assert isinstance(templates_or_sorting_result, (WaveformExtractor, SortingResult)), f"compute_sparsity(method='{method}') need WaveformExtractor or SortingResult" if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" - sparsity = ChannelSparsity.from_best_channels(templates_or_waveform_extractor, num_channels, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_result, num_channels, peak_sign=peak_sign) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" - sparsity = ChannelSparsity.from_radius(templates_or_waveform_extractor, radius_um, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_radius(templates_or_sorting_result, radius_um, peak_sign=peak_sign) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_snr(templates_or_waveform_extractor, threshold, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_snr(templates_or_sorting_result, threshold, peak_sign=peak_sign) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_energy(templates_or_waveform_extractor, threshold) + sparsity = ChannelSparsity.from_energy(templates_or_sorting_result, threshold) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp(templates_or_waveform_extractor, threshold) + sparsity = ChannelSparsity.from_ptp(templates_or_sorting_result, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" - sparsity = ChannelSparsity.from_property(templates_or_waveform_extractor, by_property) + sparsity = ChannelSparsity.from_property(templates_or_sorting_result, by_property) else: raise ValueError(f"compute_sparsity() method={method} do not exists") diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 09022d0b34..50e62afa1c 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -9,7 +9,7 @@ from .sortingresult import SortingResult -def _get_dense_templates_array(one_object): +def _get_dense_templates_array(one_object, return_scaled=True): if isinstance(one_object, Templates): templates_array = one_object.get_dense_templates() elif isinstance(one_object, WaveformExtractor): @@ -18,8 +18,10 @@ def _get_dense_templates_array(one_object): ext = one_object.get_extension("templates") if ext is not None: templates_array = ext.data["average"] + assert return_scaled == ext.params["return_scaled"], f"templates have been extracted with return_scaled={not return_scaled} you cannot get then with return_scaled={return_scaled}" else: ext = one_object.get_extension("fast_templates") + assert return_scaled == ext.params["return_scaled"], f"fast_templates have been extracted with return_scaled={not return_scaled} you cannot get then with return_scaled={return_scaled}" if ext is not None: templates_array = ext.data["average"] else: @@ -49,7 +51,7 @@ def _get_nbefore(one_object): def get_template_amplitudes( - templates_or_waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum" + templates_or_waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", return_scaled: bool = True ): """ Get amplitude per channel for each unit. @@ -63,6 +65,8 @@ def get_template_amplitudes( mode: "extremum" | "at_index", default: "extremum" "extremum": max or min "at_index": take value at spike index + return_scaled: bool, default True + The amplitude is scaled or not. Returns ------- @@ -78,7 +82,7 @@ def get_template_amplitudes( peak_values = {} - templates_array = _get_dense_templates_array(templates_or_waveform_extractor) + templates_array = _get_dense_templates_array(templates_or_waveform_extractor, return_scaled=return_scaled) for unit_ind, unit_id in enumerate(unit_ids): template = templates_array[unit_ind, :, :] diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index 855099ce0e..fad5795fe8 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -146,6 +146,16 @@ def test_ComputeFastTemplates(format, sparse): # ax.legend() # plt.show() +@pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) +# @pytest.mark.parametrize("sparse", [True, False]) +def test_ComputeNoiseLevels(format, sparse): + sortres = get_sorting_result(format=format, sparse=sparse) + + sortres.compute("noise_levels", return_scaled=True) + print(sortres) + + noise_levels = sortres.get_extension("noise_levels").data["noise_levels"] + assert noise_levels.shape[0] == sortres.channel_ids.size if __name__ == '__main__': @@ -156,9 +166,11 @@ def test_ComputeFastTemplates(format, sparse): # test_ComputeWaveforms(format="zarr", sparse=True) # test_ComputeWaveforms(format="zarr", sparse=False) - test_ComputeTemplates(format="memory", sparse=True) - test_ComputeTemplates(format="memory", sparse=False) - test_ComputeTemplates(format="binary_folder", sparse=True) - test_ComputeTemplates(format="zarr", sparse=True) + # test_ComputeTemplates(format="memory", sparse=True) + # test_ComputeTemplates(format="memory", sparse=False) + # test_ComputeTemplates(format="binary_folder", sparse=True) + # test_ComputeTemplates(format="zarr", sparse=True) # test_ComputeFastTemplates(format="memory", sparse=True) + + test_ComputeNoiseLevels(format="memory", sparse=False) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 481151b149..72aafde0d4 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -3,9 +3,10 @@ import numpy as np import json -from spikeinterface.core import ChannelSparsity, estimate_sparsity +from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity from spikeinterface.core.core_tools import check_json from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.core import start_sorting_result def test_ChannelSparsity(): for unit_ids in (["a", "b", "c", "d"], [4, 5, 6, 7]): @@ -143,15 +144,21 @@ def test_densify_waveforms(): assert np.array_equal(template_sparse, template_sparse2) - -def test_estimate_sparsity(): - num_units = 5 +def get_dataset(): recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, num_channels=10, num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=1.0, strategy="tile_pregenerated"), seed=2205, ) + recording.set_property("group", ["a"] * 5 + ["b"] * 5) + sorting.set_property("group", ["a"] * 3 + ["b"] * 2) + return recording, sorting + + +def test_estimate_sparsity(): + recording, sorting = get_dataset() + num_units = sorting.unit_ids.size # small radius should give a very sparse = one channel per unit sparsity = estimate_sparsity( @@ -184,7 +191,29 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units)*3) +def test_compute_sparsity(): + recording, sorting = get_dataset() + + # using SortingResult + sorting_result = start_sorting_result(sorting=sorting, recording=recording, sparse=False) + sorting_result.select_random_spikes() + sorting_result.compute("fast_templates", return_scaled=True) + sorting_result.compute("noise_levels", return_scaled=True) + sorting_result.compute("waveforms", return_scaled=True) + print(sorting_result) + + sparsity = compute_sparsity(sorting_result, method="best_channels", num_channels=2, peak_sign="neg") + sparsity = compute_sparsity(sorting_result, method="radius", radius_um=50., peak_sign="neg") + sparsity = compute_sparsity(sorting_result, method="snr", threshold=5, peak_sign="neg") + sparsity = compute_sparsity(sorting_result, method="ptp", threshold=5) + sparsity = compute_sparsity(sorting_result, method="energy", threshold=5) + sparsity = compute_sparsity(sorting_result, method="by_property", by_property="group") + + # using Templates + # TODO later + -if __name__ == "__main__": - test_ChannelSparsity() - test_estimate_sparsity() +if __name__ == "__main__": # test_ChannelSparsity() + # test_ChannelSparsity() + # test_estimate_sparsity() + test_compute_sparsity() diff --git a/src/spikeinterface/postprocessing/noise_level.py b/src/spikeinterface/postprocessing/noise_level.py index 4b56fc81a3..abd47f574f 100644 --- a/src/spikeinterface/postprocessing/noise_level.py +++ b/src/spikeinterface/postprocessing/noise_level.py @@ -1,58 +1,3 @@ - -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension -from spikeinterface.core import get_noise_levels - - -class ComputeNoiseLevels(ResultExtension): - """ - Computes the noise level associated to each recording channel. - - This function will wraps the `get_noise_levels(recording)` to make the noise levels persistent - on disk (folder or zarr) as a `WaveformExtension`. - The noise levels do not depend on the unit list, only the recording, but it is a convenient way to - retrieve the noise levels directly ine the WaveformExtractor. - - Note that the noise levels can be scaled or not, depending on the `return_scaled` parameter - of the `WaveformExtractor`. - - Parameters - ---------- - sorting_result: SortingResult - A SortingResult object - **params: dict with additional parameters - - Returns - ------- - noise_levels: np.array - noise level vector. - """ - extension_name = "noise_levels" - - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) - - def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, return_scaled=True, seed=None): - params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, return_scaled=return_scaled, seed=seed) - return params - - def _select_extension_data(self, unit_ids): - # this do not depend on units - return self.data - - def _run(self): - self.data["noise_levels"] = get_noise_levels(self.sorting_result.recording, **self.params) - - # def get_data(self): - # """ - # Get computed noise levels. - - # Returns - # ------- - # noise_levels : np.array - # The noise levels associated to each channel. - # """ - # return self._extension_data["noise_levels"] - - -register_result_extension(ComputeNoiseLevels) -compute_noise_levels = ComputeNoiseLevels.function_factory() +# "noise_levels" extensions is now in core +# this is kept name space compatibility but should be removed soon +from ..core.result_core import ComputeNoiseLevels, compute_noise_levels diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 32f6c11017..7a45a2f6af 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -76,6 +76,14 @@ def setUp(self): key = f"spare{sparse}_{format}" self.sorting_results[key] = sorting_result + def tearDown(self): + for k in list(self.sorting_results.keys()): + sorting_result = self.sorting_results.pop(k) + if sorting_result.format != "memory": + folder = sorting_result.folder + del sorting_result + shutil.rmtree(folder) + @property def extension_name(self): return self.extension_class.extension_name diff --git a/src/spikeinterface/postprocessing/tests/test_noise_levels.py b/src/spikeinterface/postprocessing/tests/test_noise_levels.py index db9174aa38..f334f92fa6 100644 --- a/src/spikeinterface/postprocessing/tests/test_noise_levels.py +++ b/src/spikeinterface/postprocessing/tests/test_noise_levels.py @@ -1,17 +1 @@ -import unittest - -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite -from spikeinterface.postprocessing import compute_noise_levels, ComputeNoiseLevels - - -class ComputeNoiseLevelsTest(ResultExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputeNoiseLevels - extension_function_kwargs_list = [ - dict(), - ] - - -if __name__ == "__main__": - test = ComputeNoiseLevelsTest() - test.setUp() - test.test_extension() +# "noise_levels" extensions is now in core From e3e8e2baee2288e20352616c3e1e0769b1e0ff69 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 24 Jan 2024 11:20:07 +0100 Subject: [PATCH 025/192] Port unit_location into SortingResult --- src/spikeinterface/postprocessing/__init__.py | 2 +- .../tests/test_unit_localization.py | 22 +- .../postprocessing/unit_localization.py | 194 ++++++++---------- 3 files changed, 96 insertions(+), 122 deletions(-) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 55044bb018..3d4b6842c9 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -38,7 +38,7 @@ from .unit_localization import ( compute_unit_locations, - UnitLocationsCalculator, + ComputeUnitLocations, compute_center_of_mass, ) diff --git a/src/spikeinterface/postprocessing/tests/test_unit_localization.py b/src/spikeinterface/postprocessing/tests/test_unit_localization.py index b00609cd17..3ce078a195 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_localization.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_localization.py @@ -1,22 +1,18 @@ import unittest +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing import ComputeUnitLocations -from spikeinterface.postprocessing import UnitLocationsCalculator -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite - -class UnitLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = UnitLocationsCalculator - extension_data_names = ["unit_locations"] +class UnitLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputeUnitLocations extension_function_kwargs_list = [ dict(method="center_of_mass", radius_um=100), - dict(method="center_of_mass", radius_um=100, outputs="by_unit"), - dict(method="grid_convolution", radius_um=50, outputs="by_unit"), + dict(method="center_of_mass", radius_um=100), + dict(method="grid_convolution", radius_um=50), + dict(method="monopolar_triangulation", radius_um=150), dict(method="monopolar_triangulation", radius_um=150), - dict(method="monopolar_triangulation", radius_um=150, outputs="by_unit"), - dict( - method="monopolar_triangulation", radius_um=150, outputs="by_unit", optimizer="minimize_with_log_penality" - ), + dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), ] @@ -24,4 +20,4 @@ class UnitLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.Test test = UnitLocationsExtensionTest() test.setUp() test.test_extension() - test.tearDown() + # test.tearDown() diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 2ac841c148..c78e70cd3f 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -10,9 +10,9 @@ except ImportError: HAVE_NUMBA = False +from ..core.sortingresult import register_result_extension, ResultExtension from ..core import compute_sparsity -from ..core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension -from ..core.template_tools import get_template_extremum_channel +from ..core.template_tools import get_template_extremum_channel, _get_nbefore, _get_dense_templates_array dtype_localize_by_method = { @@ -25,109 +25,85 @@ possible_localization_methods = list(dtype_localize_by_method.keys()) -class UnitLocationsCalculator(BaseWaveformExtractorExtension): +class ComputeUnitLocations(ResultExtension): """ - Comput unit locations from WaveformExtractor. + Localize units in 2D or 3D with several methods given the template. Parameters ---------- - waveform_extractor: WaveformExtractor - A waveform extractor object + sorting_result: SortingResult + A SortingResult object + method: "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + The method to use for localization + outputs: "numpy" | "by_unit", default: "numpy" + The output format + method_kwargs: + Other kwargs depending on the method + + Returns + ------- + unit_locations: np.array + unit location with shape (num_unit, 2) or (num_unit, 3) or (num_unit, 3) (with alpha) """ extension_name = "unit_locations" + depend_on = ["fast_templates|templates", ] + need_recording = True + use_nodepipeline = False - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) - def _set_params(self, method="center_of_mass", method_kwargs={}): + def _set_params(self, method="center_of_mass", **method_kwargs): params = dict(method=method, method_kwargs=method_kwargs) return params def _select_extension_data(self, unit_ids): - unit_inds = self.waveform_extractor.sorting.ids_to_indices(unit_ids) - new_unit_location = self._extension_data["unit_locations"][unit_inds] + unit_inds = self.sorting_result.sorting.ids_to_indices(unit_ids) + new_unit_location = self.data["unit_locations"][unit_inds] return dict(unit_locations=new_unit_location) def _run(self, **job_kwargs): - method = self._params["method"] - method_kwargs = self._params["method_kwargs"] + method = self.params["method"] + method_kwargs = self.params["method_kwargs"] assert method in possible_localization_methods if method == "center_of_mass": - unit_location = compute_center_of_mass(self.waveform_extractor, **method_kwargs) + unit_location = compute_center_of_mass(self.sorting_result, **method_kwargs) elif method == "grid_convolution": - unit_location = compute_grid_convolution(self.waveform_extractor, **method_kwargs) + unit_location = compute_grid_convolution(self.sorting_result, **method_kwargs) elif method == "monopolar_triangulation": - unit_location = compute_monopolar_triangulation(self.waveform_extractor, **method_kwargs) - self._extension_data["unit_locations"] = unit_location - - def get_data(self, outputs="numpy"): - """ - Get the computed unit locations. - - Parameters - ---------- - outputs : "numpy" | "by_unit", default: "numpy" - The output format - - Returns - ------- - unit_locations : np.array or dict - The unit locations as a Nd array (outputs="numpy") or - as a dict with units as key and locations as values. - """ - if outputs == "numpy": - return self._extension_data["unit_locations"] + unit_location = compute_monopolar_triangulation(self.sorting_result, **method_kwargs) + self.data["unit_locations"] = unit_location - elif outputs == "by_unit": - locations_by_unit = {} - for unit_ind, unit_id in enumerate(self.waveform_extractor.sorting.unit_ids): - locations_by_unit[unit_id] = self._extension_data["unit_locations"][unit_ind] - return locations_by_unit + # def get_data(self, outputs="numpy"): + # """ + # Get the computed unit locations. - @staticmethod - def get_extension_function(): - return compute_unit_locations + # Parameters + # ---------- + # outputs : "numpy" | "by_unit", default: "numpy" + # The output format + # Returns + # ------- + # unit_locations : np.array or dict + # The unit locations as a Nd array (outputs="numpy") or + # as a dict with units as key and locations as values. + # """ + # if outputs == "numpy": + # return self.data["unit_locations"] -WaveformExtractor.register_extension(UnitLocationsCalculator) - - -def compute_unit_locations( - waveform_extractor, load_if_exists=False, method="monopolar_triangulation", outputs="numpy", **method_kwargs -): - """ - Localize units in 2D or 3D with several methods given the template. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - A waveform extractor object - load_if_exists : bool, default: False - Whether to load precomputed unit locations, if they already exist - method: "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" - The method to use for localization - outputs: "numpy" | "by_unit", default: "numpy" - The output format - method_kwargs: - Other kwargs depending on the method + # elif outputs == "by_unit": + # locations_by_unit = {} + # for unit_ind, unit_id in enumerate(self.sorting_result.sorting.unit_ids): + # locations_by_unit[unit_id] = self.data["unit_locations"][unit_ind] + # return locations_by_unit - Returns - ------- - unit_locations: np.array - unit location with shape (num_unit, 2) or (num_unit, 3) or (num_unit, 3) (with alpha) - """ - if load_if_exists and waveform_extractor.is_extension(UnitLocationsCalculator.extension_name): - ulc = waveform_extractor.load_extension(UnitLocationsCalculator.extension_name) - else: - ulc = UnitLocationsCalculator(waveform_extractor) - ulc.set_params(method=method, method_kwargs=method_kwargs) - ulc.run() - unit_locations = ulc.get_data(outputs=outputs) - return unit_locations +register_result_extension(ComputeUnitLocations) +compute_unit_locations = ComputeUnitLocations.function_factory() def make_initial_guess_and_bounds(wf_data, local_contact_locations, max_distance_um, initial_z=20): @@ -218,7 +194,7 @@ def estimate_distance_error_with_log(vec, wf_data, local_contact_locations, max_ def compute_monopolar_triangulation( - waveform_extractor, + sorting_result, optimizer="minimize_with_log_penality", radius_um=75, max_distance_um=1000, @@ -245,8 +221,8 @@ def compute_monopolar_triangulation( Parameters ---------- - waveform_extractor:WaveformExtractor - A waveform extractor object + sorting_result: SortingResult + A SortingResult object method: "least_square" | "minimize_with_log_penality", default: "least_square" The optimizer to use radius_um: float, default: 75 @@ -272,13 +248,15 @@ def compute_monopolar_triangulation( assert optimizer in ("least_square", "minimize_with_log_penality") assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" - unit_ids = waveform_extractor.sorting.unit_ids + unit_ids = sorting_result.unit_ids + + contact_locations = sorting_result.get_channel_locations() - contact_locations = waveform_extractor.get_channel_locations() - nbefore = waveform_extractor.nbefore + + sparsity = compute_sparsity(sorting_result, method="radius", radius_um=radius_um) + templates = _get_dense_templates_array(sorting_result) + nbefore = _get_nbefore(sorting_result) - sparsity = compute_sparsity(waveform_extractor, method="radius", radius_um=radius_um) - templates = waveform_extractor.get_all_templates(mode="average") if enforce_decrease: neighbours_mask = np.zeros((templates.shape[0], templates.shape[2]), dtype=bool) @@ -286,7 +264,7 @@ def compute_monopolar_triangulation( chan_inds = sparsity.unit_id_to_channel_indices[unit_id] neighbours_mask[i, chan_inds] = True enforce_decrease_radial_parents = make_radial_order_parents(contact_locations, neighbours_mask) - best_channels = get_template_extremum_channel(waveform_extractor, outputs="index") + best_channels = get_template_extremum_channel(sorting_result, outputs="index") unit_location = np.zeros((unit_ids.size, 4), dtype="float64") for i, unit_id in enumerate(unit_ids): @@ -315,14 +293,14 @@ def compute_monopolar_triangulation( return unit_location -def compute_center_of_mass(waveform_extractor, peak_sign="neg", radius_um=75, feature="ptp"): +def compute_center_of_mass(sorting_result, peak_sign="neg", radius_um=75, feature="ptp"): """ Computes the center of mass (COM) of a unit based on the template amplitudes. Parameters ---------- - waveform_extractor: WaveformExtractor - The waveform extractor + sorting_result: SortingResult + A SortingResult object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um: float @@ -334,15 +312,15 @@ def compute_center_of_mass(waveform_extractor, peak_sign="neg", radius_um=75, fe ------- unit_location: np.array """ - unit_ids = waveform_extractor.sorting.unit_ids + unit_ids = sorting_result.unit_ids - recording = waveform_extractor.recording - contact_locations = recording.get_channel_locations() + contact_locations = sorting_result.get_channel_locations() assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" - sparsity = compute_sparsity(waveform_extractor, peak_sign=peak_sign, method="radius", radius_um=radius_um) - templates = waveform_extractor.get_all_templates(mode="average") + sparsity = compute_sparsity(sorting_result, peak_sign=peak_sign, method="radius", radius_um=radius_um) + templates = _get_dense_templates_array(sorting_result) + nbefore = _get_nbefore(sorting_result) unit_location = np.zeros((unit_ids.size, 2), dtype="float64") for i, unit_id in enumerate(unit_ids): @@ -358,7 +336,7 @@ def compute_center_of_mass(waveform_extractor, peak_sign="neg", radius_um=75, fe elif feature == "energy": wf_data = np.linalg.norm(wf[:, chan_inds], axis=0) elif feature == "peak_voltage": - wf_data = wf[waveform_extractor.nbefore, chan_inds] + wf_data = wf[nbefore, chan_inds] # center of mass com = np.sum(wf_data[:, np.newaxis] * local_contact_locations, axis=0) / np.sum(wf_data) @@ -369,7 +347,7 @@ def compute_center_of_mass(waveform_extractor, peak_sign="neg", radius_um=75, fe @np.errstate(divide="ignore", invalid="ignore") def compute_grid_convolution( - waveform_extractor, + sorting_result, peak_sign="neg", radius_um=40.0, upsampling_um=5, @@ -385,8 +363,8 @@ def compute_grid_convolution( Parameters ---------- - waveform_extractor: WaveformExtractor - The waveform extractor + sorting_result: SortingResult + A SortingResult object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um: float, default: 40.0 @@ -411,11 +389,15 @@ def compute_grid_convolution( unit_location: np.array """ - contact_locations = waveform_extractor.get_channel_locations() + contact_locations = sorting_result.get_channel_locations() + unit_ids = sorting_result.unit_ids - nbefore = waveform_extractor.nbefore - nafter = waveform_extractor.nafter - fs = waveform_extractor.sampling_frequency + templates = _get_dense_templates_array(sorting_result) + nbefore = _get_nbefore(sorting_result) + nafter = templates.shape[1] - nbefore + + + fs = sorting_result.sampling_frequency percentile = 100 - percentile assert 0 <= percentile <= 100, "Percentile should be in [0, 100]" assert 0 <= sparsity_threshold <= 1, "sparsity_threshold should be in [0, 1]" @@ -430,13 +412,9 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, sigma_um, margin_um ) - templates = waveform_extractor.get_all_templates(mode="average") - - peak_channels = get_template_extremum_channel(waveform_extractor, peak_sign, outputs="index") - unit_ids = waveform_extractor.sorting.unit_ids - + peak_channels = get_template_extremum_channel(sorting_result, peak_sign, outputs="index") + weights_sparsity_mask = weights > sparsity_threshold - unit_location = np.zeros((unit_ids.size, 2), dtype="float64") for i, unit_id in enumerate(unit_ids): main_chan = peak_channels[unit_id] From d1b4f8610b3de99227c1ed06cf49e31429c62620 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 24 Jan 2024 14:30:31 +0100 Subject: [PATCH 026/192] compute_template_metrics and compute_template_similarity ported to SortingResult --- src/spikeinterface/core/sortingresult.py | 2 +- src/spikeinterface/postprocessing/__init__.py | 4 +- .../postprocessing/spike_amplitudes.py | 2 +- .../postprocessing/template_metrics.py | 281 ++++++++---------- .../postprocessing/template_similarity.py | 196 ++++++------ .../tests/test_template_metrics.py | 35 +-- .../tests/test_template_similarity.py | 34 ++- 7 files changed, 279 insertions(+), 275 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 5d96c8f845..ce46938a10 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -723,7 +723,7 @@ def get_channel_locations(self) -> np.ndarray: return all_positions def channel_ids_to_indices(self, channel_ids) -> np.ndarray: - all_channel_ids = self.rec_attributes["channel_ids"] + all_channel_ids = list(self.rec_attributes["channel_ids"]) indices = np.array([all_channel_ids.index(id) for id in channel_ids], dtype=int) return indices diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 3d4b6842c9..6f4112095d 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -1,11 +1,11 @@ from .template_metrics import ( - TemplateMetricsCalculator, + ComputeTemplateMetrics, compute_template_metrics, get_template_metric_names, ) from .template_similarity import ( - TemplateSimilarityCalculator, + ComputeTemplateSimilarity, compute_template_similarity, check_equal_template_with_distribution_overlap, ) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 58a6390ef7..f9f15551a9 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -25,7 +25,7 @@ class ComputeSpikeAmplitudes(ResultExtension): Parameters ---------- sorting_result: SortingResult - The waveform extractor object + The SortingResult object load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. peak_sign: "neg" | "pos" | "both", default: "neg diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 879785e5a7..fae4273a10 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -10,13 +10,14 @@ from typing import Optional from copy import deepcopy -from ..core import WaveformExtractor, ChannelSparsity +from ..core.sortingresult import register_result_extension, ResultExtension +from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel -from ..core.waveform_extractor import BaseWaveformExtractorExtension - +from ..core.template_tools import _get_dense_templates_array # DEBUG = False +# TODO handle external sparsity def get_single_channel_template_metric_names(): return deepcopy(list(_single_channel_metric_name_to_func.keys())) @@ -30,20 +31,76 @@ def get_template_metric_names(): return get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() -class TemplateMetricsCalculator(BaseWaveformExtractorExtension): - """Class to compute template metrics of waveform shapes. +class ComputeTemplateMetrics(ResultExtension): + """ + Compute template metrics including: + * peak_to_valley + * peak_trough_ratio + * halfwidth + * repolarization_slope + * recovery_slope + * num_positive_peaks + * num_negative_peaks + + Optionally, the following multi-channel metrics can be computed (when include_multi_channel_metrics=True): + * velocity_above + * velocity_below + * exp_decay + * spread Parameters ---------- - waveform_extractor: WaveformExtractor - The waveform extractor object + sorting_result: SortingResult + The SortingResult object + metric_names : list or None, default: None + List of metrics to compute (see si.postprocessing.get_template_metric_names()) + peak_sign : {"neg", "pos"}, default: "neg" + Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. + upsampling_factor : int, default: 10 + The upsampling factor to upsample the templates + sparsity: ChannelSparsity or None, default: None + If None, template metrics are computed on the extremum channel only. + If sparsity is given, template metrics are computed on all sparse channels of each unit. + For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. + include_multi_channel_metrics: bool, default: False + Whether to compute multi-channel metrics + metrics_kwargs: dict + Additional arguments to pass to the metric functions. Including: + * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 + * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 + * peak_width_ms: the width in samples to detect peaks, default: 0.2 + * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) + * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 + * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 + * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" + * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 + * spread_threshold: the threshold to compute the spread, default: 0.2 + * spread_smooth_um: the smoothing in um to compute the spread, default: 20 + * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None + - If None, all channels all channels are considered + - If 0 or 1, only the "column" that includes the max channel is considered + - If > 1, only channels within range (+/-) um from the max channel horizontal position are used + + Returns + ------- + template_metrics : pd.DataFrame + Dataframe with the computed template metrics. + If "sparsity" is None, the index is the unit_id. + If "sparsity" is given, the index is a multi-index (unit_id, channel_id) + + Notes + ----- + If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, + so that one metric value will be computed per unit. + For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". """ extension_name = "template_metrics" - min_channels_for_multi_channel_warning = 10 + depend_on = ["fast_templates|templates", ] + need_recording = True + use_nodepipeline = False - def __init__(self, waveform_extractor: WaveformExtractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) + min_channels_for_multi_channel_warning = 10 def _set_params( self, @@ -54,43 +111,66 @@ def _set_params( metrics_kwargs=None, include_multi_channel_metrics=False, ): + + if sparsity is not None: + # TODO handle extra sparsity + raise NotImplementedError + + # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() + if include_multi_channel_metrics or ( + metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) + ): + assert sparsity is None, ( + "If multi-channel metrics are computed, sparsity must be None, " + "so that each unit will correspond to 1 row of the output dataframe." + ) + assert ( + self.sorting_result.get_channel_locations().shape[1] == 2 + ), "If multi-channel metrics are computed, channel locations must be 2D." + + if metric_names is None: metric_names = get_single_channel_template_metric_names() if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - metrics_kwargs = metrics_kwargs or dict() + + if metrics_kwargs is None: + metrics_kwargs_ = _default_function_kwargs.copy() + else: + metrics_kwargs_ = _default_function_kwargs.copy() + metrics_kwargs_.update(metrics_kwargs) + params = dict( metric_names=[str(name) for name in np.unique(metric_names)], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - metrics_kwargs=metrics_kwargs, + metrics_kwargs=metrics_kwargs_, ) return params def _select_extension_data(self, unit_ids): - # filter metrics dataframe - new_metrics = self._extension_data["metrics"].loc[np.array(unit_ids)] + new_metrics = self.data["metrics"].loc[np.array(unit_ids)] return dict(metrics=new_metrics) def _run(self): import pandas as pd from scipy.signal import resample_poly - metric_names = self._params["metric_names"] - sparsity = self._params["sparsity"] - peak_sign = self._params["peak_sign"] - upsampling_factor = self._params["upsampling_factor"] - unit_ids = self.waveform_extractor.sorting.unit_ids - sampling_frequency = self.waveform_extractor.sampling_frequency + metric_names = self.params["metric_names"] + sparsity = self.params["sparsity"] + peak_sign = self.params["peak_sign"] + upsampling_factor = self.params["upsampling_factor"] + unit_ids = self.sorting_result.unit_ids + sampling_frequency = self.sorting_result.sampling_frequency metrics_single_channel = [m for m in metric_names if m in get_single_channel_template_metric_names()] metrics_multi_channel = [m for m in metric_names if m in get_multi_channel_template_metric_names()] if sparsity is None: extremum_channels_ids = get_template_extremum_channel( - self.waveform_extractor, peak_sign=peak_sign, outputs="id" + self.sorting_result, peak_sign=peak_sign, outputs="id" ) template_metrics = pd.DataFrame(index=unit_ids, columns=metric_names) @@ -106,15 +186,16 @@ def _run(self): ) template_metrics = pd.DataFrame(index=multi_index, columns=metric_names) - all_templates = self.waveform_extractor.get_all_templates() - channel_locations = self.waveform_extractor.get_channel_locations() + all_templates = _get_dense_templates_array(self.sorting_result, return_scaled=True) + + channel_locations = self.sorting_result.get_channel_locations() for unit_index, unit_id in enumerate(unit_ids): template_all_chans = all_templates[unit_index] chan_ids = np.array(extremum_channels_ids[unit_id]) if chan_ids.ndim == 0: chan_ids = [chan_ids] - chan_ind = self.waveform_extractor.channel_ids_to_indices(chan_ids) + chan_ind = self.sorting_result.channel_ids_to_indices(chan_ids) template = template_all_chans[:, chan_ind] # compute single_channel metrics @@ -140,22 +221,25 @@ def _run(self): sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self._params["metrics_kwargs"], + **self.params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value # compute metrics multi_channel for metric_name in metrics_multi_channel: # retrieve template (with sparsity if waveform extractor is sparse) - template = self.waveform_extractor.get_template(unit_id=unit_id) + template = all_templates[unit_index, :, :] + if self.sorting_result.is_sparse(): + mask = self.sorting_result.sparsity.mask[unit_index, :] + template = template[:, mask] if template.shape[1] < self.min_channels_for_multi_channel_warning: warnings.warn( f"With less than {self.min_channels_for_multi_channel_warning} channels, " "multi-channel metrics might not be reliable." ) - if self.waveform_extractor.is_sparse(): - channel_locations_sparse = channel_locations[self.waveform_extractor.sparsity.mask[unit_index]] + if self.sorting_result.is_sparse(): + channel_locations_sparse = channel_locations[self.sorting_result.sparsity.mask[unit_index]] else: channel_locations_sparse = channel_locations @@ -172,30 +256,27 @@ def _run(self): template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self._params["metrics_kwargs"], + **self.params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value - self._extension_data["metrics"] = template_metrics + self.data["metrics"] = template_metrics - def get_data(self): - """ - Get the computed metrics. + # def get_data(self): + # """ + # Get the computed metrics. - Returns - ------- - metrics : pd.DataFrame - Dataframe with template metrics - """ - msg = "Template metrics are not computed. Use the 'run()' function." - assert self._extension_data["metrics"] is not None, msg - return self._extension_data["metrics"] + # Returns + # ------- + # metrics : pd.DataFrame + # Dataframe with template metrics + # """ + # msg = "Template metrics are not computed. Use the 'run()' function." + # assert self.data["metrics"] is not None, msg + # return self.data["metrics"] - @staticmethod - def get_extension_function(): - return compute_template_metrics - -WaveformExtractor.register_extension(TemplateMetricsCalculator) +register_result_extension(ComputeTemplateMetrics) +compute_template_metrics = ComputeTemplateMetrics.function_factory() _default_function_kwargs = dict( @@ -213,114 +294,6 @@ def get_extension_function(): ) -def compute_template_metrics( - waveform_extractor, - load_if_exists: bool = False, - metric_names: Optional[list[str]] = None, - peak_sign: Optional[str] = "neg", - upsampling_factor: int = 10, - sparsity: Optional[ChannelSparsity] = None, - include_multi_channel_metrics: bool = False, - metrics_kwargs: dict = None, -): - """ - Compute template metrics including: - * peak_to_valley - * peak_trough_ratio - * halfwidth - * repolarization_slope - * recovery_slope - * num_positive_peaks - * num_negative_peaks - - Optionally, the following multi-channel metrics can be computed (when include_multi_channel_metrics=True): - * velocity_above - * velocity_below - * exp_decay - * spread - - Parameters - ---------- - waveform_extractor : WaveformExtractor - The waveform extractor used to compute template metrics - load_if_exists : bool, default: False - Whether to load precomputed template metrics, if they already exist. - metric_names : list or None, default: None - List of metrics to compute (see si.postprocessing.get_template_metric_names()) - peak_sign : {"neg", "pos"}, default: "neg" - Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. - upsampling_factor : int, default: 10 - The upsampling factor to upsample the templates - sparsity: ChannelSparsity or None, default: None - If None, template metrics are computed on the extremum channel only. - If sparsity is given, template metrics are computed on all sparse channels of each unit. - For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. - include_multi_channel_metrics: bool, default: False - Whether to compute multi-channel metrics - metrics_kwargs: dict - Additional arguments to pass to the metric functions. Including: - * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 - * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 - * peak_width_ms: the width in samples to detect peaks, default: 0.2 - * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) - * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 - * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 - * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" - * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 - * spread_threshold: the threshold to compute the spread, default: 0.2 - * spread_smooth_um: the smoothing in um to compute the spread, default: 20 - * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None - - If None, all channels all channels are considered - - If 0 or 1, only the "column" that includes the max channel is considered - - If > 1, only channels within range (+/-) um from the max channel horizontal position are used - - Returns - ------- - template_metrics : pd.DataFrame - Dataframe with the computed template metrics. - If "sparsity" is None, the index is the unit_id. - If "sparsity" is given, the index is a multi-index (unit_id, channel_id) - - Notes - ----- - If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, - so that one metric value will be computed per unit. - For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". - """ - if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): - tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) - else: - tmc = TemplateMetricsCalculator(waveform_extractor) - # For 2D metrics, external sparsity must be None, so that one metric value will be computed per unit. - if include_multi_channel_metrics or ( - metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) - ): - assert sparsity is None, ( - "If multi-channel metrics are computed, sparsity must be None, " - "so that each unit will correspond to 1 row of the output dataframe." - ) - assert ( - waveform_extractor.get_channel_locations().shape[1] == 2 - ), "If multi-channel metrics are computed, channel locations must be 2D." - default_kwargs = _default_function_kwargs.copy() - if metrics_kwargs is None: - metrics_kwargs = default_kwargs - else: - default_kwargs.update(metrics_kwargs) - metrics_kwargs = default_kwargs - tmc.set_params( - metric_names=metric_names, - peak_sign=peak_sign, - upsampling_factor=upsampling_factor, - sparsity=sparsity, - include_multi_channel_metrics=include_multi_channel_metrics, - metrics_kwargs=metrics_kwargs, - ) - tmc.run() - - metrics = tmc.get_data() - - return metrics def get_trough_and_peak_idx(template): diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 5febdf83f7..748e342dc6 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -1,116 +1,144 @@ import numpy as np -from ..core import WaveformExtractor -from ..core.waveform_extractor import BaseWaveformExtractorExtension +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from ..core.template_tools import _get_dense_templates_array -class TemplateSimilarityCalculator(BaseWaveformExtractorExtension): +class ComputeTemplateSimilarity(ResultExtension): """Compute similarity between templates with several methods. + Parameters ---------- - waveform_extractor: WaveformExtractor - A waveform extractor object + sorting_result: SortingResult + The SortingResult object + method: str, default: "cosine_similarity" + The method to compute the similarity + + Returns + ------- + similarity: np.array + The similarity matrix """ extension_name = "similarity" + depend_on = ["fast_templates|templates", ] + need_recording = True + use_nodepipeline = False - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) def _set_params(self, method="cosine_similarity"): params = dict(method=method) - return params def _select_extension_data(self, unit_ids): # filter metrics dataframe - unit_indices = self.waveform_extractor.sorting.ids_to_indices(unit_ids) - new_similarity = self._extension_data["similarity"][unit_indices][:, unit_indices] + unit_indices = self.sorting_result.sorting.ids_to_indices(unit_ids) + new_similarity = self.data["similarity"][unit_indices][:, unit_indices] return dict(similarity=new_similarity) def _run(self): - similarity = _compute_template_similarity(self.waveform_extractor, method=self._params["method"]) - self._extension_data["similarity"] = similarity - - def get_data(self): - """ - Get the computed similarity. - - Returns - ------- - similarity : 2d np.array - 2d matrix with computed similarity values. - """ - msg = "Template similarity is not computed. Use the 'run()' function." - assert self._extension_data["similarity"] is not None, msg - return self._extension_data["similarity"] - - @staticmethod - def get_extension_function(): - return compute_template_similarity - - -WaveformExtractor.register_extension(TemplateSimilarityCalculator) - - -def _compute_template_similarity( - waveform_extractor, load_if_exists=False, method="cosine_similarity", waveform_extractor_other=None -): + templates_array = _get_dense_templates_array(self.sorting_result, return_scaled=True) + similarity = compute_similarity_with_templates_array(templates_array, templates_array, method=self.params["method"]) + self.data["similarity"] = similarity + + # def get_data(self): + # """ + # Get the computed similarity. + + # Returns + # ------- + # similarity : 2d np.array + # 2d matrix with computed similarity values. + # """ + # msg = "Template similarity is not computed. Use the 'run()' function." + # assert self._extension_data["similarity"] is not None, msg + # return self._extension_data["similarity"] + + # @staticmethod + # def get_extension_function(): + # return compute_template_similarity + +register_result_extension(ComputeTemplateSimilarity) +compute_template_similarity = ComputeTemplateSimilarity.function_factory() + +def compute_similarity_with_templates_array(templates_array, other_templates_array, method): import sklearn.metrics.pairwise - templates = waveform_extractor.get_all_templates() - s = templates.shape if method == "cosine_similarity": - templates_flat = templates.reshape(s[0], -1) - if waveform_extractor_other is not None: - templates_other = waveform_extractor_other.get_all_templates() - s_other = templates_other.shape - templates_other_flat = templates_other.reshape(s_other[0], -1) - assert len(templates_flat[0]) == len(templates_other_flat[0]), ( - "Templates from second WaveformExtractor " "don't have the correct shape!" - ) - else: - templates_other_flat = None - similarity = sklearn.metrics.pairwise.cosine_similarity(templates_flat, templates_other_flat) - # elif method == '': + assert templates_array.shape[0] == other_templates_array.shape[0] + templates_flat = templates_array.reshape(templates_array.shape[0], -1) + other_templates_flat = templates_array.reshape(other_templates_array.shape[0], -1) + similarity = sklearn.metrics.pairwise.cosine_similarity(templates_flat, other_templates_flat) + else: raise ValueError(f"compute_template_similarity(method {method}) not exists") - + return similarity -def compute_template_similarity( - waveform_extractor, load_if_exists=False, method="cosine_similarity", waveform_extractor_other=None -): - """Compute similarity between templates with several methods. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - A waveform extractor object - load_if_exists : bool, default: False - Whether to load precomputed similarity, if is already exists. - method: str, default: "cosine_similarity" - The method to compute the similarity - waveform_extractor_other: WaveformExtractor, default: None - A second waveform extractor object - - Returns - ------- - similarity: np.array - The similarity matrix - """ - if waveform_extractor_other is None: - if load_if_exists and waveform_extractor.is_extension(TemplateSimilarityCalculator.extension_name): - tmc = waveform_extractor.load_extension(TemplateSimilarityCalculator.extension_name) - else: - tmc = TemplateSimilarityCalculator(waveform_extractor) - tmc.set_params(method=method) - tmc.run() - similarity = tmc.get_data() - return similarity - else: - return _compute_template_similarity(waveform_extractor, waveform_extractor_other, method) +# TODO port the waveform_extractor_other concept that compare 2 SortingResult + + +# def _compute_template_similarity( +# waveform_extractor, load_if_exists=False, method="cosine_similarity", waveform_extractor_other=None +# ): +# import sklearn.metrics.pairwise + +# templates = waveform_extractor.get_all_templates() +# s = templates.shape +# if method == "cosine_similarity": +# templates_flat = templates.reshape(s[0], -1) +# if waveform_extractor_other is not None: +# templates_other = waveform_extractor_other.get_all_templates() +# s_other = templates_other.shape +# templates_other_flat = templates_other.reshape(s_other[0], -1) +# assert len(templates_flat[0]) == len(templates_other_flat[0]), ( +# "Templates from second WaveformExtractor " "don't have the correct shape!" +# ) +# else: +# templates_other_flat = None +# similarity = sklearn.metrics.pairwise.cosine_similarity(templates_flat, templates_other_flat) +# # elif method == '': +# else: +# raise ValueError(f"compute_template_similarity(method {method}) not exists") + +# return similarity + + +# def compute_template_similarity( +# waveform_extractor, load_if_exists=False, method="cosine_similarity", waveform_extractor_other=None +# ): +# """Compute similarity between templates with several methods. + +# Parameters +# ---------- +# waveform_extractor: WaveformExtractor +# A waveform extractor object +# load_if_exists : bool, default: False +# Whether to load precomputed similarity, if is already exists. +# method: str, default: "cosine_similarity" +# The method to compute the similarity +# waveform_extractor_other: WaveformExtractor, default: None +# A second waveform extractor object + +# Returns +# ------- +# similarity: np.array +# The similarity matrix +# """ +# if waveform_extractor_other is None: +# if load_if_exists and waveform_extractor.is_extension(TemplateSimilarityCalculator.extension_name): +# tmc = waveform_extractor.load_extension(TemplateSimilarityCalculator.extension_name) +# else: +# tmc = TemplateSimilarityCalculator(waveform_extractor) +# tmc.set_params(method=method) +# tmc.run() +# similarity = tmc.get_data() +# return similarity +# else: +# return _compute_template_similarity(waveform_extractor, waveform_extractor_other, method) def check_equal_template_with_distribution_overlap( diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 30e5881024..430523cf99 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,30 +1,31 @@ import unittest -from spikeinterface import extract_waveforms, WaveformExtractor -from spikeinterface.extractors import toy_example -from spikeinterface.postprocessing import TemplateMetricsCalculator +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing import ComputeTemplateMetrics -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite -class TemplateMetricsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = TemplateMetricsCalculator - extension_data_names = ["metrics"] - extension_function_kwargs_list = [dict(), dict(upsampling_factor=2)] - exact_same_content = False +class TemplateMetricsTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputeTemplateMetrics + extension_function_kwargs_list = [ + dict(), + dict(upsampling_factor=2), + dict(include_multi_channel_metrics=True), + ] - def test_sparse_metrics(self): - tm_sparse = self.extension_class.get_extension_function()(self.we1, sparsity=self.sparsity1) - print(tm_sparse) + # def test_sparse_metrics(self): + # tm_sparse = self.extension_class.get_extension_function()(self.we1, sparsity=self.sparsity1) + # print(tm_sparse) - def test_multi_channel_metrics(self): - tm_multi = self.extension_class.get_extension_function()(self.we1, include_multi_channel_metrics=True) - print(tm_multi) + # def test_multi_channel_metrics(self): + # tm_multi = self.extension_class.get_extension_function()(self.we1, include_multi_channel_metrics=True) + # print(tm_multi) if __name__ == "__main__": - test = TemplateMetricsExtensionTest() + test = TemplateMetricsTest() test.setUp() test.test_extension() - test.test_multi_channel_metrics() + # test.test_extension() + # test.test_multi_channel_metrics() diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 210954bbc4..646fccf0fa 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -1,28 +1,30 @@ import unittest -from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, TemplateSimilarityCalculator +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite +from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity -class SimilarityExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = TemplateSimilarityCalculator - extension_data_names = ["similarity"] +class SimilarityExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputeTemplateSimilarity + extension_function_kwargs_list = [ + dict(method="cosine_similarity"), + ] - # extend common test - def test_check_equal_template_with_distribution_overlap(self): - we = self.we1 - for unit_id0 in we.unit_ids: - waveforms0 = we.get_waveforms(unit_id0) - for unit_id1 in we.unit_ids: - if unit_id0 == unit_id1: - continue - waveforms1 = we.get_waveforms(unit_id1) - check_equal_template_with_distribution_overlap(waveforms0, waveforms1) + # # extend common test + # def test_check_equal_template_with_distribution_overlap(self): + # we = self.we1 + # for unit_id0 in we.unit_ids: + # waveforms0 = we.get_waveforms(unit_id0) + # for unit_id1 in we.unit_ids: + # if unit_id0 == unit_id1: + # continue + # waveforms1 = we.get_waveforms(unit_id1) + # check_equal_template_with_distribution_overlap(waveforms0, waveforms1) +# TODO check_equal_template_with_distribution_overlap if __name__ == "__main__": test = SimilarityExtensionTest() test.setUp() test.test_extension() - test.test_check_equal_template_with_distribution_overlap() From d3ac13f1b798885a3dc049014a6dcb4592e31182 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 24 Jan 2024 18:36:26 +0100 Subject: [PATCH 027/192] Make spike_amplitudes use node_piepeline machinery --- src/spikeinterface/core/node_pipeline.py | 10 +- src/spikeinterface/core/sortingresult.py | 62 +++++- .../postprocessing/spike_amplitudes.py | 197 +++++++----------- .../tests/test_spike_amplitudes.py | 38 +--- 4 files changed, 149 insertions(+), 158 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index e8f0c0d5c3..faef731e34 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -41,6 +41,11 @@ ] +spike_peak_dtype = base_peak_dtype + [ + ("unit_index", "int64"), +] + + class PipelineNode: def __init__( self, @@ -189,7 +194,7 @@ def get_trace_margin(self): return 0 def get_dtype(self): - return base_peak_dtype + return spike_peak_dtype def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks @@ -223,12 +228,13 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def sorting_to_peaks(sorting, extremum_channel_inds): spikes = sorting.to_spike_vector() - peaks = np.zeros(spikes.size, dtype=base_peak_dtype) + peaks = np.zeros(spikes.size, dtype=spike_peak_dtype) peaks["sample_index"] = spikes["sample_index"] extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]] peaks["amplitude"] = 0.0 peaks["segment_index"] = spikes["segment_index"] + peaks["unit_index"] = spikes["unit_index"] return peaks diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index ce46938a10..83d7f4e975 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -735,7 +735,27 @@ def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) ## extensions zone - def compute(self, extension_name, save=True, **params): + + + + def compute(self, input, save=True, **params): + """ + Compute one extension or several extension. + Internally calling compute_one_extension() or compute_several_extensions() depending th input type. + + Parameters + ---------- + input: str or dict + If the input is a string then compute one extension with compute_one_extension(extension_name=input, ...) + If the input is a dict then compute several extension with compute_several_extensions(extensions=input) + """ + if isinstance(input, str): + self.compute_one_extension(extension_name=input, save=save, **params) + elif isinstance(input, dict): + assert len(params) == 0, "Too many arguments for SortingResult.compute_several_extensions()" + self.compute_several_extensions(extensions=input, save=save) + + def compute_one_extension(self, extension_name, save=True, **params): """ Compute one extension @@ -761,6 +781,7 @@ def compute(self, extension_name, save=True, **params): -------- >>> extension = sortres.compute("waveforms", **some_params) + >>> extension = sortres.compute_one_extension("waveforms", **some_params) >>> wfs = extension.data["waveforms"] """ @@ -790,6 +811,36 @@ def compute(self, extension_name, save=True, **params): # OR return extension_instance.data + def compute_several_extensions(self, extensions, save=True): + """ + Compute several extensions + + Parameters + ---------- + extensions: dict + Key are extension_name and values are params. + save: bool, default True + It the extension can be saved then it is saved. + If not then the extension will only live in memory as long as the object is deleted. + save=False is convinient to try some parameters without changing an already saved extension. + + Returns + ------- + No return + + Examples + -------- + + >>> sortres.compute({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std", ]} }) + >>> sortres.compute_several_extensions({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std"]}}) + + """ + # TODO this is a simple implementation + # this will be improved with nodepipeline!!! + for extension_name, extension_params in extensions.items(): + self.compute_one_extension(self, extension_name, save=save, **extension_params) + + def get_saved_extension_names(self): """ Get extension saved in folder or zarr that can be loaded. @@ -1035,6 +1086,11 @@ def _set_params(self, **params): def _select_extension_data(self, unit_ids): # must be implemented in subclass raise NotImplementedError + + def _get_pipeline_nodes(self): + # must be implemented in subclass only if use_nodepipeline=True + raise NotImplementedError + # ####### @@ -1326,5 +1382,7 @@ def _save_params(self): extension_group = self._get_zarr_extension_group(mode="r+") extension_group.attrs["params"] = check_json(params_to_save) - + def get_pipeline_nodes(self): + assert self.use_nodepipeline, "ResultExtension.get_pipeline_nodes() must be called only when use_nodepipeline=True" + return self._get_pipeline_nodes() diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index f9f15551a9..c0c2004028 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -1,5 +1,5 @@ import numpy as np -import shutil +import warnings from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs @@ -8,7 +8,7 @@ # from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension from spikeinterface.core.sortingresult import register_result_extension, ResultExtension - +from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type class ComputeSpikeAmplitudes(ResultExtension): """ @@ -39,8 +39,7 @@ class ComputeSpikeAmplitudes(ResultExtension): extension_name = "spike_amplitudes" depend_on = ["fast_templates|templates", ] need_recording = True - # TODO: implement this as a pipeline - use_nodepipeline = False + use_nodepipeline = True def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) @@ -62,58 +61,47 @@ def _select_extension_data(self, unit_ids): return new_data - def _run(self, **job_kwargs): - if not self.sorting_result.has_recording(): - self.sorting_result.delete_extension(ComputeSpikeAmplitudes.extension_name) - raise ValueError("compute_spike_amplitudes() cannot run with a SortingResult in recordless mode.") - job_kwargs = fix_job_kwargs(job_kwargs) - sorting_result = self.sorting_result - recording = sorting_result.recording - sorting = sorting_result.sorting + def _get_pipeline_nodes(self): - all_spikes = sorting.to_spike_vector() - self._all_spikes = all_spikes + recording = self.sorting_result.recording + sorting = self.sorting_result.sorting peak_sign = self.params["peak_sign"] return_scaled = self.params["return_scaled"] - extremum_channels_index = get_template_extremum_channel(sorting_result, peak_sign=peak_sign, outputs="index") - peak_shifts = get_template_extremum_channel_peak_shift(sorting_result, peak_sign=peak_sign) - - # put extremum_channels_index and peak_shifts in vector way - extremum_channels_index = np.array( - [extremum_channels_index[unit_id] for unit_id in sorting.unit_ids], dtype="int64" - ) - peak_shifts = np.array([peak_shifts[unit_id] for unit_id in sorting.unit_ids], dtype="int64") + extremum_channels_indices = get_template_extremum_channel(self.sorting_result, peak_sign=peak_sign, outputs="index") + peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_result, peak_sign=peak_sign) if return_scaled: # check if has scaled values: if not recording.has_scaled_traces(): - print("Setting 'return_scaled' to False") + warnings.warn("Recording doesn't have scaled traces! Setting 'return_scaled' to False") return_scaled = False - # and run - func = _spike_amplitudes_chunk - init_func = _init_worker_spike_amplitudes - n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) - init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs + spike_retriever_node = SpikeRetriever( + recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channels_indices + ) + spike_amplitudes_node = SpikeAmplitudeNode( + recording, + parents=[spike_retriever_node], + return_output=True, + peak_shifts=peak_shifts, + return_scaled=return_scaled, + ) + nodes = [spike_retriever_node, spike_amplitudes_node] + return nodes + + def _run(self, **job_kwargs): + # TODO later gather to disk when format="binary_folder" + job_kwargs = fix_job_kwargs(job_kwargs) + nodes = self.get_pipeline_nodes() + amps = run_node_pipeline( + self.sorting_result.recording, nodes, job_kwargs=job_kwargs, job_name="spike_amplitudes", gather_mode="memory" ) - # out = processor.run() - # amps, segments = zip(*out) - # amps = np.concatenate(amps) - # segments = np.concatenate(segments) - - # for segment_index in range(recording.get_num_segments()): - # mask = segments == segment_index - # amps_seg = amps[mask] - # self._extension_data[f"amplitude_segment_{segment_index}"] = amps_seg - amps = processor.run() - amps = np.concatenate(amps) self.data["amplitudes"] = amps + # def get_data(self, outputs="concatenated"): # """ # Get computed spike amplitudes. @@ -150,99 +138,68 @@ def _run(self, **job_kwargs): # amplitudes_by_unit[segment_index][unit_id] = amps # return amplitudes_by_unit - # @staticmethod - # def get_extension_function(): - # return compute_spike_amplitudes -# WaveformExtractor.register_extension(SpikeAmplitudesCalculator) register_result_extension(ComputeSpikeAmplitudes) compute_spike_amplitudes = ComputeSpikeAmplitudes.function_factory() -# def compute_spike_amplitudes( -# sorting_result, load_if_exists=False, peak_sign="neg", return_scaled=True, outputs="concatenated", **job_kwargs -# ): - -# if load_if_exists and sorting_result.has_extension(SpikeAmplitudesCalculator.extension_name): -# sac = sorting_result.load_extension(SpikeAmplitudesCalculator.extension_name) -# else: -# sac = SpikeAmplitudesCalculator(sorting_result) -# sac.set_params(peak_sign=peak_sign, return_scaled=return_scaled) -# sac.run(**job_kwargs) - -# amps = sac.get_data(outputs=outputs) -# return amps - - -# compute_spike_amplitudes.__doc__.format(_shared_job_kwargs_doc) -def _init_worker_spike_amplitudes(recording, sorting, extremum_channels_index, peak_shifts, return_scaled): - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["sorting"] = sorting - worker_ctx["return_scaled"] = return_scaled - worker_ctx["peak_shifts"] = peak_shifts - worker_ctx["min_shift"] = np.min(peak_shifts) - worker_ctx["max_shifts"] = np.max(peak_shifts) - worker_ctx["all_spikes"] = sorting.to_spike_vector(concatenated=False) - worker_ctx["extremum_channels_index"] = extremum_channels_index - - return worker_ctx - - -def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - all_spikes = worker_ctx["all_spikes"] - recording = worker_ctx["recording"] - return_scaled = worker_ctx["return_scaled"] - peak_shifts = worker_ctx["peak_shifts"] - - seg_size = recording.get_num_samples(segment_index=segment_index) - - spike_times = all_spikes[segment_index]["sample_index"] - spike_labels = all_spikes[segment_index]["unit_index"] - - d = np.diff(spike_times) - assert np.all(d >= 0) - - i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) - n_spikes = i1 - i0 - amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype()) - - if i0 != i1: - # some spike in the chunk +class SpikeAmplitudeNode(PipelineNode): + def __init__( + self, + recording, + parents=None, + return_output=True, + peak_shifts=None, + return_scaled=True, + ): + PipelineNode.__init__(self, recording, parents=parents, return_output=return_output) + self.return_scaled = return_scaled + if return_scaled and recording.has_scaled(): + self._dtype = np.float32 + self._gains = recording.get_channel_gains() + self._offsets = recording.get_channel_gains() + else: + self._dtype = recording.get_dtype() + self._gains = None + self._offsets = None + spike_retriever = find_parent_of_type(parents, SpikeRetriever) + assert isinstance( + spike_retriever, SpikeRetriever + ), "SpikeAmplitudeNode needs a single SpikeRetriever as a parent" + # put peak_shifts in vector way + self._peak_shifts = np.array(list(peak_shifts.values()), dtype="int64") + self._margin = np.max(np.abs(self._peak_shifts)) + self._kwargs.update( + peak_shifts=peak_shifts, + return_scaled=return_scaled, + ) - extremum_channels_index = worker_ctx["extremum_channels_index"] + def get_dtype(self): + return self._dtype - sample_inds = spike_times[i0:i1].copy() - labels = spike_labels[i0:i1] + def compute(self, traces, peaks): + sample_indices = peaks["sample_index"].copy() + unit_index = peaks["unit_index"] + chan_inds = peaks["channel_index"] # apply shifts per spike - sample_inds += peak_shifts[labels] - - # get channels per spike - chan_inds = extremum_channels_index[labels] - - # prevent border accident due to shift - sample_inds[sample_inds < 0] = 0 - sample_inds[sample_inds >= seg_size] = seg_size - 1 - - first = np.min(sample_inds) - last = np.max(sample_inds) - sample_inds -= first - - # load trace in memory - traces = recording.get_traces( - start_frame=first, end_frame=last + 1, segment_index=segment_index, return_scaled=return_scaled - ) + sample_indices += self._peak_shifts[unit_index] # and get amplitudes - amplitudes = traces[sample_inds, chan_inds] + amplitudes = traces[sample_indices, chan_inds] + + # and scale + if self._gains is not None: + traces = traces.astype("float32") * self._gains + self._offsets + amplitudes = amplitudes.astype('float32', copy=True) + amplitudes *= self._gains[chan_inds] + amplitudes += self._offsets[chan_inds] - # segments = np.zeros(amplitudes.size, dtype="int64") + segment_index + return amplitudes - # return amplitudes, segments - return amplitudes + def get_trace_margin(self): + return self._margin diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index 60416570c1..86cff185ed 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -2,9 +2,6 @@ import numpy as np from spikeinterface.postprocessing import ComputeSpikeAmplitudes - -# from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite - from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite @@ -14,40 +11,13 @@ class ComputeSpikeAmplitudesTest(ResultExtensionCommonTestSuite, unittest.TestCa dict(), ] -# class SpikeAmplitudesExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): -# extension_class = SpikeAmplitudesCalculator -# extension_data_names = ["amplitude_segment_0"] -# extension_function_kwargs_list = [ -# dict(peak_sign="neg", outputs="concatenated", chunk_size=10000, n_jobs=1), -# dict(peak_sign="neg", outputs="by_unit", chunk_size=10000, n_jobs=1), -# ] - -# def test_scaled(self): -# amplitudes_scaled = self.extension_class.get_extension_function()( -# self.we1, peak_sign="neg", outputs="concatenated", chunk_size=10000, n_jobs=1, return_scaled=True -# ) -# amplitudes_unscaled = self.extension_class.get_extension_function()( -# self.we1, peak_sign="neg", outputs="concatenated", chunk_size=10000, n_jobs=1, return_scaled=False -# ) -# gain = self.we1.recording.get_channel_gains()[0] - -# assert np.allclose(amplitudes_scaled[0], amplitudes_unscaled[0] * gain) - -# def test_parallel(self): -# amplitudes1 = self.extension_class.get_extension_function()( -# self.we1, peak_sign="neg", load_if_exists=False, outputs="concatenated", chunk_size=10000, n_jobs=1 -# ) -# # TODO : fix multi processing for spike amplitudes!!!!!!! -# amplitudes2 = self.extension_class.get_extension_function()( -# self.we1, peak_sign="neg", load_if_exists=False, outputs="concatenated", chunk_size=10000, n_jobs=2 -# ) - -# assert np.array_equal(amplitudes1[0], amplitudes2[0]) - - if __name__ == "__main__": test = ComputeSpikeAmplitudesTest() test.setUp() test.test_extension() # test.test_scaled() # test.test_parallel() + + # for k, sorting_result in test.sorting_results.items(): + # print(sorting_result) + # print(sorting_result.get_extension("spike_amplitudes").data["amplitudes"].shape) From a3baef8b4f1ed8a527376ee8ca3435c612bac8e9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 24 Jan 2024 20:48:17 +0100 Subject: [PATCH 028/192] spike_amplitudes and spikelocations are using nodepipeline --- src/spikeinterface/postprocessing/__init__.py | 2 +- .../postprocessing/spike_amplitudes.py | 39 ++- .../postprocessing/spike_locations.py | 259 ++++++++---------- .../tests/test_spike_amplitudes.py | 5 +- .../tests/test_spike_locations.py | 34 +-- .../sortingcomponents/peak_localization.py | 13 +- 6 files changed, 169 insertions(+), 183 deletions(-) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 6f4112095d..c23e56355b 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -34,7 +34,7 @@ compute_isi_histograms_numba, ) -from .spike_locations import compute_spike_locations, SpikeLocationsCalculator +from .spike_locations import compute_spike_locations, ComputeSpikeLocations from .unit_localization import ( compute_unit_locations, diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index c0c2004028..63bbc47411 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -1,12 +1,10 @@ import numpy as np import warnings -from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift -# from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension - from spikeinterface.core.sortingresult import register_result_extension, ResultExtension from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type @@ -16,7 +14,42 @@ class ComputeSpikeAmplitudes(ResultExtension): Computes the spike amplitudes. Need "templates" or "fast_templates" to be computed first. + Localize spikes in 2D or 3D with several methods given the template. + + Parameters + ---------- + sorting_result: SortingResult + A SortingResult object + ms_before : float, default: 0.5 + The left window, before a peak, in milliseconds + ms_after : float, default: 0.5 + The right window, after a peak, in milliseconds + spike_retriver_kwargs: dict + A dictionary to control the behavior for getting the maximum channel for each spike + This dictionary contains: + + * channel_from_template: bool, default: True + For each spike is the maximum channel computed from template or re estimated at every spikes + channel_from_template = True is old behavior but less acurate + channel_from_template = False is slower but more accurate + * radius_um: float, default: 50 + In case channel_from_template=False, this is the radius to get the true peak + * peak_sign, default: "neg" + In case channel_from_template=False, this is the peak sign. + method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + The localization method to use + method_kwargs : dict, default: dict() + Other kwargs depending on the method. + outputs : "concatenated" | "by_unit", default: "concatenated" + The output format + {} + Returns + ------- + spike_locations: np.array or list of dict + The spike locations. + - If "concatenated" all locations for all spikes and all units are concatenated + - If "by_unit", locations are returned as a list (for segments) of dictionaries (for units) 1. Determine the max channel per unit. 2. Then a "peak_shift" is estimated because for some sorters the spike index is not always at the peak. diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index dfa940b979..5b88b6a761 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -1,148 +1,25 @@ import numpy as np from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.template_tools import get_template_extremum_channel -from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift -from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension -from spikeinterface.core.node_pipeline import SpikeRetriever +from spikeinterface.core.node_pipeline import SpikeRetriever, run_node_pipeline -class SpikeLocationsCalculator(BaseWaveformExtractorExtension): - """ - Computes spike locations from WaveformExtractor. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - A waveform extractor object - """ - - extension_name = "spike_locations" - - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) - - extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") - self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - - def _set_params( - self, - ms_before=0.5, - ms_after=0.5, - spike_retriver_kwargs=dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg", - ), - method="center_of_mass", - method_kwargs={}, - ): - params = dict( - ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs, method=method - ) - params.update(**method_kwargs) - return params - - def _select_extension_data(self, unit_ids): - old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - - spike_mask = np.isin(self.spikes["unit_index"], unit_inds) - new_spike_locations = self._extension_data["spike_locations"][spike_mask] - return dict(spike_locations=new_spike_locations) - - def _run(self, **job_kwargs): - """ - This function first transforms the sorting object into a `peaks` numpy array and then - uses the`sortingcomponents.peak_localization.localize_peaks()` function to triangulate - spike locations. - """ - from spikeinterface.sortingcomponents.peak_localization import _run_localization_from_peak_source - - job_kwargs = fix_job_kwargs(job_kwargs) - - we = self.waveform_extractor +# TODO job_kwargs - extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - params = self._params.copy() - spike_retriver_kwargs = params.pop("spike_retriver_kwargs") - spike_retriever = SpikeRetriever( - we.recording, we.sorting, extremum_channel_inds=extremum_channel_inds, **spike_retriver_kwargs - ) - spike_locations = _run_localization_from_peak_source(we.recording, spike_retriever, **params, **job_kwargs) - - self._extension_data["spike_locations"] = spike_locations - - def get_data(self, outputs="concatenated"): - """ - Get computed spike locations - - Parameters - ---------- - outputs : "concatenated" | "by_unit", default: "concatenated" - The output format - - Returns - ------- - spike_locations : np.array or dict - The spike locations as a structured array (outputs="concatenated") or - as a dict with units as key and spike locations as values. - """ - we = self.waveform_extractor - sorting = we.sorting - - if outputs == "concatenated": - return self._extension_data["spike_locations"] - - elif outputs == "by_unit": - locations_by_unit = [] - for segment_index in range(self.waveform_extractor.get_num_segments()): - i0 = np.searchsorted(self.spikes["segment_index"], segment_index, side="left") - i1 = np.searchsorted(self.spikes["segment_index"], segment_index, side="right") - spikes = self.spikes[i0:i1] - locations = self._extension_data["spike_locations"][i0:i1] - - locations_by_unit.append({}) - for unit_ind, unit_id in enumerate(sorting.unit_ids): - mask = spikes["unit_index"] == unit_ind - locations_by_unit[segment_index][unit_id] = locations[mask] - return locations_by_unit - - @staticmethod - def get_extension_function(): - return compute_spike_locations - - -WaveformExtractor.register_extension(SpikeLocationsCalculator) - - -def compute_spike_locations( - waveform_extractor, - load_if_exists=False, - ms_before=0.5, - ms_after=0.5, - spike_retriver_kwargs=dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg", - ), - method="center_of_mass", - method_kwargs={}, - outputs="concatenated", - **job_kwargs, -): +class ComputeSpikeLocations(ResultExtension): """ Localize spikes in 2D or 3D with several methods given the template. Parameters ---------- - waveform_extractor : WaveformExtractor - A waveform extractor object - load_if_exists : bool, default: False - Whether to load precomputed spike locations, if they already exist + sorting_result: SortingResult + A SortingResult object ms_before : float, default: 0.5 The left window, before a peak, in milliseconds ms_after : float, default: 0.5 @@ -172,23 +49,113 @@ def compute_spike_locations( spike_locations: np.array or list of dict The spike locations. - If "concatenated" all locations for all spikes and all units are concatenated - - If "by_unit", locations are returned as a list (for segments) of dictionaries (for units) - """ - if load_if_exists and waveform_extractor.is_extension(SpikeLocationsCalculator.extension_name): - slc = waveform_extractor.load_extension(SpikeLocationsCalculator.extension_name) - else: - slc = SpikeLocationsCalculator(waveform_extractor) - slc.set_params( - ms_before=ms_before, - ms_after=ms_after, - spike_retriver_kwargs=spike_retriver_kwargs, - method=method, - method_kwargs=method_kwargs, + - If "by_unit", locations are returned as a list (for segments) of dictionaries (for units) """ + + extension_name = "spike_locations" + depend_on = ["fast_templates|templates", ] + need_recording = True + use_nodepipeline = True + + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) + + extremum_channel_inds = get_template_extremum_channel(self.sorting_result, outputs="index") + self.spikes = self.sorting_result.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + + def _set_params( + self, + ms_before=0.5, + ms_after=0.5, + spike_retriver_kwargs=None, + method="center_of_mass", + method_kwargs={}, + ): + spike_retriver_kwargs_ = dict( + channel_from_template=True, + radius_um=50, + peak_sign="neg", ) - slc.run(**job_kwargs) + if spike_retriver_kwargs is not None: + spike_retriver_kwargs_.update(spike_retriver_kwargs) + params = dict( + ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs_, method=method, + method_kwargs=method_kwargs + ) + return params + + def _select_extension_data(self, unit_ids): + old_unit_ids = self.sorting_result.unit_ids + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) + + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) + new_spike_locations = self.data["spike_locations"][spike_mask] + return dict(spike_locations=new_spike_locations) + + def _get_pipeline_nodes(self): + from spikeinterface.sortingcomponents.peak_localization import get_localization_pipeline_nodes - locs = slc.get_data(outputs=outputs) - return locs + recording = self.sorting_result.recording + sorting = self.sorting_result.sorting + peak_sign=self.params["spike_retriver_kwargs"]["peak_sign"] + extremum_channels_indices = get_template_extremum_channel(self.sorting_result, peak_sign=peak_sign, outputs="index") + retriever = SpikeRetriever( + recording, + sorting, + channel_from_template=True, + extremum_channel_inds=extremum_channels_indices, + ) + nodes = get_localization_pipeline_nodes( + recording, retriever, method=self.params["method"], ms_before=self.params["ms_before"], ms_after=self.params["ms_after"], **self.params["method_kwargs"] + ) + return nodes -compute_spike_locations.__doc__.format(_shared_job_kwargs_doc) + def _run(self, **job_kwargs): + # TODO later gather to disk when format="binary_folder" + job_kwargs = fix_job_kwargs(job_kwargs) + nodes = self.get_pipeline_nodes() + spike_locations = run_node_pipeline( + self.sorting_result.recording, nodes, job_kwargs=job_kwargs, job_name="spike_locations", gather_mode="memory" + ) + self.data["spike_locations"] = spike_locations + + # def get_data(self, outputs="concatenated"): + # """ + # Get computed spike locations + + # Parameters + # ---------- + # outputs : "concatenated" | "by_unit", default: "concatenated" + # The output format + + # Returns + # ------- + # spike_locations : np.array or dict + # The spike locations as a structured array (outputs="concatenated") or + # as a dict with units as key and spike locations as values. + # """ + # we = self.sorting_result + # sorting = we.sorting + + # if outputs == "concatenated": + # return self._extension_data["spike_locations"] + + # elif outputs == "by_unit": + # locations_by_unit = [] + # for segment_index in range(self.sorting_result.get_num_segments()): + # i0 = np.searchsorted(self.spikes["segment_index"], segment_index, side="left") + # i1 = np.searchsorted(self.spikes["segment_index"], segment_index, side="right") + # spikes = self.spikes[i0:i1] + # locations = self._extension_data["spike_locations"][i0:i1] + + # locations_by_unit.append({}) + # for unit_ind, unit_id in enumerate(sorting.unit_ids): + # mask = spikes["unit_index"] == unit_ind + # locations_by_unit[segment_index][unit_id] = locations[mask] + # return locations_by_unit + + +ComputeSpikeLocations.__doc__.format(_shared_job_kwargs_doc) + +register_result_extension(ComputeSpikeLocations) +compute_spike_locations = ComputeSpikeLocations.function_factory() diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index 86cff185ed..c4c8bb7974 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -8,15 +8,14 @@ class ComputeSpikeAmplitudesTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeSpikeAmplitudes extension_function_kwargs_list = [ - dict(), + dict(return_scaled=True), + dict(return_scaled=False), ] if __name__ == "__main__": test = ComputeSpikeAmplitudesTest() test.setUp() test.test_extension() - # test.test_scaled() - # test.test_parallel() # for k, sorting_result in test.sorting_results.items(): # print(sorting_result) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index d047a2f67e..98b8d19c2b 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -1,39 +1,25 @@ import unittest import numpy as np -from spikeinterface.postprocessing import SpikeLocationsCalculator +from spikeinterface.postprocessing import ComputeSpikeLocations +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite -class SpikeLocationsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = SpikeLocationsCalculator - extension_data_names = ["spike_locations"] + +class SpikeLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputeSpikeLocations extension_function_kwargs_list = [ - dict( - method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=True) - ), - dict( - method="center_of_mass", chunk_size=10000, n_jobs=1, spike_retriver_kwargs=dict(channel_from_template=False) - ), - dict(method="center_of_mass", chunk_size=10000, n_jobs=1, outputs="by_unit"), - dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), - dict(method="monopolar_triangulation", chunk_size=10000, n_jobs=1, outputs="by_unit"), + dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)), # chunk_size=10000, n_jobs=1, + dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), + dict(method="center_of_mass", ), + dict(method="monopolar_triangulation"), # , chunk_size=10000, n_jobs=1 + dict(method="grid_convolution"), # , chunk_size=10000, n_jobs=1 ] - def test_parallel(self): - locs_mono1 = self.extension_class.get_extension_function()( - self.we1, method="monopolar_triangulation", chunk_size=10000, n_jobs=1 - ) - locs_mono2 = self.extension_class.get_extension_function()( - self.we1, method="monopolar_triangulation", chunk_size=10000, n_jobs=2 - ) - - assert np.array_equal(locs_mono1[0], locs_mono2[0]) if __name__ == "__main__": test = SpikeLocationsExtensionTest() test.setUp() test.test_extension() - test.test_parallel() diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index 75c8f7f03f..c13517f80f 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -28,7 +28,7 @@ from .tools import get_prototype_spike -def _run_localization_from_peak_source( +def get_localization_pipeline_nodes( recording, peak_source, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs ): # use by localize_peaks() and compute_spike_locations() @@ -73,10 +73,7 @@ def _run_localization_from_peak_source( LocalizeGridConvolution(recording, parents=[peak_source, extract_dense_waveforms], **method_kwargs), ] - job_name = f"localize peaks using {method}" - peak_locations = run_node_pipeline(recording, pipeline_nodes, job_kwargs, job_name=job_name, squeeze_output=True) - - return peak_locations + return pipeline_nodes def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): @@ -104,10 +101,14 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ Array with estimated location for each spike. The dtype depends on the method. ("x", "y") or ("x", "y", "z", "alpha"). """ + _, job_kwargs = split_job_kwargs(kwargs) peak_retriever = PeakRetriever(recording, peaks) - peak_locations = _run_localization_from_peak_source( + pipeline_nodes = get_localization_pipeline_nodes( recording, peak_retriever, method=method, ms_before=ms_before, ms_after=ms_after, **kwargs ) + job_name = f"localize peaks using {method}" + peak_locations = run_node_pipeline(recording, pipeline_nodes, job_kwargs, job_name=job_name, squeeze_output=True) + return peak_locations From f2928cc37aaad8dd5c363828ad9d848beeed4639 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 25 Jan 2024 17:55:10 +0100 Subject: [PATCH 029/192] Make amplitude scaling with nodepipeline based on Alssio PR. --- src/spikeinterface/core/node_pipeline.py | 33 +- .../core/tests/test_node_pipeline.py | 3 +- src/spikeinterface/postprocessing/__init__.py | 2 +- .../postprocessing/amplitude_scalings.py | 686 ++++++++++-------- .../tests/test_amplitude_scalings.py | 59 +- 5 files changed, 419 insertions(+), 364 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index faef731e34..5ec5553a4b 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -166,10 +166,13 @@ class SpikeRetriever(PeakSource): peak_sign: "neg" | "pos", default: "neg" Peak sign to find the max channel. Used only when channel_from_template=False + include_spikes_in_margin: bool, default False + If not None then spikes in margin are added and an extra filed in dtype is added """ def __init__( - self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg" + self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg", + include_spikes_in_margin=False, ): PipelineNode.__init__(self, recording, return_output=False) @@ -177,7 +180,13 @@ def __init__( assert extremum_channel_inds is not None, "SpikeRetriever needs the extremum_channel_inds dictionary" - self.peaks = sorting_to_peaks(sorting, extremum_channel_inds) + self._dtype = spike_peak_dtype + + self.include_spikes_in_margin = include_spikes_in_margin + if include_spikes_in_margin is not None: + self._dtype = spike_peak_dtype + [("in_margin", "bool")] + + self.peaks = sorting_to_peaks(sorting, extremum_channel_inds, self._dtype) if not channel_from_template: channel_distance = get_channel_distances(recording) @@ -189,24 +198,36 @@ def __init__( for segment_index in range(recording.get_num_segments()): i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) + def get_trace_margin(self): return 0 def get_dtype(self): - return spike_peak_dtype + return self._dtype def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + if self.include_spikes_in_margin: + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin]) + else: + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces local_peaks = local_peaks.copy() local_peaks["sample_index"] -= start_frame - max_margin + # handle flag for margin + if self.include_spikes_in_margin: + local_peaks["in_margin"][:] = False + mask = local_peaks["sample_index"] < max_margin + local_peaks["in_margin"][mask] = True + mask = local_peaks["sample_index"] >= traces.shape[0] - max_margin + local_peaks["in_margin"][mask] = True + if not self.channel_from_template: # handle channel spike per spike for i, peak in enumerate(local_peaks): @@ -226,9 +247,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) -def sorting_to_peaks(sorting, extremum_channel_inds): +def sorting_to_peaks(sorting, extremum_channel_inds, dtype): spikes = sorting.to_spike_vector() - peaks = np.zeros(spikes.size, dtype=spike_peak_dtype) + peaks = np.zeros(spikes.size, dtype=dtype) peaks["sample_index"] = spikes["sample_index"] extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]] diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index ca30a5f8c9..e3a793b2d5 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -14,6 +14,7 @@ PipelineNode, ExtractDenseWaveforms, sorting_to_peaks, + spike_peak_dtype ) @@ -78,7 +79,7 @@ def test_run_node_pipeline(): # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - peaks = sorting_to_peaks(sorting, extremum_channel_inds) + peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) peak_retriever = PeakRetriever(recording, peaks) # channel index is from template diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index c23e56355b..98b6ec6cee 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -42,7 +42,7 @@ compute_center_of_mass, ) -from .amplitude_scalings import compute_amplitude_scalings, AmplitudeScalingsCalculator +from .amplitude_scalings import compute_amplitude_scalings, ComputeAmplitudeScalings from .alignsorting import align_sorting, AlignSortingExtractor diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 2aaf4d20b9..331c122929 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -3,38 +3,81 @@ from spikeinterface.core import ChannelSparsity, get_chunk_with_margin from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs -from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift -from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension + +from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type + +from ..core.template_tools import _get_dense_templates_array, _get_nbefore # DEBUG = True -class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): +# TODO extra sparsity and job_kwargs handling + +class ComputeAmplitudeScalings(ResultExtension): """ - Computes amplitude scalings from WaveformExtractor. + Computes the amplitude scalings from a WaveformExtractor. + + Parameters + ---------- + sorting_result: SortingResult + A SortingResult object + sparsity: ChannelSparsity or None, default: None + If waveforms are not sparse, sparsity is required if the number of channels is greater than + `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. + max_dense_channels: int, default: 16 + Maximum number of channels to allow running without sparsity. To compute amplitude scaling using + dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. + ms_before : float or None, default: None + The cut out to apply before the spike peak to extract local waveforms. + If None, the WaveformExtractor ms_before is used. + ms_after : float or None, default: None + The cut out to apply after the spike peak to extract local waveforms. + If None, the WaveformExtractor ms_after is used. + handle_collisions: bool, default: True + Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes + (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a + multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. + delta_collision_ms: float, default: 2 + The maximum time difference in ms before and after a spike to gather colliding spikes. + load_if_exists : bool, default: False + Whether to load precomputed spike amplitudes, if they already exist. + outputs: "concatenated" | "by_unit", default: "concatenated" + How the output should be returned + {} + + Returns + ------- + amplitude_scalings: np.array or list of dict + The amplitude scalings. + - If "concatenated" all amplitudes for all spikes and all units are concatenated + - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) """ extension_name = "amplitude_scalings" - handle_sparsity = True + depend_on = ["fast_templates|templates", ] + need_recording = True + use_nodepipeline = True - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) - extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") - self.spikes = self.waveform_extractor.sorting.to_spike_vector( - extremum_channel_inds=extremum_channel_inds, use_cache=False - ) + # extremum_channel_inds = get_template_extremum_channel(self.sorting_result, outputs="index") + # self.spikes = self.sorting_result.sorting.to_spike_vector( + # extremum_channel_inds=extremum_channel_inds, use_cache=False + # ) self.collisions = None def _set_params( self, - sparsity, - max_dense_channels, - ms_before, - ms_after, - handle_collisions, - delta_collision_ms, + sparsity=None, + max_dense_channels=16, + ms_before=None, + ms_after=None, + handle_collisions=True, + delta_collision_ms=2, ): params = dict( sparsity=sparsity, @@ -47,329 +90,339 @@ def _set_params( return params def _select_extension_data(self, unit_ids): - old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) - spike_mask = np.isin(self.spikes["unit_index"], unit_inds) - new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] - return dict(amplitude_scalings=new_amplitude_scalings) + spikes = self.sorting_result.sorting.to_spike_vector() + keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) - def _run(self, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - we = self.waveform_extractor - recording = we.recording - nbefore = we.nbefore - nafter = we.nafter - ms_before = self._params["ms_before"] - ms_after = self._params["ms_after"] + new_data = dict() + new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_spike_mask] + if self.params["handle_collisions"]: + new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask] + return new_data - # collisions - handle_collisions = self._params["handle_collisions"] - delta_collision_ms = self._params["delta_collision_ms"] - delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) - return_scaled = we._params["return_scaled"] + def _get_pipeline_nodes(self): - if ms_before is not None: + recording = self.sorting_result.recording + sorting = self.sorting_result.sorting + + # TODO return_scaled is not any more a property of SortingResult this is hard coded for now + return_scaled = True + + all_templates = _get_dense_templates_array(self.sorting_result, return_scaled=return_scaled) + nbefore = _get_nbefore(self.sorting_result) + nafter = all_templates.shape[1] - nbefore + + # if ms_before / ms_after are set in params then the original templates are shorten + if self.params["ms_before"] is not None: + cut_out_before = int(self.params["ms_before"] * self.sorting_result.sampling_frequency / 1000.0) assert ( - ms_before <= we._params["ms_before"] - ), f"`ms_before` must be smaller than `ms_before` used in WaveformExractor: {we._params['ms_before']}" - if ms_after is not None: + cut_out_before <= nbefore + ), f"`ms_before` must be smaller than `ms_before` used in ComputeTemplates: {nbefore}" + else: + cut_out_before = nbefore + + if self.params["ms_after"] is not None: + cut_out_after = int(self.params["ms_after"] * self.sorting_result.sampling_frequency / 1000.0) assert ( - ms_after <= we._params["ms_after"] + cut_out_after <= nafter ), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}" + else: + cut_out_after = nafter - cut_out_before = int(ms_before / 1000 * we.sampling_frequency) if ms_before is not None else nbefore - cut_out_after = int(ms_after / 1000 * we.sampling_frequency) if ms_after is not None else nafter + peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" + extremum_channels_indices = get_template_extremum_channel(self.sorting_result, peak_sign=peak_sign, outputs="index") - if we.is_sparse() and self._params["sparsity"] is None: - sparsity = we.sparsity - elif we.is_sparse() and self._params["sparsity"] is not None: - sparsity = self._params["sparsity"] + # collisions + handle_collisions = self.params["handle_collisions"] + delta_collision_ms = self.params["delta_collision_ms"] + delta_collision_samples = int(delta_collision_ms / 1000 * self.sorting_result.sampling_frequency) + + if self.sorting_result.is_sparse() and self.params["sparsity"] is None: + sparsity = self.sorting_result.sparsity + elif self.sorting_result.is_sparse() and self.params["sparsity"] is not None: + raise NotImplementedError + sparsity = self.params["sparsity"] # assert provided sparsity is sparser than the one in the waveform extractor waveform_sparsity = we.sparsity assert np.all( np.sum(waveform_sparsity.mask, 1) - np.sum(sparsity.mask, 1) > 0 ), "The provided sparsity needs to be sparser than the one in the waveform extractor!" - elif not we.is_sparse() and self._params["sparsity"] is not None: - sparsity = self._params["sparsity"] + elif not self.sorting_result.is_sparse() and self.params["sparsity"] is not None: + raise NotImplementedError + # sparsity = self.params["sparsity"] else: - if self._params["max_dense_channels"] is not None: - assert recording.get_num_channels() <= self._params["max_dense_channels"], "" - sparsity = ChannelSparsity.create_dense(we) + if self.params["max_dense_channels"] is not None: + assert recording.get_num_channels() <= self.params["max_dense_channels"], "" + sparsity = ChannelSparsity.create_dense(self.sorting_result) sparsity_mask = sparsity.mask - all_templates = we.get_all_templates() - - # precompute segment slice - segment_slices = [] - for segment_index in range(we.get_num_segments()): - i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1]) - segment_slices.append(slice(i0, i1)) - - # and run - func = _amplitude_scalings_chunk - init_func = _init_worker_amplitude_scalings - n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) - job_kwargs["n_jobs"] = n_jobs - init_args = ( + + spike_retriever_node = SpikeRetriever( recording, - self.spikes, - all_templates, - segment_slices, - sparsity_mask, - nbefore, - nafter, - cut_out_before, - cut_out_after, - return_scaled, - handle_collisions, - delta_collision_samples, + sorting, + channel_from_template=True, + extremum_channel_inds=extremum_channels_indices, + include_spikes_in_margin=True, ) - processor = ChunkRecordingExecutor( + amplitude_scalings_node = AmplitudeScalingNode( recording, - func, - init_func, - init_args, - handle_returns=True, - job_name="extract amplitude scalings", - **job_kwargs, + parents=[spike_retriever_node], + return_output=True, + all_templates=all_templates, + sparsity_mask=sparsity_mask, + nbefore=nbefore, + nafter=nafter, + cut_out_before=cut_out_before, + cut_out_after=cut_out_after, + return_scaled=return_scaled, + handle_collisions=handle_collisions, + delta_collision_samples=delta_collision_samples, ) - out = processor.run() - (amp_scalings, collisions) = zip(*out) - amp_scalings = np.concatenate(amp_scalings) + nodes = [spike_retriever_node, amplitude_scalings_node] + return nodes - collisions_dict = {} - if handle_collisions: - for collision in collisions: - collisions_dict.update(collision) - self.collisions = collisions_dict - # Note: collisions are note in _extension_data because they are not pickable. We only store the indices - self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) - - self._extension_data["amplitude_scalings"] = amp_scalings - - def get_data(self, outputs="concatenated"): - """ - Get computed spike amplitudes. - Parameters - ---------- - outputs : "concatenated" | "by_unit", default: "concatenated" - The output format - - Returns - ------- - spike_amplitudes : np.array or dict - The spike amplitudes as an array (outputs="concatenated") or - as a dict with units as key and spike amplitudes as values. - """ - we = self.waveform_extractor - sorting = we.sorting - - if outputs == "concatenated": - return self._extension_data[f"amplitude_scalings"] - elif outputs == "by_unit": - amplitudes_by_unit = [] - for segment_index in range(we.get_num_segments()): - amplitudes_by_unit.append({}) - segment_mask = self.spikes["segment_index"] == segment_index - spikes_segment = self.spikes[segment_mask] - amp_scalings_segment = self._extension_data[f"amplitude_scalings"][segment_mask] - for unit_index, unit_id in enumerate(sorting.unit_ids): - unit_mask = spikes_segment["unit_index"] == unit_index - amp_scalings = amp_scalings_segment[unit_mask] - amplitudes_by_unit[segment_index][unit_id] = amp_scalings - return amplitudes_by_unit - - @staticmethod - def get_extension_function(): - return compute_amplitude_scalings - - -WaveformExtractor.register_extension(AmplitudeScalingsCalculator) - - -def compute_amplitude_scalings( - waveform_extractor, - sparsity=None, - max_dense_channels=16, - ms_before=None, - ms_after=None, - handle_collisions=True, - delta_collision_ms=2, - load_if_exists=False, - outputs="concatenated", - **job_kwargs, -): - """ - Computes the amplitude scalings from a WaveformExtractor. + def _run(self, **job_kwargs): + job_kwargs = fix_job_kwargs(job_kwargs) + nodes = self.get_pipeline_nodes() + amp_scalings, collision_mask = run_node_pipeline( + self.sorting_result.recording, nodes, job_kwargs=job_kwargs, job_name="amplitude_scalings", gather_mode="memory" + ) + self.data["amplitude_scalings"] = amp_scalings + if self.params["handle_collisions"]: + self.data["collision_mask"] = collision_mask + # TODO: make collisions "global" + # for collision in collisions: + # collisions_dict.update(collision) + # self.collisions = collisions_dict + # # Note: collisions are note in _extension_data because they are not pickable. We only store the indices + # self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) + + # def get_data(self, outputs="concatenated"): + # """ + # Get computed spike amplitudes. + # Parameters + # ---------- + # outputs : "concatenated" | "by_unit", default: "concatenated" + # The output format + + # Returns + # ------- + # spike_amplitudes : np.array or dict + # The spike amplitudes as an array (outputs="concatenated") or + # as a dict with units as key and spike amplitudes as values. + # """ + # we = self.sorting_result + # sorting = we.sorting + + # if outputs == "concatenated": + # return self._extension_data[f"amplitude_scalings"] + # elif outputs == "by_unit": + # amplitudes_by_unit = [] + # for segment_index in range(we.get_num_segments()): + # amplitudes_by_unit.append({}) + # segment_mask = self.spikes["segment_index"] == segment_index + # spikes_segment = self.spikes[segment_mask] + # amp_scalings_segment = self._extension_data[f"amplitude_scalings"][segment_mask] + # for unit_index, unit_id in enumerate(sorting.unit_ids): + # unit_mask = spikes_segment["unit_index"] == unit_index + # amp_scalings = amp_scalings_segment[unit_mask] + # amplitudes_by_unit[segment_index][unit_id] = amp_scalings + # return amplitudes_by_unit + + +register_result_extension(ComputeAmplitudeScalings) +compute_amplitude_scalings = ComputeAmplitudeScalings.function_factory() + +# def compute_amplitude_scalings( +# waveform_extractor, +# sparsity=None, +# max_dense_channels=16, +# ms_before=None, +# ms_after=None, +# handle_collisions=True, +# delta_collision_ms=2, +# load_if_exists=False, +# outputs="concatenated", +# **job_kwargs, +# ): +# """ +# Computes the amplitude scalings from a WaveformExtractor. - Parameters - ---------- - waveform_extractor: WaveformExtractor - The waveform extractor object - sparsity: ChannelSparsity or None, default: None - If waveforms are not sparse, sparsity is required if the number of channels is greater than - `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. - max_dense_channels: int, default: 16 - Maximum number of channels to allow running without sparsity. To compute amplitude scaling using - dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. - ms_before : float or None, default: None - The cut out to apply before the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_before is used. - ms_after : float or None, default: None - The cut out to apply after the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_after is used. - handle_collisions: bool, default: True - Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes - (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a - multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. - delta_collision_ms: float, default: 2 - The maximum time difference in ms before and after a spike to gather colliding spikes. - load_if_exists : bool, default: False - Whether to load precomputed spike amplitudes, if they already exist. - outputs: "concatenated" | "by_unit", default: "concatenated" - How the output should be returned - {} +# Parameters +# ---------- +# waveform_extractor: WaveformExtractor +# The waveform extractor object +# sparsity: ChannelSparsity or None, default: None +# If waveforms are not sparse, sparsity is required if the number of channels is greater than +# `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. +# max_dense_channels: int, default: 16 +# Maximum number of channels to allow running without sparsity. To compute amplitude scaling using +# dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. +# ms_before : float or None, default: None +# The cut out to apply before the spike peak to extract local waveforms. +# If None, the WaveformExtractor ms_before is used. +# ms_after : float or None, default: None +# The cut out to apply after the spike peak to extract local waveforms. +# If None, the WaveformExtractor ms_after is used. +# handle_collisions: bool, default: True +# Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes +# (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a +# multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. +# delta_collision_ms: float, default: 2 +# The maximum time difference in ms before and after a spike to gather colliding spikes. +# load_if_exists : bool, default: False +# Whether to load precomputed spike amplitudes, if they already exist. +# outputs: "concatenated" | "by_unit", default: "concatenated" +# How the output should be returned +# {} + +# Returns +# ------- +# amplitude_scalings: np.array or list of dict +# The amplitude scalings. +# - If "concatenated" all amplitudes for all spikes and all units are concatenated +# - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) +# """ +# if load_if_exists and waveform_extractor.is_extension(AmplitudeScalingsCalculator.extension_name): +# sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) +# else: +# sac = AmplitudeScalingsCalculator(waveform_extractor) +# sac.set_params( +# sparsity=sparsity, +# max_dense_channels=max_dense_channels, +# ms_before=ms_before, +# ms_after=ms_after, +# handle_collisions=handle_collisions, +# delta_collision_ms=delta_collision_ms, +# ) +# sac.run(**job_kwargs) - Returns - ------- - amplitude_scalings: np.array or list of dict - The amplitude scalings. - - If "concatenated" all amplitudes for all spikes and all units are concatenated - - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) - """ - if load_if_exists and waveform_extractor.is_extension(AmplitudeScalingsCalculator.extension_name): - sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) - else: - sac = AmplitudeScalingsCalculator(waveform_extractor) - sac.set_params( - sparsity=sparsity, - max_dense_channels=max_dense_channels, - ms_before=ms_before, - ms_after=ms_after, +# amps = sac.get_data(outputs=outputs) +# return amps + + +# compute_amplitude_scalings.__doc__.format(_shared_job_kwargs_doc) + +class AmplitudeScalingNode(PipelineNode): + def __init__( + self, + recording, + parents, + return_output, + all_templates, + sparsity_mask, + nbefore, + nafter, + cut_out_before, + cut_out_after, + return_scaled, + handle_collisions, + delta_collision_samples, + ): + PipelineNode.__init__(self, recording, parents=parents, return_output=return_output) + self.return_scaled = return_scaled + if return_scaled and recording.has_scaled(): + self._dtype = np.float32 + self._gains = recording.get_channel_gains() + self._offsets = recording.get_channel_gains() + else: + self._dtype = recording.get_dtype() + self._gains = None + self._offsets = None + spike_retriever = find_parent_of_type(parents, SpikeRetriever) + assert isinstance( + spike_retriever, SpikeRetriever + ), "SpikeAmplitudeNode needs a single SpikeRetriever as a parent" + assert spike_retriever.include_spikes_in_margin, "Need SpikeRetriever with include_spikes_in_margin=True" + if not handle_collisions: + self._margin = max(nbefore, nafter) + else: + # in this case we extend the margin to be able to get with collisions outside the chunk + margin_waveforms = max(nbefore, nafter) + max_margin_collisions = delta_collision_samples + margin_waveforms + self._margin = max_margin_collisions + + self._all_templates = all_templates + self._sparsity_mask = sparsity_mask + self._nbefore = nbefore + self._nafter = nafter + self._cut_out_before = cut_out_before + self._cut_out_after = cut_out_after + self._handle_collisions = handle_collisions + self._delta_collision_samples = delta_collision_samples + + self._kwargs.update( + all_templates=all_templates, + sparsity_mask=sparsity_mask, + nbefore=nbefore, + nafter=nafter, + cut_out_before=cut_out_before, + cut_out_after=cut_out_after, + return_scaled=return_scaled, handle_collisions=handle_collisions, - delta_collision_ms=delta_collision_ms, + delta_collision_samples=delta_collision_samples, ) - sac.run(**job_kwargs) - amps = sac.get_data(outputs=outputs) - return amps + def get_dtype(self): + return self._dtype + def compute(self, traces, peaks): + from scipy.stats import linregress -compute_amplitude_scalings.__doc__.format(_shared_job_kwargs_doc) + # scale traces with margin to match scaling of templates + if self._gains is not None: + traces = traces.astype("float32") * self._gains + self._offsets + + all_templates = self._all_templates + sparsity_mask = self._sparsity_mask + nbefore = self._nbefore + cut_out_before = self._cut_out_before + cut_out_after = self._cut_out_after + handle_collisions = self._handle_collisions + delta_collision_samples = self._delta_collision_samples + + # local_spikes_w_margin = peaks + # i0 = np.searchsorted(local_spikes_w_margin["sample_index"], left_margin) + # i1 = np.searchsorted(local_spikes_w_margin["sample_index"], traces.shape[0] - right_margin) + # local_spikes = local_spikes_w_margin[i0:i1] + + local_spikes_w_margin = peaks + local_spikes = local_spikes_w_margin[~peaks["in_margin"]] -def _init_worker_amplitude_scalings( - recording, - spikes, - all_templates, - segment_slices, - sparsity_mask, - nbefore, - nafter, - cut_out_before, - cut_out_after, - return_scaled, - handle_collisions, - delta_collision_samples, -): - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["spikes"] = spikes - worker_ctx["all_templates"] = all_templates - worker_ctx["segment_slices"] = segment_slices - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["cut_out_before"] = cut_out_before - worker_ctx["cut_out_after"] = cut_out_after - worker_ctx["return_scaled"] = return_scaled - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["handle_collisions"] = handle_collisions - worker_ctx["delta_collision_samples"] = delta_collision_samples - - if not handle_collisions: - worker_ctx["margin"] = max(nbefore, nafter) - else: - # in this case we extend the margin to be able to get with collisions outside the chunk - margin_waveforms = max(nbefore, nafter) - max_margin_collisions = delta_collision_samples + margin_waveforms - worker_ctx["margin"] = max_margin_collisions - - return worker_ctx - - -def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx): - # from sklearn.linear_model import LinearRegression - from scipy.stats import linregress - - # recover variables of the worker - spikes = worker_ctx["spikes"] - recording = worker_ctx["recording"] - all_templates = worker_ctx["all_templates"] - segment_slices = worker_ctx["segment_slices"] - sparsity_mask = worker_ctx["sparsity_mask"] - nbefore = worker_ctx["nbefore"] - cut_out_before = worker_ctx["cut_out_before"] - cut_out_after = worker_ctx["cut_out_after"] - margin = worker_ctx["margin"] - return_scaled = worker_ctx["return_scaled"] - handle_collisions = worker_ctx["handle_collisions"] - delta_collision_samples = worker_ctx["delta_collision_samples"] - - spikes_in_segment = spikes[segment_slices[segment_index]] - - i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) - - if i0 != i1: - local_spikes = spikes_in_segment[i0:i1] - traces_with_margin, left, right = get_chunk_with_margin( - recording._recording_segments[segment_index], start_frame, end_frame, channel_indices=None, margin=margin - ) - # scale traces with margin to match scaling of templates - if return_scaled and recording.has_scaled(): - gains = recording.get_property("gain_to_uV") - offsets = recording.get_property("offset_to_uV") - traces_with_margin = traces_with_margin.astype("float32") * gains + offsets # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! - i0_margin, i1_margin = np.searchsorted( - spikes_in_segment["sample_index"], [start_frame - left, end_frame + right] - ) - local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] - collisions_local = find_collisions( - local_spikes, local_spikes_w_margin, delta_collision_samples, sparsity_mask - ) + collisions = find_collisions(local_spikes, local_spikes_w_margin, delta_collision_samples, sparsity_mask) else: - collisions_local = {} + collisions = {} # compute the scaling for each spike scalings = np.zeros(len(local_spikes), dtype=float) - # collision_global transforms local spike index to global spike index - collisions_global = {} + spike_collision_mask = np.zeros(len(local_spikes), dtype=bool) + for spike_index, spike in enumerate(local_spikes): - if spike_index in collisions_local.keys(): + if spike_index in collisions.keys(): # we deal with overlapping spikes later continue unit_index = spike["unit_index"] - sample_index = spike["sample_index"] + sample_centered = spike["sample_index"] (sparse_indices,) = np.nonzero(sparsity_mask[unit_index]) template = all_templates[unit_index][:, sparse_indices] template = template[nbefore - cut_out_before : nbefore + cut_out_after] - sample_centered = sample_index - start_frame - cut_out_start = left + sample_centered - cut_out_before - cut_out_end = left + sample_centered + cut_out_after - if sample_index - cut_out_before < 0: - local_waveform = traces_with_margin[:cut_out_end, sparse_indices] - template = template[cut_out_before - sample_index :] - elif sample_index + cut_out_after > end_frame + right: - local_waveform = traces_with_margin[cut_out_start:, sparse_indices] - template = template[: -(sample_index + cut_out_after - (end_frame + right))] + cut_out_start = sample_centered - cut_out_before + cut_out_end = sample_centered + cut_out_after + if sample_centered - cut_out_before < 0: + local_waveform = traces[:cut_out_end, sparse_indices] + template = template[cut_out_before - sample_centered :] + elif sample_centered + cut_out_after > traces.shape[0]: + local_waveform = traces[cut_out_start:, sparse_indices] + template = template[: -(sample_centered + cut_out_after - (traces.shape[0]))] else: - local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] + local_waveform = traces[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape # here we use linregress, which is equivalent to using sklearn LinearRegression with fit_intercept=True @@ -377,22 +430,23 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # X = template.flatten()[:, np.newaxis] # reg = LinearRegression(positive=True, fit_intercept=True).fit(X, y) # scalings[spike_index] = reg.coef_[0] + + # closed form: W = (X' * X)^-1 X' y + # y = local_waveform.flatten()[:, None] + # X = np.ones((len(y), 2)) + # X[:, 0] = template.flatten() + # W = np.linalg.inv(X.T @ X) @ X.T @ y + # scalings[spike_index] = W[0, 0] + linregress_res = linregress(template.flatten(), local_waveform.flatten()) scalings[spike_index] = linregress_res[0] # deal with collisions - if len(collisions_local) > 0: - num_spikes_in_previous_segments = int( - np.sum([len(spikes[segment_slices[s]]) for s in range(segment_index)]) - ) - for spike_index, collision in collisions_local.items(): + if len(collisions) > 0: + for spike_index, collision in collisions.items(): scaled_amps = fit_collision( collision, - traces_with_margin, - start_frame, - end_frame, - left, - right, + traces, nbefore, all_templates, sparsity_mask, @@ -401,14 +455,16 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) ) # the scaling for the current spike is at index 0 scalings[spike_index] = scaled_amps[0] + spike_collision_mask[spike_index] = True + + # TODO: switch to collision mask and return that (to use concatenation) + return (scalings, spike_collision_mask) + + def get_trace_margin(self): + return self._margin + - # make collision_dict indices "absolute" by adding i0 and the cumulative number of spikes in previous segments - collisions_global.update({spike_index + i0 + num_spikes_in_previous_segments: collision}) - else: - scalings = np.array([]) - collisions_global = {} - return (scalings, collisions_global) ### Collision handling ### @@ -494,10 +550,6 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, sparsity_m def fit_collision( collision, traces_with_margin, - start_frame, - end_frame, - left, - right, nbefore, all_templates, sparsity_mask, @@ -542,8 +594,8 @@ def fit_collision( from sklearn.linear_model import LinearRegression # make center of the spike externally - sample_first_centered = np.min(collision["sample_index"]) - (start_frame - left) - sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) + sample_first_centered = np.min(collision["sample_index"]) + sample_last_centered = np.max(collision["sample_index"]) # construct sparsity as union between units' sparsity common_sparse_mask = np.zeros(sparsity_mask.shape[1], dtype="int") @@ -562,7 +614,7 @@ def fit_collision( for i, spike in enumerate(collision): full_template = np.zeros_like(local_waveform) # center wrt cutout traces - sample_centered = spike["sample_index"] - (start_frame - left) - local_waveform_start + sample_centered = spike["sample_index"] - local_waveform_start template = all_templates[spike["unit_index"]][:, sparse_indices] template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] # deal with borders @@ -590,9 +642,9 @@ def fit_collision( # ---------- # we : WaveformExtractor # The WaveformExtractor object. -# sparsity : ChannelSparsity, default: None +# sparsity : ChannelSparsity, default=None # The ChannelSparsity. If None, only main channels are plotted. -# num_collisions : int, default: None +# num_collisions : int, default=None # Number of collisions to plot. If None, all collisions are plotted. # """ # assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index 4fac98078f..b81b905efe 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -1,51 +1,33 @@ import unittest import numpy as np -from spikeinterface import compute_sparsity -from spikeinterface.postprocessing import AmplitudeScalingsCalculator -from spikeinterface.postprocessing.tests.common_extension_tests import ( - WaveformExtensionCommonTestSuite, -) +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing import ComputeAmplitudeScalings -class AmplitudeScalingsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = AmplitudeScalingsCalculator - extension_data_names = ["amplitude_scalings"] + + +class AmplitudeScalingsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputeAmplitudeScalings extension_function_kwargs_list = [ - dict(outputs="concatenated", chunk_size=10000, n_jobs=1), - dict(outputs="concatenated", chunk_size=10000, n_jobs=1, ms_before=0.5, ms_after=0.5), - dict(outputs="by_unit", chunk_size=10000, n_jobs=1), - dict(outputs="concatenated", chunk_size=10000, n_jobs=-1), - dict(outputs="concatenated", chunk_size=10000, n_jobs=2, ms_before=0.5, ms_after=0.5), + dict(), + dict(ms_before=0.5, ms_after=0.5), ] - def test_scaling_parallel(self): - scalings1 = self.extension_class.get_extension_function()( - self.we1, - outputs="concatenated", - chunk_size=10000, - n_jobs=1, - ) - scalings2 = self.extension_class.get_extension_function()( - self.we1, - outputs="concatenated", - chunk_size=10000, - n_jobs=2, - ) - np.testing.assert_array_equal(scalings1, scalings2) - def test_scaling_values(self): - scalings1 = self.extension_class.get_extension_function()( - self.we1, - outputs="by_unit", - chunk_size=10000, - n_jobs=1, - ) - # since this is GT spikes, the rounded median must be 1 - for u, scalings in scalings1[0].items(): + key0 = next(iter(self.sorting_results.keys())) + sorting_result = self.sorting_results[key0] + + spikes = sorting_result.sorting.to_spike_vector() + + ext = sorting_result.get_extension("amplitude_scalings") + ext.data["amplitude_scalings"] + for unit_index, unit_id in enumerate(sorting_result.unit_ids): + mask = spikes["unit_index"] == unit_index + scalings = ext.data["amplitude_scalings"][mask] median_scaling = np.median(scalings) - print(u, median_scaling) + print(unit_index, median_scaling) np.testing.assert_array_equal(np.round(median_scaling), 1) @@ -53,5 +35,4 @@ def test_scaling_values(self): test = AmplitudeScalingsExtensionTest() test.setUp() test.test_extension() - # test.test_scaling_values() - # test.test_scaling_parallel() + test.test_scaling_values() From 84b6b3ec0061e934c9d682b3ff896ff37f691881 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 26 Jan 2024 08:33:39 +0100 Subject: [PATCH 030/192] wip --- .../tests/common_extension_tests.py | 241 +----------------- .../tests/test_amplitude_scalings.py | 16 +- 2 files changed, 13 insertions(+), 244 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 7a45a2f6af..94a56db7e2 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -5,9 +5,6 @@ import platform from pathlib import Path -# from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity -# from spikeinterface.core.generate import generate_ground_truth_recording - from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import start_sorting_result from spikeinterface.core import estimate_sparsity @@ -56,6 +53,8 @@ class ResultExtensionCommonTestSuite: """ Common tests with class approach to compute extension on several cases (3 format x 2 sparsity) + This is done a a list of differents parameters (extension_function_kwargs_list). + This automatically precompute extension dependencies with default params before running computation. This also test the select_units() ability. @@ -73,7 +72,7 @@ def setUp(self): for format in ("memory", "binary_folder", "zarr"): sparsity_ = sparsity if sparse else None sorting_result = get_sorting_result(recording, sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name) - key = f"spare{sparse}_{format}" + key = f"sparse{sparse}_{format}" self.sorting_results[key] = sorting_result def tearDown(self): @@ -116,237 +115,3 @@ def test_extension(self): print() print(self.extension_name, key) self._check_one(sorting_result) - - - -class WaveformExtensionCommonTestSuite: - """ - This class runs common tests for extensions. - """ - - extension_class = None - extension_data_names = [] - extension_function_kwargs_list = None - - # this flag enables us to check that all backends have the same contents - exact_same_content = True - - def _clean_all_folders(self): - for name in ( - "toy_rec_1seg", - "toy_sorting_1seg", - "toy_waveforms_1seg", - "toy_rec_2seg", - "toy_sorting_2seg", - "toy_waveforms_2seg", - "toy_sorting_2seg.zarr", - "toy_sorting_2seg_sparse", - ): - if (cache_folder / name).is_dir(): - shutil.rmtree(cache_folder / name) - - for name in ("toy_waveforms_1seg", "toy_waveforms_2seg", "toy_sorting_2seg_sparse"): - for ext in self.extension_data_names: - folder = self.cache_folder / f"{name}_{ext}_selected" - if folder.exists(): - shutil.rmtree(folder) - - def setUp(self): - self.cache_folder = cache_folder - self._clean_all_folders() - - # 1-segment - recording, sorting = generate_ground_truth_recording( - durations=[10], - sampling_frequency=30000, - num_channels=12, - num_units=10, - dtype="float32", - seed=91, - generate_sorting_kwargs=dict(add_spikes_on_borders=True), - noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), - ) - - # add gains and offsets and save - gain = 0.1 - recording.set_channel_gains(gain) - recording.set_channel_offsets(0) - - recording = recording.save(folder=cache_folder / "toy_rec_1seg") - sorting = sorting.save(folder=cache_folder / "toy_sorting_1seg") - - we1 = extract_waveforms( - recording, - sorting, - cache_folder / "toy_waveforms_1seg", - max_spikes_per_unit=500, - sparse=False, - n_jobs=1, - chunk_size=30000, - overwrite=True, - ) - self.we1 = we1 - self.sparsity1 = compute_sparsity(we1, method="radius", radius_um=50) - - # 2-segments - recording, sorting = generate_ground_truth_recording( - durations=[10, 5], - sampling_frequency=30000, - num_channels=12, - num_units=10, - dtype="float32", - seed=91, - generate_sorting_kwargs=dict(add_spikes_on_borders=True), - noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), - ) - recording.set_channel_gains(gain) - recording.set_channel_offsets(0) - recording = recording.save(folder=cache_folder / "toy_rec_2seg") - sorting = sorting.save(folder=cache_folder / "toy_sorting_2seg") - - we2 = extract_waveforms( - recording, - sorting, - cache_folder / "toy_waveforms_2seg", - max_spikes_per_unit=500, - sparse=False, - n_jobs=1, - chunk_size=30000, - overwrite=True, - ) - self.we2 = we2 - - # make we read-only - if platform.system() != "Windows": - we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" - if not we_ro_folder.is_dir(): - shutil.copytree(we2.folder, we_ro_folder) - # change permissions (R+X) - we_ro_folder.chmod(0o555) - self.we_ro = load_waveforms(we_ro_folder) - - self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30) - we_memory = extract_waveforms( - recording, - sorting, - mode="memory", - sparse=False, - max_spikes_per_unit=500, - n_jobs=1, - chunk_size=30000, - ) - self.we_memory2 = we_memory - - self.we_zarr2 = we_memory.save(folder=cache_folder / "toy_sorting_2seg", overwrite=True, format="zarr") - - # use best channels for PC-concatenated - sparsity = compute_sparsity(we_memory, method="best_channels", num_channels=2) - self.we_sparse = we_memory.save( - folder=cache_folder / "toy_sorting_2seg_sparse", format="binary", sparsity=sparsity, overwrite=True - ) - - def tearDown(self): - # delete object to release memmap - del self.we1, self.we2, self.we_memory2, self.we_zarr2, self.we_sparse - if hasattr(self, "we_ro"): - del self.we_ro - - # allow pytest to delete RO folder - if platform.system() != "Windows": - we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" - we_ro_folder.chmod(0o777) - - self._clean_all_folders() - - def _test_extension_folder(self, we, in_memory=False): - if self.extension_function_kwargs_list is None: - extension_function_kwargs_list = [dict()] - else: - extension_function_kwargs_list = self.extension_function_kwargs_list - for ext_kwargs in extension_function_kwargs_list: - compute_func = self.extension_class.get_extension_function() - _ = compute_func(we, load_if_exists=False, **ext_kwargs) - - # reload as an extension from we - assert self.extension_class.extension_name in we.get_available_extension_names() - assert we.has_extension(self.extension_class.extension_name) - ext = we.load_extension(self.extension_class.extension_name) - assert isinstance(ext, self.extension_class) - for ext_name in self.extension_data_names: - assert ext_name in ext._extension_data - - if not in_memory: - ext_loaded = self.extension_class.load(we.folder, we) - for ext_name in self.extension_data_names: - assert ext_name in ext_loaded._extension_data - - # test select units - # print('test select units', we.format) - if we.format == "binary": - new_folder = cache_folder / f"{we.folder.stem}_{self.extension_class.extension_name}_selected" - if new_folder.is_dir(): - shutil.rmtree(new_folder) - we_new = we.select_units( - unit_ids=we.sorting.unit_ids[::2], - new_folder=new_folder, - ) - # check that extension is present after select_units() - assert self.extension_class.extension_name in we_new.get_available_extension_names() - elif we.folder is None: - # test select units in-memory and zarr - we_new = we.select_units(unit_ids=we.sorting.unit_ids[::2]) - # check that extension is present after select_units() - assert self.extension_class.extension_name in we_new.get_available_extension_names() - if we.format == "zarr": - # select_units() not supported for Zarr - pass - - def test_extension(self): - print("Test extension", self.extension_class) - # 1 segment - print("1 segment", self.we1) - self._test_extension_folder(self.we1) - - # 2 segment - print("2 segment", self.we2) - self._test_extension_folder(self.we2) - # memory - print("Memory", self.we_memory2) - self._test_extension_folder(self.we_memory2, in_memory=True) - # zarr - # @alessio : this need to be fixed the PCA extention do not work wih zarr - print("Zarr", self.we_zarr2) - self._test_extension_folder(self.we_zarr2) - - # sparse - print("Sparse", self.we_sparse) - self._test_extension_folder(self.we_sparse) - - if self.exact_same_content: - # check content is the same across modes: memory/content/zarr - - for ext in self.we2.get_available_extension_names(): - print(f"Testing data for {ext}") - ext_memory = self.we_memory2.load_extension(ext) - ext_folder = self.we2.load_extension(ext) - ext_zarr = self.we_zarr2.load_extension(ext) - - for ext_data_name, ext_data_mem in ext_memory._extension_data.items(): - ext_data_folder = ext_folder._extension_data[ext_data_name] - ext_data_zarr = ext_zarr._extension_data[ext_data_name] - if isinstance(ext_data_mem, np.ndarray): - np.testing.assert_array_equal(ext_data_mem, ext_data_folder) - np.testing.assert_array_equal(ext_data_mem, ext_data_zarr) - elif isinstance(ext_data_mem, pd.DataFrame): - assert ext_data_mem.equals(ext_data_folder) - assert ext_data_mem.equals(ext_data_zarr) - else: - print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.") - - # read-only - Extension is memory only - if platform.system() != "Windows": - _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) - assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() - ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) - assert ext_ro.format == "memory" - assert ext_ro.extension_folder is None diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index b81b905efe..e0abd07edc 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -11,25 +11,29 @@ class AmplitudeScalingsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeAmplitudeScalings extension_function_kwargs_list = [ - dict(), - dict(ms_before=0.5, ms_after=0.5), + dict(handle_collisions=True), + dict(handle_collisions=False), ] def test_scaling_values(self): - key0 = next(iter(self.sorting_results.keys())) - sorting_result = self.sorting_results[key0] + sorting_result = self.sorting_results["sparseTrue_memory"] spikes = sorting_result.sorting.to_spike_vector() ext = sorting_result.get_extension("amplitude_scalings") - ext.data["amplitude_scalings"] + for unit_index, unit_id in enumerate(sorting_result.unit_ids): mask = spikes["unit_index"] == unit_index scalings = ext.data["amplitude_scalings"][mask] median_scaling = np.median(scalings) - print(unit_index, median_scaling) + # print(unit_index, median_scaling) np.testing.assert_array_equal(np.round(median_scaling), 1) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.hist(ext.data["amplitude_scalings"]) + # plt.show() + if __name__ == "__main__": test = AmplitudeScalingsExtensionTest() From 0b90f6b97a3ce3fc665ad04d3ae12ed562b68205 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 26 Jan 2024 09:10:21 +0100 Subject: [PATCH 031/192] start to port principal components --- src/spikeinterface/postprocessing/__init__.py | 2 +- .../postprocessing/principal_component.py | 221 ++++------- .../tests/test_principal_component.py | 369 +++++++++--------- 3 files changed, 275 insertions(+), 317 deletions(-) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 98b6ec6cee..a56457e34c 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -11,7 +11,7 @@ ) from .principal_component import ( - WaveformPrincipalComponent, + ComputePrincipalComponents, compute_principal_components, ) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 9e822a5d1a..8f8fdfb3b0 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -7,65 +7,92 @@ import numpy as np +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs -from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension from spikeinterface.core.globals import get_global_tmp_folder _possible_modes = ["by_channel_local", "by_channel_global", "concatenated"] -class WaveformPrincipalComponent(BaseWaveformExtractorExtension): +# TODO handle extra sparsity + +class ComputePrincipalComponents(ResultExtension): """ - Class to extract principal components from a WaveformExtractor object. + Compute PC scores from waveform extractor. The PCA projections are pre-computed only + on the sampled waveforms available from the WaveformExtractor. + + Parameters + ---------- + sorting_result: SortingResult + A SortingResult object + n_components: int, default: 5 + Number of components fo PCA + mode: "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" + The PCA mode: + - "by_channel_local": a local PCA is fitted for each channel (projection by channel) + - "by_channel_global": a global PCA is fitted for all channels (projection by channel) + - "concatenated": channels are concatenated and a global PCA is fitted + sparsity: ChannelSparsity or None, default: None + The sparsity to apply to waveforms. + If sorting_result is already sparse, the default sparsity will be used + whiten: bool, default: True + If True, waveforms are pre-whitened + dtype: dtype, default: "float32" + Dtype of the pc scores + tmp_folder: str or Path or None, default: None + The temporary folder to use for parallel computation. If you run several `compute_principal_components` + functions in parallel with mode "by_channel_local", you need to specify a different `tmp_folder` for each call, + to avoid overwriting to the same folder + + + Examples + -------- + >>> we = si.extract_waveforms(recording, sorting, folder='waveforms') + >>> pc = st.compute_principal_components(we, n_components=3, mode='by_channel_local') + >>> # get pre-computed projections for unit_id=1 + >>> projections = pc.get_projections(unit_id=1) + >>> # get all pre-computed projections and labels + >>> all_projections, all_labels = pc.get_all_projections() + >>> # retrieve fitted pca model(s) + >>> pca_model = pc.get_pca_model() + >>> # compute projections on new waveforms + >>> proj_new = pc.project_new(new_waveforms) + >>> # run for all spikes in the SortingExtractor + >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ extension_name = "principal_components" - handle_sparsity = True - - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) - - @classmethod - def create(cls, waveform_extractor): - pc = WaveformPrincipalComponent(waveform_extractor) - return pc - - def __repr__(self): - we = self.waveform_extractor - clsname = self.__class__.__name__ - nseg = we.get_num_segments() - nchan = we.get_num_channels() - txt = f"{clsname}: {nchan} channels - {nseg} segments" - if len(self._params) > 0: - mode = self._params["mode"] - n_components = self._params["n_components"] - txt = txt + f"\n mode: {mode} n_components: {n_components}" - if self._params["sparsity"] is not None: - txt += " - sparse" - return txt + depend_on = ["waveforms", ] + need_recording = False + use_nodepipeline = False + + def __init__(self, sorting_result): + ResultExtension.__init__(self, sorting_result) def _set_params( self, n_components=5, mode="by_channel_local", whiten=True, dtype="float32", sparsity=None, tmp_folder=None ): assert mode in _possible_modes, "Invalid mode!" - if self.waveform_extractor.is_sparse(): - assert sparsity is None, "WaveformExtractor is already sparse, sparsity must be None" + if sparsity is not None: + # TODO alessio: implement local sparsity or not ?? + raise NotImplementedError - # the sparsity in params is ONLY the injected sparsity and not the waveform_extractor one + # the sparsity in params is ONLY the injected sparsity and not the sorting_result one params = dict( - n_components=int(n_components), - mode=str(mode), - whiten=bool(whiten), - dtype=np.dtype(dtype).str, - sparsity=sparsity, - tmp_folder=tmp_folder, + n_components=n_components, + mode=mode, + whiten=whiten, + dtype=np.dtype(dtype), + # sparsity=sparsity, + # tmp_folder=tmp_folder, ) - return params def _select_extension_data(self, unit_ids): + raise NotImplementedError + new_extension_data = dict() for unit_id in unit_ids: new_extension_data[f"pca_{unit_id}"] = self._extension_data[f"pca_{unit_id}"] @@ -112,7 +139,7 @@ def get_pca_model(self): mode = self._params["mode"] if mode == "by_channel_local": pca_models = [] - for chan_id in self.waveform_extractor.channel_ids: + for chan_id in self.sorting_result.channel_ids: pca_models.append(self._extension_data[f"pca_model_{mode}_{chan_id}"]) else: pca_models = self._extension_data[f"pca_model_{mode}"] @@ -140,14 +167,14 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): The PCA projections (num_all_waveforms, num_components, num_channels) """ if unit_ids is None: - unit_ids = self.waveform_extractor.sorting.unit_ids + unit_ids = self.sorting_result.sorting.unit_ids all_labels = [] #  can be unit_id or unit_index all_projections = [] for unit_index, unit_id in enumerate(unit_ids): proj = self.get_projections(unit_id, sparse=False) if channel_ids is not None: - chan_inds = self.waveform_extractor.channel_ids_to_indices(channel_ids) + chan_inds = self.sorting_result.channel_ids_to_indices(channel_ids) proj = proj[:, :, chan_inds] n = proj.shape[0] if outputs == "id": @@ -185,11 +212,11 @@ def project_new(self, new_waveforms, unit_id=None, sparse=False): mode = p["mode"] sparsity = p["sparsity"] - wfs0 = self.waveform_extractor.get_waveforms(unit_id=self.waveform_extractor.sorting.unit_ids[0]) + wfs0 = self.sorting_result.get_waveforms(unit_id=self.sorting_result.sorting.unit_ids[0]) assert ( wfs0.shape[1] == new_waveforms.shape[1] ), "Mismatch in number of samples between waveforms used to fit the pca model and 'new_waveforms'" - num_channels = len(self.waveform_extractor.channel_ids) + num_channels = len(self.sorting_result.channel_ids) # check waveform shapes if sparsity is not None: @@ -231,8 +258,8 @@ def project_new(self, new_waveforms, unit_id=None, sparse=False): return projections def get_sparsity(self): - if self.waveform_extractor.is_sparse(): - return self.waveform_extractor.sparsity + if self.sorting_result.is_sparse(): + return self.sorting_result.sparsity return self._params["sparsity"] def _run(self, **job_kwargs): @@ -245,7 +272,7 @@ def _run(self, **job_kwargs): in extension subfolder. """ p = self._params - we = self.waveform_extractor + we = self.sorting_result num_chans = we.get_num_channels() # update job_kwargs with global ones @@ -313,7 +340,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): """ job_kwargs = fix_job_kwargs(job_kwargs) p = self._params - we = self.waveform_extractor + we = self.sorting_result sorting = we.sorting assert ( we.has_recording() @@ -374,7 +401,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): from sklearn.decomposition import IncrementalPCA from concurrent.futures import ProcessPoolExecutor - we = self.waveform_extractor + we = self.sorting_result p = self._params unit_ids = we.unit_ids @@ -447,7 +474,7 @@ def _run_by_channel_local(self, projection_memmap, n_jobs, progress_bar): """ from sklearn.exceptions import NotFittedError - we = self.waveform_extractor + we = self.sorting_result unit_ids = we.unit_ids pca_model = self._fit_by_channel_local(n_jobs, progress_bar) @@ -477,7 +504,7 @@ def _run_by_channel_local(self, projection_memmap, n_jobs, progress_bar): ) def _fit_by_channel_global(self, progress_bar): - we = self.waveform_extractor + we = self.sorting_result p = self._params unit_ids = we.unit_ids @@ -513,7 +540,7 @@ def _run_by_channel_global(self, projection_objects, n_jobs, progress_bar): The transform is applied by channel. The output is then (n_spike, n_components, n_channels) """ - we = self.waveform_extractor + we = self.sorting_result unit_ids = we.unit_ids pca_model = self._fit_by_channel_global(progress_bar) @@ -533,7 +560,7 @@ def _run_by_channel_global(self, projection_objects, n_jobs, progress_bar): projection_objects[unit_id][:, :, chan_ind] = proj def _fit_concatenated(self, progress_bar): - we = self.waveform_extractor + we = self.sorting_result p = self._params unit_ids = we.unit_ids @@ -572,7 +599,7 @@ def _run_concatenated(self, projection_objects, n_jobs, progress_bar): In this mode the waveforms are concatenated and there is a global fit_transform at once. """ - we = self.waveform_extractor + we = self.sorting_result p = self._params unit_ids = we.unit_ids @@ -593,7 +620,7 @@ def _run_concatenated(self, projection_objects, n_jobs, progress_bar): def _get_sparse_waveforms(self, unit_id): # get waveforms : dense or sparse - we = self.waveform_extractor + we = self.sorting_result sparsity = self._params["sparsity"] if we.is_sparse(): # natural sparsity @@ -601,11 +628,11 @@ def _get_sparse_waveforms(self, unit_id): channel_inds = we.sparsity.unit_id_to_channel_indices[unit_id] elif sparsity is not None: # injected sparsity - wfs = self.waveform_extractor.get_waveforms(unit_id, sparsity=sparsity, lazy=False) + wfs = self.sorting_result.get_waveforms(unit_id, sparsity=sparsity, lazy=False) channel_inds = sparsity.unit_id_to_channel_indices[unit_id] else: # dense - wfs = self.waveform_extractor.get_waveforms(unit_id, sparsity=None, lazy=False) + wfs = self.sorting_result.get_waveforms(unit_id, sparsity=None, lazy=False) channel_inds = np.arange(we.channel_ids.size, dtype=int) return wfs, channel_inds @@ -685,88 +712,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte return worker_ctx -WaveformPrincipalComponent.run_for_all_spikes.__doc__ = WaveformPrincipalComponent.run_for_all_spikes.__doc__.format( - _shared_job_kwargs_doc -) - -WaveformExtractor.register_extension(WaveformPrincipalComponent) - - -def compute_principal_components( - waveform_extractor, - load_if_exists=False, - n_components=5, - mode="by_channel_local", - sparsity=None, - whiten=True, - dtype="float32", - tmp_folder=None, - **job_kwargs, -): - """ - Compute PC scores from waveform extractor. The PCA projections are pre-computed only - on the sampled waveforms available from the WaveformExtractor. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - The waveform extractor - load_if_exists: bool - If True and pc scores are already in the waveform extractor folders, pc scores are loaded and not recomputed. - n_components: int, default: 5 - Number of components fo PCA - mode: "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" - The PCA mode: - - "by_channel_local": a local PCA is fitted for each channel (projection by channel) - - "by_channel_global": a global PCA is fitted for all channels (projection by channel) - - "concatenated": channels are concatenated and a global PCA is fitted - sparsity: ChannelSparsity or None, default: None - The sparsity to apply to waveforms. - If waveform_extractor is already sparse, the default sparsity will be used - whiten: bool, default: True - If True, waveforms are pre-whitened - dtype: dtype, default: "float32" - Dtype of the pc scores - tmp_folder: str or Path or None, default: None - The temporary folder to use for parallel computation. If you run several `compute_principal_components` - functions in parallel with mode "by_channel_local", you need to specify a different `tmp_folder` for each call, - to avoid overwriting to the same folder - n_jobs: int, default: 1 - Number of jobs used to fit the PCA model (if mode is "by_channel_local") - progress_bar: bool, default: False - If True, a progress bar is shown - - Returns - ------- - pc: WaveformPrincipalComponent - The waveform principal component object - - Examples - -------- - >>> we = si.extract_waveforms(recording, sorting, folder='waveforms') - >>> pc = st.compute_principal_components(we, n_components=3, mode='by_channel_local') - >>> # get pre-computed projections for unit_id=1 - >>> projections = pc.get_projections(unit_id=1) - >>> # get all pre-computed projections and labels - >>> all_projections, all_labels = pc.get_all_projections() - >>> # retrieve fitted pca model(s) - >>> pca_model = pc.get_pca_model() - >>> # compute projections on new waveforms - >>> proj_new = pc.project_new(new_waveforms) - >>> # run for all spikes in the SortingExtractor - >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") - """ - - if load_if_exists and waveform_extractor.has_extension(WaveformPrincipalComponent.extension_name): - pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) - else: - pc = WaveformPrincipalComponent.create(waveform_extractor) - pc.set_params( - n_components=n_components, mode=mode, whiten=whiten, dtype=dtype, sparsity=sparsity, tmp_folder=tmp_folder - ) - pc.run(**job_kwargs) +# WaveformPrincipalComponent.run_for_all_spikes.__doc__ = WaveformPrincipalComponent.run_for_all_spikes.__doc__.format( +# _shared_job_kwargs_doc +# ) - return pc +register_result_extension(ComputePrincipalComponents) +compute_principal_components = ComputePrincipalComponents.function_factory() def partial_fit_one_channel(args): diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 0b7e8b4602..696f316473 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -4,198 +4,205 @@ import numpy as np -from spikeinterface import compute_sparsity -from spikeinterface.postprocessing import WaveformPrincipalComponent, compute_principal_components -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite +from spikeinterface.postprocessing import ComputePrincipalComponents, compute_principal_components +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "postprocessing" -else: - cache_folder = Path("cache_folder") / "postprocessing" + + + +# from spikeinterface import compute_sparsity +# from spikeinterface.postprocessing import WaveformPrincipalComponent, compute_principal_components +# from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite + +# if hasattr(pytest, "global_test_folder"): +# cache_folder = pytest.global_test_folder / "postprocessing" +# else: +# cache_folder = Path("cache_folder") / "postprocessing" DEBUG = False -class PrincipalComponentsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = WaveformPrincipalComponent - extension_data_names = ["pca_0", "pca_1"] +class PrincipalComponentsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): + extension_class = ComputePrincipalComponents extension_function_kwargs_list = [ dict(mode="by_channel_local"), - dict(mode="by_channel_local", n_jobs=2), + dict(mode="by_channel_local"), dict(mode="by_channel_global"), dict(mode="concatenated"), ] - def test_shapes(self): - nchan1 = self.we1.recording.get_num_channels() - for mode in ("by_channel_local", "by_channel_global"): - _ = self.extension_class.get_extension_function()(self.we1, mode=mode, n_components=5) - pc = self.we1.load_extension(self.extension_class.extension_name) - for unit_id in self.we1.sorting.unit_ids: - proj = pc.get_projections(unit_id) - assert proj.shape[1:] == (5, nchan1) - for mode in ("concatenated",): - _ = self.extension_class.get_extension_function()(self.we2, mode=mode, n_components=3) - pc = self.we2.load_extension(self.extension_class.extension_name) - for unit_id in self.we2.sorting.unit_ids: - proj = pc.get_projections(unit_id) - assert proj.shape[1] == 3 - - def test_compute_for_all_spikes(self): - we = self.we1 - pc = self.extension_class.get_extension_function()(we, load_if_exists=True) - print(pc) - - pc_file1 = pc.extension_folder / "all_pc1.npy" - pc.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) - all_pc1 = np.load(pc_file1) - - pc_file2 = pc.extension_folder / "all_pc2.npy" - pc.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) - all_pc2 = np.load(pc_file2) - - assert np.array_equal(all_pc1, all_pc2) - - # test with sparsity - sparsity = compute_sparsity(we, method="radius", radius_um=50) - we_copy = we.save(folder=cache_folder / "we_copy") - pc_sparse = self.extension_class.get_extension_function()(we_copy, sparsity=sparsity, load_if_exists=False) - pc_file_sparse = pc.extension_folder / "all_pc_sparse.npy" - pc_sparse.run_for_all_spikes(pc_file_sparse, chunk_size=10000, n_jobs=1) - all_pc_sparse = np.load(pc_file_sparse) - all_spikes_seg0 = we_copy.sorting.to_spike_vector(concatenated=False)[0] - for unit_index, unit_id in enumerate(we.unit_ids): - sparse_channel_ids = sparsity.unit_id_to_channel_ids[unit_id] - pc_unit = all_pc_sparse[all_spikes_seg0["unit_index"] == unit_index] - assert np.allclose(pc_unit[:, :, len(sparse_channel_ids) :], 0) - - def test_sparse(self): - we = self.we2 - unit_ids = we.unit_ids - num_channels = we.get_num_channels() - pc = self.extension_class(we) - - sparsity_radius = compute_sparsity(we, method="radius", radius_um=50) - sparsity_best = compute_sparsity(we, method="best_channels", num_channels=2) - sparsities = [sparsity_radius, sparsity_best] - print(sparsities) - - for mode in ("by_channel_local", "by_channel_global"): - for sparsity in sparsities: - pc.set_params(n_components=5, mode=mode, sparsity=sparsity) - pc.run() - for i, unit_id in enumerate(unit_ids): - proj_sparse = pc.get_projections(unit_id, sparse=True) - assert proj_sparse.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) - proj_dense = pc.get_projections(unit_id, sparse=False) - assert proj_dense.shape[1:] == (5, num_channels) - - # test project_new - unit_id = 3 - new_wfs = we.get_waveforms(unit_id) - new_proj_sparse = pc.project_new(new_wfs, unit_id=unit_id, sparse=True) - assert new_proj_sparse.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) - new_proj_dense = pc.project_new(new_wfs, unit_id=unit_id, sparse=False) - assert new_proj_dense.shape == (new_wfs.shape[0], 5, num_channels) - - if DEBUG: - import matplotlib.pyplot as plt - - plt.ion() - cmap = plt.get_cmap("jet", len(unit_ids)) - fig, axs = plt.subplots(nrows=len(unit_ids), ncols=num_channels) - for i, unit_id in enumerate(unit_ids): - comp = pc.get_projections(unit_id) - print(comp.shape) - for chan_ind in range(num_channels): - ax = axs[i, chan_ind] - ax.scatter(comp[:, 0, chan_ind], comp[:, 1, chan_ind], color=cmap(i)) - ax.set_title(f"{mode}-{sparsity.unit_id_to_channel_ids[unit_id]}") - if i == 0: - ax.set_xlabel(f"Ch{chan_ind}") - plt.show() - - for mode in ("concatenated",): - # concatenated is only compatible with "best" - pc.set_params(n_components=5, mode=mode, sparsity=sparsity_best) - print(pc) - pc.run() - for i, unit_id in enumerate(unit_ids): - proj = pc.get_projections(unit_id) - assert proj.shape[1] == 5 - - # test project_new - unit_id = 3 - new_wfs = we.get_waveforms(unit_id) - new_proj = pc.project_new(new_wfs, unit_id) - assert new_proj.shape == (len(new_wfs), 5) - - def test_project_new(self): - from sklearn.decomposition import IncrementalPCA - - we = self.we1 - if we.has_extension("principal_components"): - we.delete_extension("principal_components") - we_cp = we.select_units(we.unit_ids, self.cache_folder / "toy_waveforms_1seg_cp") - - wfs0 = we.get_waveforms(unit_id=we.unit_ids[0]) - n_samples = wfs0.shape[1] - n_channels = wfs0.shape[2] - n_components = 5 - - # local - pc_local = compute_principal_components( - we, n_components=n_components, load_if_exists=True, mode="by_channel_local" - ) - pc_local_par = compute_principal_components( - we_cp, n_components=n_components, load_if_exists=True, mode="by_channel_local", n_jobs=2, progress_bar=True - ) - - all_pca = pc_local.get_pca_model() - all_pca_par = pc_local_par.get_pca_model() - - assert len(all_pca) == we.get_num_channels() - assert len(all_pca_par) == we.get_num_channels() - - for pc, pc_par in zip(all_pca, all_pca_par): - assert np.allclose(pc.components_, pc_par.components_) - - # project - new_waveforms = np.random.randn(100, n_samples, n_channels) - new_proj = pc_local.project_new(new_waveforms) - - assert new_proj.shape == (100, n_components, n_channels) - - # global - we.delete_extension("principal_components") - pc_global = compute_principal_components( - we, n_components=n_components, load_if_exists=True, mode="by_channel_global" - ) - - all_pca = pc_global.get_pca_model() - assert isinstance(all_pca, IncrementalPCA) - - # project - new_waveforms = np.random.randn(100, n_samples, n_channels) - new_proj = pc_global.project_new(new_waveforms) - - assert new_proj.shape == (100, n_components, n_channels) - - # concatenated - we.delete_extension("principal_components") - pc_concatenated = compute_principal_components( - we, n_components=n_components, load_if_exists=True, mode="concatenated" - ) - - all_pca = pc_concatenated.get_pca_model() - assert isinstance(all_pca, IncrementalPCA) - - # project - new_waveforms = np.random.randn(100, n_samples, n_channels) - new_proj = pc_concatenated.project_new(new_waveforms) - - assert new_proj.shape == (100, n_components) + # TODO : put back theses tests + + # def test_shapes(self): + # nchan1 = self.we1.recording.get_num_channels() + # for mode in ("by_channel_local", "by_channel_global"): + # _ = self.extension_class.get_extension_function()(self.we1, mode=mode, n_components=5) + # pc = self.we1.load_extension(self.extension_class.extension_name) + # for unit_id in self.we1.sorting.unit_ids: + # proj = pc.get_projections(unit_id) + # assert proj.shape[1:] == (5, nchan1) + # for mode in ("concatenated",): + # _ = self.extension_class.get_extension_function()(self.we2, mode=mode, n_components=3) + # pc = self.we2.load_extension(self.extension_class.extension_name) + # for unit_id in self.we2.sorting.unit_ids: + # proj = pc.get_projections(unit_id) + # assert proj.shape[1] == 3 + + # def test_compute_for_all_spikes(self): + # we = self.we1 + # pc = self.extension_class.get_extension_function()(we, load_if_exists=True) + # print(pc) + + # pc_file1 = pc.extension_folder / "all_pc1.npy" + # pc.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) + # all_pc1 = np.load(pc_file1) + + # pc_file2 = pc.extension_folder / "all_pc2.npy" + # pc.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) + # all_pc2 = np.load(pc_file2) + + # assert np.array_equal(all_pc1, all_pc2) + + # # test with sparsity + # sparsity = compute_sparsity(we, method="radius", radius_um=50) + # we_copy = we.save(folder=cache_folder / "we_copy") + # pc_sparse = self.extension_class.get_extension_function()(we_copy, sparsity=sparsity, load_if_exists=False) + # pc_file_sparse = pc.extension_folder / "all_pc_sparse.npy" + # pc_sparse.run_for_all_spikes(pc_file_sparse, chunk_size=10000, n_jobs=1) + # all_pc_sparse = np.load(pc_file_sparse) + # all_spikes_seg0 = we_copy.sorting.to_spike_vector(concatenated=False)[0] + # for unit_index, unit_id in enumerate(we.unit_ids): + # sparse_channel_ids = sparsity.unit_id_to_channel_ids[unit_id] + # pc_unit = all_pc_sparse[all_spikes_seg0["unit_index"] == unit_index] + # assert np.allclose(pc_unit[:, :, len(sparse_channel_ids) :], 0) + + # def test_sparse(self): + # we = self.we2 + # unit_ids = we.unit_ids + # num_channels = we.get_num_channels() + # pc = self.extension_class(we) + + # sparsity_radius = compute_sparsity(we, method="radius", radius_um=50) + # sparsity_best = compute_sparsity(we, method="best_channels", num_channels=2) + # sparsities = [sparsity_radius, sparsity_best] + # print(sparsities) + + # for mode in ("by_channel_local", "by_channel_global"): + # for sparsity in sparsities: + # pc.set_params(n_components=5, mode=mode, sparsity=sparsity) + # pc.run() + # for i, unit_id in enumerate(unit_ids): + # proj_sparse = pc.get_projections(unit_id, sparse=True) + # assert proj_sparse.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + # proj_dense = pc.get_projections(unit_id, sparse=False) + # assert proj_dense.shape[1:] == (5, num_channels) + + # # test project_new + # unit_id = 3 + # new_wfs = we.get_waveforms(unit_id) + # new_proj_sparse = pc.project_new(new_wfs, unit_id=unit_id, sparse=True) + # assert new_proj_sparse.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + # new_proj_dense = pc.project_new(new_wfs, unit_id=unit_id, sparse=False) + # assert new_proj_dense.shape == (new_wfs.shape[0], 5, num_channels) + + # if DEBUG: + # import matplotlib.pyplot as plt + + # plt.ion() + # cmap = plt.get_cmap("jet", len(unit_ids)) + # fig, axs = plt.subplots(nrows=len(unit_ids), ncols=num_channels) + # for i, unit_id in enumerate(unit_ids): + # comp = pc.get_projections(unit_id) + # print(comp.shape) + # for chan_ind in range(num_channels): + # ax = axs[i, chan_ind] + # ax.scatter(comp[:, 0, chan_ind], comp[:, 1, chan_ind], color=cmap(i)) + # ax.set_title(f"{mode}-{sparsity.unit_id_to_channel_ids[unit_id]}") + # if i == 0: + # ax.set_xlabel(f"Ch{chan_ind}") + # plt.show() + + # for mode in ("concatenated",): + # # concatenated is only compatible with "best" + # pc.set_params(n_components=5, mode=mode, sparsity=sparsity_best) + # print(pc) + # pc.run() + # for i, unit_id in enumerate(unit_ids): + # proj = pc.get_projections(unit_id) + # assert proj.shape[1] == 5 + + # # test project_new + # unit_id = 3 + # new_wfs = we.get_waveforms(unit_id) + # new_proj = pc.project_new(new_wfs, unit_id) + # assert new_proj.shape == (len(new_wfs), 5) + + # def test_project_new(self): + # from sklearn.decomposition import IncrementalPCA + + # we = self.we1 + # if we.has_extension("principal_components"): + # we.delete_extension("principal_components") + # we_cp = we.select_units(we.unit_ids, self.cache_folder / "toy_waveforms_1seg_cp") + + # wfs0 = we.get_waveforms(unit_id=we.unit_ids[0]) + # n_samples = wfs0.shape[1] + # n_channels = wfs0.shape[2] + # n_components = 5 + + # # local + # pc_local = compute_principal_components( + # we, n_components=n_components, load_if_exists=True, mode="by_channel_local" + # ) + # pc_local_par = compute_principal_components( + # we_cp, n_components=n_components, load_if_exists=True, mode="by_channel_local", n_jobs=2, progress_bar=True + # ) + + # all_pca = pc_local.get_pca_model() + # all_pca_par = pc_local_par.get_pca_model() + + # assert len(all_pca) == we.get_num_channels() + # assert len(all_pca_par) == we.get_num_channels() + + # for pc, pc_par in zip(all_pca, all_pca_par): + # assert np.allclose(pc.components_, pc_par.components_) + + # # project + # new_waveforms = np.random.randn(100, n_samples, n_channels) + # new_proj = pc_local.project_new(new_waveforms) + + # assert new_proj.shape == (100, n_components, n_channels) + + # # global + # we.delete_extension("principal_components") + # pc_global = compute_principal_components( + # we, n_components=n_components, load_if_exists=True, mode="by_channel_global" + # ) + + # all_pca = pc_global.get_pca_model() + # assert isinstance(all_pca, IncrementalPCA) + + # # project + # new_waveforms = np.random.randn(100, n_samples, n_channels) + # new_proj = pc_global.project_new(new_waveforms) + + # assert new_proj.shape == (100, n_components, n_channels) + + # # concatenated + # we.delete_extension("principal_components") + # pc_concatenated = compute_principal_components( + # we, n_components=n_components, load_if_exists=True, mode="concatenated" + # ) + + # all_pca = pc_concatenated.get_pca_model() + # assert isinstance(all_pca, IncrementalPCA) + + # # project + # new_waveforms = np.random.randn(100, n_samples, n_channels) + # new_proj = pc_concatenated.project_new(new_waveforms) + + # assert new_proj.shape == (100, n_components) if __name__ == "__main__": @@ -205,4 +212,4 @@ def test_project_new(self): # test.test_shapes() # test.test_compute_for_all_spikes() # test.test_sparse() - test.test_project_new() + # test.test_project_new() From 7afacec0109d3a40fa2996aad788e5218cc809d6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 26 Jan 2024 17:34:22 +0100 Subject: [PATCH 032/192] wip --- .../postprocessing/principal_component.py | 44 +++++++++++++------ .../tests/test_principal_component.py | 1 - 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 8f8fdfb3b0..6819e85a21 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -119,7 +119,7 @@ def get_projections(self, unit_id, sparse=False): In case sparsity is used, only the projections on sparse channels are returned. """ projections = self._extension_data[f"pca_{unit_id}"] - mode = self._params["mode"] + mode = self.params["mode"] if mode in ("by_channel_local", "by_channel_global") and sparse: sparsity = self.get_sparsity() if sparsity is not None: @@ -136,7 +136,7 @@ def get_pca_model(self): * if mode is "by_channel_local", "pca_model" is a list of PCA model by channel * if mode is "by_channel_global" or "concatenated", "pca_model" is a single PCA model """ - mode = self._params["mode"] + mode = self.params["mode"] if mode == "by_channel_local": pca_models = [] for chan_id in self.sorting_result.channel_ids: @@ -208,7 +208,7 @@ def project_new(self, new_waveforms, unit_id=None, sparse=False): Projections of new waveforms on PCA compoents """ - p = self._params + p = self.params mode = p["mode"] sparsity = p["sparsity"] @@ -260,7 +260,7 @@ def project_new(self, new_waveforms, unit_id=None, sparse=False): def get_sparsity(self): if self.sorting_result.is_sparse(): return self.sorting_result.sparsity - return self._params["sparsity"] + return self.params["sparsity"] def _run(self, **job_kwargs): """ @@ -271,15 +271,33 @@ def _run(self, **job_kwargs): This will be cached in the same folder than WaveformExtarctor in extension subfolder. """ - p = self._params - we = self.sorting_result - num_chans = we.get_num_channels() + p = self.params + # we = self.sorting_result + # num_chans = we.get_num_channels() # update job_kwargs with global ones job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] + ext = self.sorting_result.get_extension("waveforms") + waveforms = ext.data["waveforms"] + + + spikes = self.sorting_result.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + + # prepare buffer + n_components = self.params["n_components"] + if p["mode"] in ("by_channel_local", "by_channel_global"): + shape = (waveforms.shape[0], n_components, waveforms.shape[2]) + elif p["mode"] == "concatenated": + shape = (waveforms.shape[0], n_components) + pca_projection = np.zeros(shape, dtype="float32") + + # fit PCA models + #### + # prepare memmap files with npy projection_objects = {} unit_ids = we.unit_ids @@ -339,7 +357,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): {} """ job_kwargs = fix_job_kwargs(job_kwargs) - p = self._params + p = self.params we = self.sorting_result sorting = we.sorting assert ( @@ -402,7 +420,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): from concurrent.futures import ProcessPoolExecutor we = self.sorting_result - p = self._params + p = self.params unit_ids = we.unit_ids channel_ids = we.channel_ids @@ -505,7 +523,7 @@ def _run_by_channel_local(self, projection_memmap, n_jobs, progress_bar): def _fit_by_channel_global(self, progress_bar): we = self.sorting_result - p = self._params + p = self.params unit_ids = we.unit_ids # there is one unique PCA accross channels @@ -561,7 +579,7 @@ def _run_by_channel_global(self, projection_objects, n_jobs, progress_bar): def _fit_concatenated(self, progress_bar): we = self.sorting_result - p = self._params + p = self.params unit_ids = we.unit_ids sparsity = self.get_sparsity() @@ -600,7 +618,7 @@ def _run_concatenated(self, projection_objects, n_jobs, progress_bar): a global fit_transform at once. """ we = self.sorting_result - p = self._params + p = self.params unit_ids = we.unit_ids @@ -621,7 +639,7 @@ def _run_concatenated(self, projection_objects, n_jobs, progress_bar): def _get_sparse_waveforms(self, unit_id): # get waveforms : dense or sparse we = self.sorting_result - sparsity = self._params["sparsity"] + sparsity = self.params["sparsity"] if we.is_sparse(): # natural sparsity wfs = we.get_waveforms(unit_id, lazy=False) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 696f316473..d40b683a1c 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -26,7 +26,6 @@ class PrincipalComponentsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputePrincipalComponents extension_function_kwargs_list = [ - dict(mode="by_channel_local"), dict(mode="by_channel_local"), dict(mode="by_channel_global"), dict(mode="concatenated"), From b91a27ea93f2579705705d95576884640d74a3cf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 29 Jan 2024 14:09:28 +0100 Subject: [PATCH 033/192] Porting pincipal components to SortingRseult --- .../postprocessing/principal_component.py | 695 +++++++----------- .../tests/common_extension_tests.py | 22 +- .../tests/test_principal_component.py | 200 ++--- 3 files changed, 347 insertions(+), 570 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 6819e85a21..2c02457168 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -10,7 +10,7 @@ from spikeinterface.core.sortingresult import register_result_extension, ResultExtension from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs -from spikeinterface.core.globals import get_global_tmp_folder +# from spikeinterface.core.globals import get_global_tmp_folder _possible_modes = ["by_channel_local", "by_channel_global", "concatenated"] @@ -40,11 +40,6 @@ class ComputePrincipalComponents(ResultExtension): If True, waveforms are pre-whitened dtype: dtype, default: "float32" Dtype of the pc scores - tmp_folder: str or Path or None, default: None - The temporary folder to use for parallel computation. If you run several `compute_principal_components` - functions in parallel with mode "by_channel_local", you need to specify a different `tmp_folder` for each call, - to avoid overwriting to the same folder - Examples -------- @@ -52,8 +47,6 @@ class ComputePrincipalComponents(ResultExtension): >>> pc = st.compute_principal_components(we, n_components=3, mode='by_channel_local') >>> # get pre-computed projections for unit_id=1 >>> projections = pc.get_projections(unit_id=1) - >>> # get all pre-computed projections and labels - >>> all_projections, all_labels = pc.get_all_projections() >>> # retrieve fitted pca model(s) >>> pca_model = pc.get_pca_model() >>> # compute projections on new waveforms @@ -71,7 +64,7 @@ def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) def _set_params( - self, n_components=5, mode="by_channel_local", whiten=True, dtype="float32", sparsity=None, tmp_folder=None + self, n_components=5, mode="by_channel_local", whiten=True, dtype="float32", sparsity=None, ): assert mode in _possible_modes, "Invalid mode!" @@ -91,15 +84,20 @@ def _set_params( return params def _select_extension_data(self, unit_ids): - raise NotImplementedError - - new_extension_data = dict() - for unit_id in unit_ids: - new_extension_data[f"pca_{unit_id}"] = self._extension_data[f"pca_{unit_id}"] - for k, v in self._extension_data.items(): + + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + spikes = self.sorting_result.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + keep_spike_mask = np.isin(some_spikes["unit_index"], keep_unit_indices) + + new_data = dict() + new_data["pca_projection"] = self.data["pca_projection"][keep_spike_mask, :, :] + # one or several model + for k, v in self.data.items(): if "model" in k: - new_extension_data[k] = v - return new_extension_data + new_data[k] = v + return new_data + def get_projections(self, unit_id, sparse=False): """ @@ -118,7 +116,7 @@ def get_projections(self, unit_id, sparse=False): The PCA projections (num_waveforms, num_components, num_channels). In case sparsity is used, only the projections on sparse channels are returned. """ - projections = self._extension_data[f"pca_{unit_id}"] + projections = self.data[f"pca_{unit_id}"] mode = self.params["mode"] if mode in ("by_channel_local", "by_channel_global") and sparse: sparsity = self.get_sparsity() @@ -140,122 +138,76 @@ def get_pca_model(self): if mode == "by_channel_local": pca_models = [] for chan_id in self.sorting_result.channel_ids: - pca_models.append(self._extension_data[f"pca_model_{mode}_{chan_id}"]) + pca_models.append(self.data[f"pca_model_{mode}_{chan_id}"]) else: - pca_models = self._extension_data[f"pca_model_{mode}"] + pca_models = self.data[f"pca_model_{mode}"] return pca_models - def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): - """ - Returns the computed projections for the sampled waveforms of all units. - - Parameters - ---------- - channel_ids : list, default: None - List of channel ids on which projections are computed - unit_ids : list, default: None - List of unit ids to return projections for - outputs: str - * "id": "all_labels" contain unit ids - * "index": "all_labels" contain unit indices - - Returns - ------- - all_labels: np.array - Array with labels (ids or indices based on "outputs") of returned PCA projections - all_projections: np.array - The PCA projections (num_all_waveforms, num_components, num_channels) - """ - if unit_ids is None: - unit_ids = self.sorting_result.sorting.unit_ids - - all_labels = [] #  can be unit_id or unit_index - all_projections = [] - for unit_index, unit_id in enumerate(unit_ids): - proj = self.get_projections(unit_id, sparse=False) - if channel_ids is not None: - chan_inds = self.sorting_result.channel_ids_to_indices(channel_ids) - proj = proj[:, :, chan_inds] - n = proj.shape[0] - if outputs == "id": - labels = np.array([unit_id] * n) - elif outputs == "index": - labels = np.ones(n, dtype="int64") - labels[:] = unit_index - all_labels.append(labels) - all_projections.append(proj) - all_labels = np.concatenate(all_labels, axis=0) - all_projections = np.concatenate(all_projections, axis=0) - - return all_labels, all_projections - - def project_new(self, new_waveforms, unit_id=None, sparse=False): + # def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): + # """ + # Returns the computed projections for the sampled waveforms of all units. + + # Parameters + # ---------- + # channel_ids : list, default: None + # List of channel ids on which projections are computed + # unit_ids : list, default: None + # List of unit ids to return projections for + # outputs: str + # * "id": "all_labels" contain unit ids + # * "index": "all_labels" contain unit indices + + # Returns + # ------- + # all_labels: np.array + # Array with labels (ids or indices based on "outputs") of returned PCA projections + # all_projections: np.array + # The PCA projections (num_all_waveforms, num_components, num_channels) + # """ + # if unit_ids is None: + # unit_ids = self.sorting_result.sorting.unit_ids + + # all_labels = [] #  can be unit_id or unit_index + # all_projections = [] + # for unit_index, unit_id in enumerate(unit_ids): + # proj = self.get_projections(unit_id, sparse=False) + # if channel_ids is not None: + # chan_inds = self.sorting_result.chanpca_projectionnel_ids_to_indices(channel_ids) + # proj = proj[:, :, chan_inds] + # n = proj.shape[0] + # if outputs == "id": + # labels = np.array([unit_id] * n) + # elif outputs == "index": + # labels = np.ones(n, dtype="int64") + # labels[:] = unit_index + # all_labels.append(labels) + # all_projections.append(proj) + # all_labels = np.concatenate(all_labels, axis=0) + # all_projections = np.concatenate(all_projections, axis=0) + + # return all_labels, all_projections + + def project_new(self, new_spikes, new_waveforms, progress_bar=True): """ Projects new waveforms or traces snippets on the PC components. Parameters ---------- + new_spikes: np.array + The spikes vector associated to the waveforms buffer. This is need need to get the sparsity spike per spike. new_waveforms: np.array Array with new waveforms to project with shape (num_waveforms, num_samples, num_channels) - unit_id: int or str - In case PCA is sparse and mode is by_channel_local, the unit_id of "new_waveforms" - sparse: bool, default: False - If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- - projections: np.array + new_projections: np.array Projections of new waveforms on PCA compoents """ - p = self.params - mode = p["mode"] - sparsity = p["sparsity"] - - wfs0 = self.sorting_result.get_waveforms(unit_id=self.sorting_result.sorting.unit_ids[0]) - assert ( - wfs0.shape[1] == new_waveforms.shape[1] - ), "Mismatch in number of samples between waveforms used to fit the pca model and 'new_waveforms'" - num_channels = len(self.sorting_result.channel_ids) - - # check waveform shapes - if sparsity is not None: - assert ( - unit_id is not None - ), "The unit_id of the new_waveforms is needed to apply the waveforms transformation" - channel_inds = sparsity.unit_id_to_channel_indices[unit_id] - if new_waveforms.shape[2] != len(channel_inds): - new_waveforms = new_waveforms.copy()[:, :, channel_inds] - else: - assert ( - wfs0.shape[2] == new_waveforms.shape[2] - ), "Mismatch in number of channels between waveforms used to fit the pca model and 'new_waveforms'" - channel_inds = np.arange(num_channels, dtype=int) - - # get channel ids and pca models pca_model = self.get_pca_model() - projections = None - - if mode == "by_channel_local": - shape = (new_waveforms.shape[0], p["n_components"], num_channels) - projections = np.zeros(shape) - for wf_ind, chan_ind in enumerate(channel_inds): - pca = pca_model[chan_ind] - projections[:, :, chan_ind] = pca.transform(new_waveforms[:, :, wf_ind]) - elif mode == "by_channel_global": - shape = (new_waveforms.shape[0], p["n_components"], num_channels) - projections = np.zeros(shape) - for wf_ind, chan_ind in enumerate(channel_inds): - projections[:, :, chan_ind] = pca_model.transform(new_waveforms[:, :, wf_ind]) - elif mode == "concatenated": - wfs_flat = new_waveforms.reshape(new_waveforms.shape[0], -1) - projections = pca_model.transform(wfs_flat) + new_projections = self._transform_waveforms( new_spikes, new_waveforms, pca_model, progress_bar=progress_bar) + return new_projections - # take care of sparsity (not in case of concatenated) - if mode in ("by_channel_local", "by_channel_global") and sparse: - if sparsity is not None: - projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] - return projections def get_sparsity(self): if self.sorting_result.is_sparse(): @@ -264,194 +216,146 @@ def get_sparsity(self): def _run(self, **job_kwargs): """ - Compute the PCs on waveforms extacted within the WaveformExtarctor. - Projections are computed only on the waveforms sampled by the WaveformExtractor. - - The index of spikes come from the WaveformExtarctor. - This will be cached in the same folder than WaveformExtarctor - in extension subfolder. + Compute the PCs on waveforms extacted within the by ComputeWaveforms. + Projections are computed only on the waveforms sampled by the SortingResult. """ p = self.params - # we = self.sorting_result - # num_chans = we.get_num_channels() + mode = p["mode"] # update job_kwargs with global ones job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] - ext = self.sorting_result.get_extension("waveforms") - waveforms = ext.data["waveforms"] - - - spikes = self.sorting_result.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] - - # prepare buffer - n_components = self.params["n_components"] - if p["mode"] in ("by_channel_local", "by_channel_global"): - shape = (waveforms.shape[0], n_components, waveforms.shape[2]) - elif p["mode"] == "concatenated": - shape = (waveforms.shape[0], n_components) - pca_projection = np.zeros(shape, dtype="float32") - - # fit PCA models - #### - - # prepare memmap files with npy - projection_objects = {} - unit_ids = we.unit_ids - - for unit_id in unit_ids: - n_spike = we.get_waveforms(unit_id).shape[0] - if p["mode"] in ("by_channel_local", "by_channel_global"): - shape = (n_spike, p["n_components"], num_chans) - elif p["mode"] == "concatenated": - shape = (n_spike, p["n_components"]) - proj = np.zeros(shape, dtype=p["dtype"]) - projection_objects[unit_id] = proj - - # run ... - if p["mode"] == "by_channel_local": - self._run_by_channel_local(projection_objects, n_jobs, progress_bar) - elif p["mode"] == "by_channel_global": - self._run_by_channel_global(projection_objects, n_jobs, progress_bar) - elif p["mode"] == "concatenated": - self._run_concatenated(projection_objects, n_jobs, progress_bar) - - # add projections to extension data - for unit_id in unit_ids: - self._extension_data[f"pca_{unit_id}"] = projection_objects[unit_id] - - def get_data(self): - """ - Get computed PCA projections. - - Returns - ------- - all_labels : 1d np.array - Array with all spike labels - all_projections : 3d array - Array with PCA projections (num_spikes, num_components, num_channels) - """ - return self.get_all_projections() - - @staticmethod - def get_extension_function(): - return compute_principal_components - - def run_for_all_spikes(self, file_path=None, **job_kwargs): - """ - Project all spikes from the sorting on the PCA model. - This is a long computation because waveform need to be extracted from each spikes. - - Used mainly for `export_to_phy()` - - PCs are exported to a .npy single file. - - Parameters - ---------- - file_path : str or Path or None - Path to npy file that will store the PCA projections. - If None, output is saved in principal_components/all_pcs.npy - {} - """ - job_kwargs = fix_job_kwargs(job_kwargs) - p = self.params - we = self.sorting_result - sorting = we.sorting - assert ( - we.has_recording() - ), "To compute PCA projections for all spikes, the waveform extractor needs the recording" - recording = we.recording - - assert sorting.get_num_segments() == 1 - assert p["mode"] in ("by_channel_local", "by_channel_global") - - if file_path is None: - file_path = self.extension_folder / "all_pcs.npy" - file_path = Path(file_path) - - # spikes = sorting.to_spike_vector(concatenated=False) - # # This is the first segment only - # spikes = spikes[0] - # spike_times = spikes["sample_index"] - # spike_labels = spikes["unit_index"] - - sparsity = self.get_sparsity() - if sparsity is None: - sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} - max_channels_per_template = we.get_num_channels() - else: - sparse_channels_indices = sparsity.unit_id_to_channel_indices - max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) + # fit model/models + # TODO : make parralel for by_channel_global and concatenated + if mode == "by_channel_local": + pca_models = self._fit_by_channel_local(n_jobs, progress_bar) + for chan_ind, chan_id in enumerate(self.sorting_result.channel_ids): + self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind] + pca_model = pca_models + elif mode == "by_channel_global": + pca_model = self._fit_by_channel_global(progress_bar) + self.data[f"pca_model_{mode}"] = pca_model + elif mode == "concatenated": + pca_model = self._fit_concatenated(progress_bar) + self.data[f"pca_model_{mode}"] = pca_model - unit_channels = [sparse_channels_indices[unit_id] for unit_id in sorting.unit_ids] - pca_model = self.get_pca_model() - if p["mode"] in ["by_channel_global", "concatenated"]: - pca_model = [pca_model] * recording.get_num_channels() - - # nSpikes, nFeaturesPerChannel, nPCFeatures - # this comes from phy template-gui - # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets - num_spikes = sorting.to_spike_vector().size - shape = (num_spikes, p["n_components"], max_channels_per_template) - all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) - all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) - - # and run - func = _all_pc_extractor_chunk - init_func = _init_work_all_pc_extractor - init_args = ( - recording, - sorting.to_multiprocessing(job_kwargs["n_jobs"]), - all_pcs_args, - we.nbefore, - we.nafter, - unit_channels, - pca_model, - ) - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs) - processor.run() + # transform + waveforms_ext = self.sorting_result.get_extension("waveforms") + some_waveforms = waveforms_ext.data["waveforms"] + spikes = self.sorting_result.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + + pca_projection = self._transform_waveforms(some_spikes, some_waveforms, pca_model, progress_bar) + + self.data["pca_projection"] = pca_projection + + + # def get_data(self): + # """ + # Get computed PCA projections. + + # Returns + # ------- + # all_labels : 1d np.array + # Array with all spike labels + # all_projections : 3d array + # Array with PCA projections (num_spikes, num_components, num_channels) + # """ + # return self.get_all_projections() + + # @staticmethod + # def get_extension_function(): + # return compute_principal_components + + # def run_for_all_spikes(self, file_path=None, **job_kwargs): + # """ + # Project all spikes from the sorting on the PCA model. + # This is a long computation because waveform need to be extracted from each spikes. + + # Used mainly for `export_to_phy()` + + # PCs are exported to a .npy single file. + + # Parameters + # ---------- + # file_path : str or Path or None + # Path to npy file that will store the PCA projections. + # If None, output is saved in principal_components/all_pcs.npy + # {} + # """ + + # job_kwargs = fix_job_kwargs(job_kwargs) + # p = self.params + # we = self.sorting_result + # sorting = we.sorting + # assert ( + # we.has_recording() + # ), "To compute PCA projections for all spikes, the waveform extractor needs the recording" + # recording = we.recording + + # assert sorting.get_num_segments() == 1 + # assert p["mode"] in ("by_channel_local", "by_channel_global") + + # if file_path is None: + # file_path = self.extension_folder / "all_pcs.npy" + # file_path = Path(file_path) + + + # sparsity = self.get_sparsity() + # if sparsity is None: + # sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} + # max_channels_per_template = we.get_num_channels() + # else: + # sparse_channels_indices = sparsity.unit_id_to_channel_indices + # max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) + + # unit_channels = [sparse_channels_indices[unit_id] for unit_id in sorting.unit_ids] + + # pca_model = self.get_pca_model() + # if p["mode"] in ["by_channel_global", "concatenated"]: + # pca_model = [pca_model] * recording.get_num_channels() + + # num_spikes = sorting.to_spike_vector().size + # shape = (num_spikes, p["n_components"], max_channels_per_template) + # all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) + # all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) + + # # and run + # func = _all_pc_extractor_chunk + # init_func = _init_work_all_pc_extractor + # init_args = ( + # recording, + # sorting.to_multiprocessing(job_kwargs["n_jobs"]), + # all_pcs_args, + # we.nbefore, + # we.nafter, + # unit_channels, + # pca_model, + # ) + # processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs) + # processor.run() def _fit_by_channel_local(self, n_jobs, progress_bar): from sklearn.decomposition import IncrementalPCA from concurrent.futures import ProcessPoolExecutor - we = self.sorting_result p = self.params - unit_ids = we.unit_ids - channel_ids = we.channel_ids + unit_ids = self.sorting_result.unit_ids + channel_ids = self.sorting_result.channel_ids # there is one PCA per channel for independent fit per channel pca_models = [IncrementalPCA(n_components=p["n_components"], whiten=p["whiten"]) for _ in channel_ids] - mode = p["mode"] - pca_model_files = [] - - tmp_folder = p["tmp_folder"] - if tmp_folder is None: - if n_jobs > 1: - tmp_folder = tempfile.mkdtemp(prefix="pca", dir=get_global_tmp_folder()) - - for chan_ind, chan_id in enumerate(channel_ids): - pca_model = pca_models[chan_ind] - if n_jobs > 1: - tmp_folder = Path(tmp_folder) - tmp_folder.mkdir(exist_ok=True) - pca_model_file = tmp_folder / f"tmp_pca_model_{mode}_{chan_id}.pkl" - with pca_model_file.open("wb") as f: - pickle.dump(pca_model, f) - pca_model_files.append(pca_model_file) - # fit units_loop = enumerate(unit_ids) if progress_bar: units_loop = tqdm(units_loop, desc="Fitting PCA", total=len(unit_ids)) for unit_ind, unit_id in units_loop: - wfs, channel_inds = self._get_sparse_waveforms(unit_id) + wfs, channel_inds, _ = self._get_sparse_waveforms(unit_id) if len(wfs) < p["n_components"]: continue if n_jobs in (0, 1): @@ -460,71 +364,21 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): pca.partial_fit(wfs[:, :, wf_ind]) else: # parallel - items = [(pca_model_files[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds)] + items = [(chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds)] n_jobs = min(n_jobs, len(items)) with ProcessPoolExecutor(max_workers=n_jobs) as executor: results = executor.map(partial_fit_one_channel, items) - for res in results: - pass - - # reload the models (if n_jobs > 1) - if n_jobs not in (0, 1): - pca_models = [] - for chan_ind, chan_id in enumerate(channel_ids): - pca_model_file = pca_model_files[chan_ind] - with open(pca_model_file, "rb") as fid: - pca_models.append(pickle.load(fid)) - pca_model_file.unlink() - shutil.rmtree(tmp_folder) - - # add models to extension data - for chan_ind, chan_id in enumerate(channel_ids): - pca_model = pca_models[chan_ind] - self._extension_data[f"pca_model_{mode}_{chan_id}"] = pca_model + for chan_ind, pca_model_updated in results: + pca_models[chan_ind] = pca_model_updated return pca_models - def _run_by_channel_local(self, projection_memmap, n_jobs, progress_bar): - """ - In this mode each PCA is "fit" and "transform" by channel. - The output is then (n_spike, n_components, n_channels) - """ - from sklearn.exceptions import NotFittedError - - we = self.sorting_result - unit_ids = we.unit_ids - - pca_model = self._fit_by_channel_local(n_jobs, progress_bar) - - # transform - units_loop = enumerate(unit_ids) - if progress_bar: - units_loop = tqdm(units_loop, desc="Projecting waveforms", total=len(unit_ids)) - - project_on_non_fitted = False - for unit_ind, unit_id in units_loop: - wfs, channel_inds = self._get_sparse_waveforms(unit_id) - if wfs.size == 0: - continue - for wf_ind, chan_ind in enumerate(channel_inds): - pca = pca_model[chan_ind] - try: - proj = pca.transform(wfs[:, :, wf_ind]) - projection_memmap[unit_id][:, :, chan_ind] = proj - except NotFittedError as e: - # this could happen if len(wfs) is less then n_comp for a channel - project_on_non_fitted = True - if project_on_non_fitted: - warnings.warn( - "Projection attempted on unfitted PCA models. This could be due to a small " - "number of waveforms for a particular unit." - ) - def _fit_by_channel_global(self, progress_bar): - we = self.sorting_result + # we = self.sorting_result p = self.params - unit_ids = we.unit_ids + # unit_ids = we.unit_ids + unit_ids = self.sorting_result.unit_ids # there is one unique PCA accross channels from sklearn.decomposition import IncrementalPCA @@ -538,7 +392,7 @@ def _fit_by_channel_global(self, progress_bar): # with 'by_channel_global' we can't parallelize over channels for unit_ind, unit_id in units_loop: - wfs, _ = self._get_sparse_waveforms(unit_id) + wfs, _, _ = self._get_sparse_waveforms(unit_id) shape = wfs.shape if shape[0] * shape[2] < p["n_components"]: continue @@ -546,48 +400,15 @@ def _fit_by_channel_global(self, progress_bar): wfs_concat = wfs.transpose(0, 2, 1).reshape(shape[0] * shape[2], shape[1]) pca_model.partial_fit(wfs_concat) - # save - mode = p["mode"] - self._extension_data[f"pca_model_{mode}"] = pca_model return pca_model - - def _run_by_channel_global(self, projection_objects, n_jobs, progress_bar): - """ - In this mode there is one "fit" for all channels. - The transform is applied by channel. - The output is then (n_spike, n_components, n_channels) - """ - we = self.sorting_result - unit_ids = we.unit_ids - - pca_model = self._fit_by_channel_global(progress_bar) - - # transform - units_loop = enumerate(unit_ids) - if progress_bar: - units_loop = tqdm(units_loop, desc="Projecting waveforms", total=len(unit_ids)) - - # with 'by_channel_global' we can't parallelize over channels - for unit_ind, unit_id in units_loop: - wfs, channel_inds = self._get_sparse_waveforms(unit_id) - if wfs.size == 0: - continue - for wf_ind, chan_ind in enumerate(channel_inds): - proj = pca_model.transform(wfs[:, :, wf_ind]) - projection_objects[unit_id][:, :, chan_ind] = proj - + def _fit_concatenated(self, progress_bar): - we = self.sorting_result + p = self.params - unit_ids = we.unit_ids + unit_ids = self.sorting_result.unit_ids - sparsity = self.get_sparsity() - if sparsity is not None: - sparsity0 = sparsity.unit_id_to_channel_indices[unit_ids[0]] - assert all( - len(chans) == len(sparsity0) for u, chans in sparsity.unit_id_to_channel_indices.items() - ), "When using sparsity in concatenated mode, make sure each unit has the same number of sparse channels" + assert self.sorting_result.sparsity is None, "For mode 'concatenated' waveforms need to be dense" # there is one unique PCA accross channels from sklearn.decomposition import IncrementalPCA @@ -600,59 +421,103 @@ def _fit_concatenated(self, progress_bar): units_loop = tqdm(units_loop, desc="Fitting PCA", total=len(unit_ids)) for unit_ind, unit_id in units_loop: - wfs, _ = self._get_sparse_waveforms(unit_id) + wfs, _, _ = self._get_sparse_waveforms(unit_id) wfs_flat = wfs.reshape(wfs.shape[0], -1) if len(wfs_flat) < p["n_components"]: continue pca_model.partial_fit(wfs_flat) - # save - mode = p["mode"] - self._extension_data[f"pca_model_{mode}"] = pca_model - return pca_model + - def _run_concatenated(self, projection_objects, n_jobs, progress_bar): - """ - In this mode the waveforms are concatenated and there is - a global fit_transform at once. - """ - we = self.sorting_result - p = self.params + def _transform_waveforms(self, spikes, waveforms, pca_model, progress_bar): + # transform a waveforms buffer + # used by _run() and project_new() - unit_ids = we.unit_ids + from sklearn.exceptions import NotFittedError - # there is one unique PCA accross channels - pca_model = self._fit_concatenated(progress_bar) + mode = self.params["mode"] + + # prepare buffer + n_components = self.params["n_components"] + if mode in ("by_channel_local", "by_channel_global"): + shape = (waveforms.shape[0], n_components, waveforms.shape[2]) + elif mode == "concatenated": + shape = (waveforms.shape[0], n_components) + pca_projection = np.zeros(shape, dtype="float32") + + unit_ids = self.sorting_result.unit_ids # transform units_loop = enumerate(unit_ids) if progress_bar: units_loop = tqdm(units_loop, desc="Projecting waveforms", total=len(unit_ids)) - for unit_ind, unit_id in units_loop: - wfs, _ = self._get_sparse_waveforms(unit_id) - wfs_flat = wfs.reshape(wfs.shape[0], -1) - proj = pca_model.transform(wfs_flat) - projection_objects[unit_id][:, :] = proj + if mode == "by_channel_local": + # in this case the model is a list of model + pca_models = pca_model + + project_on_non_fitted = False + for unit_ind, unit_id in units_loop: + wfs, channel_inds, spike_mask = self._get_slice_waveforms(unit_id, spikes, waveforms) + if wfs.size == 0: + continue + for wf_ind, chan_ind in enumerate(channel_inds): + pca_model = pca_models[chan_ind] + try: + proj = pca_model.transform(wfs[:, :, wf_ind]) + pca_projection[:, :, wf_ind][spike_mask, : ] = proj + except NotFittedError as e: + # this could happen if len(wfs) is less then n_comp for a channel + project_on_non_fitted = True + if project_on_non_fitted: + warnings.warn( + "Projection attempted on unfitted PCA models. This could be due to a small " + "number of waveforms for a particular unit." + ) + elif mode == "by_channel_global": + # with 'by_channel_global' we can't parallelize over channels + for unit_ind, unit_id in units_loop: + wfs, channel_inds, spike_mask = self._get_slice_waveforms(unit_id, spikes, waveforms) + if wfs.size == 0: + continue + for wf_ind, chan_ind in enumerate(channel_inds): + proj = pca_model.transform(wfs[:, :, wf_ind]) + pca_projection[:, :, wf_ind][spike_mask, : ] = proj + elif mode == "concatenated": + for unit_ind, unit_id in units_loop: + wfs, channel_inds, spike_mask = self._get_slice_waveforms(unit_id, spikes, waveforms) + wfs_flat = wfs.reshape(wfs.shape[0], -1) + proj = pca_model.transform(wfs_flat) + pca_projection[spike_mask, :] = proj + + return pca_projection - def _get_sparse_waveforms(self, unit_id): - # get waveforms : dense or sparse - we = self.sorting_result - sparsity = self.params["sparsity"] - if we.is_sparse(): - # natural sparsity - wfs = we.get_waveforms(unit_id, lazy=False) - channel_inds = we.sparsity.unit_id_to_channel_indices[unit_id] - elif sparsity is not None: - # injected sparsity - wfs = self.sorting_result.get_waveforms(unit_id, sparsity=sparsity, lazy=False) + def _get_slice_waveforms(self, unit_id, spikes, waveforms): + # slice by mask waveforms from one unit + + unit_index = self.sorting_result.sorting.id_to_index(unit_id) + spike_mask = spikes["unit_index"] == unit_index + wfs = waveforms[spike_mask, :, :] + + sparsity = self.sorting_result.sparsity + if sparsity is not None: channel_inds = sparsity.unit_id_to_channel_indices[unit_id] + wfs = wfs[:, :, :channel_inds.size] else: - # dense - wfs = self.sorting_result.get_waveforms(unit_id, sparsity=None, lazy=False) - channel_inds = np.arange(we.channel_ids.size, dtype=int) - return wfs, channel_inds + channel_inds = np.arange(self.sorting_result.channel_ids.size, dtype=int) + + return wfs, channel_inds, spike_mask + + def _get_sparse_waveforms(self, unit_id): + # get waveforms + channel_inds: dense or sparse + waveforms_ext = self.sorting_result.get_extension("waveforms") + some_waveforms = waveforms_ext.data["waveforms"] + + spikes = self.sorting_result.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + + return self._get_slice_waveforms(unit_id, some_spikes, some_waveforms) def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): @@ -738,10 +603,16 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte compute_principal_components = ComputePrincipalComponents.function_factory() +# def partial_fit_one_channel(args): +# pca_file, wf_chan = args +# with open(pca_file, "rb") as fid: +# pca_model = pickle.load(fid) +# pca_model.partial_fit(wf_chan) +# with pca_file.open("wb") as f: +# pickle.dump(pca_model, f) + def partial_fit_one_channel(args): - pca_file, wf_chan = args - with open(pca_file, "rb") as fid: - pca_model = pickle.load(fid) + chan_ind, pca_model, wf_chan = args pca_model.partial_fit(wf_chan) - with pca_file.open("wb") as f: - pickle.dump(pca_model, f) + return chan_ind, pca_model + diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 94a56db7e2..79c7f575b0 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -63,15 +63,14 @@ class ResultExtensionCommonTestSuite: extension_function_kwargs_list = None def setUp(self): - recording, sorting = get_dataset() - # sparsity is computed once for all cases to save processing - sparsity = estimate_sparsity(recording, sorting) + self.recording, self.sorting = get_dataset() + # sparsity is computed once for all cases to save processing time + self.sparsity = estimate_sparsity(self.recording, self.sorting) self.sorting_results = {} for sparse in (True, False): for format in ("memory", "binary_folder", "zarr"): - sparsity_ = sparsity if sparse else None - sorting_result = get_sorting_result(recording, sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name) + sorting_result = self._prepare_sorting_result(format, sparse) key = f"sparse{sparse}_{format}" self.sorting_results[key] = sorting_result @@ -86,16 +85,20 @@ def tearDown(self): @property def extension_name(self): return self.extension_class.extension_name - - def _check_one(self, sorting_result): + + def _prepare_sorting_result(self, format, sparse): + # prepare a SortingResult object with depencies already computed + sparsity_ = self.sparsity if sparse else None + sorting_result = get_sorting_result(self.recording, self.sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name) sorting_result.select_random_spikes(max_spikes_per_unit=50, seed=2205) - for dependency_name in self.extension_class.depend_on: if "|" in dependency_name: dependency_name = dependency_name.split("|")[0] sorting_result.compute(dependency_name) + return sorting_result + + def _check_one(self, sorting_result): - for kwargs in self.extension_function_kwargs_list: print(' kwargs', kwargs) sorting_result.compute(self.extension_name, **kwargs) @@ -110,7 +113,6 @@ def _check_one(self, sorting_result): def test_extension(self): - for key, sorting_result in self.sorting_results.items(): print() print(self.extension_name, key) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index d40b683a1c..541f7fc4de 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -28,30 +28,32 @@ class PrincipalComponentsExtensionTest(ResultExtensionCommonTestSuite, unittest. extension_function_kwargs_list = [ dict(mode="by_channel_local"), dict(mode="by_channel_global"), - dict(mode="concatenated"), + # mode concatenated cannot be tested here because it do not work with sparse=True ] # TODO : put back theses tests - # def test_shapes(self): - # nchan1 = self.we1.recording.get_num_channels() - # for mode in ("by_channel_local", "by_channel_global"): - # _ = self.extension_class.get_extension_function()(self.we1, mode=mode, n_components=5) - # pc = self.we1.load_extension(self.extension_class.extension_name) - # for unit_id in self.we1.sorting.unit_ids: - # proj = pc.get_projections(unit_id) - # assert proj.shape[1:] == (5, nchan1) - # for mode in ("concatenated",): - # _ = self.extension_class.get_extension_function()(self.we2, mode=mode, n_components=3) - # pc = self.we2.load_extension(self.extension_class.extension_name) - # for unit_id in self.we2.sorting.unit_ids: - # proj = pc.get_projections(unit_id) - # assert proj.shape[1] == 3 + def test_mode_concatenated(self): + # this is tested outside "extension_function_kwargs_list" because it do not support sparsity! + + sorting_result = self._prepare_sorting_result(format="memory", sparse=False) + + n_components = 3 + sorting_result.compute("principal_components", mode="concatenated", n_components=n_components) + ext = sorting_result.get_extension(self.extension_name) + assert ext is not None + assert len(ext.data) > 0 + pca = ext.data["pca_projection"] + assert pca.ndim == 2 + assert pca.shape[1] == n_components # def test_compute_for_all_spikes(self): - # we = self.we1 - # pc = self.extension_class.get_extension_function()(we, load_if_exists=True) - # print(pc) + # sorting_result = self._prepare_sorting_result(format="memory", sparse=False) + + # n_components = 3 + # sorting_result.compute("principal_components", mode="by_channel_local", n_components=n_components) + # ext = sorting_result.get_extension(self.extension_name) + # ext.run_for_all_spikes() # pc_file1 = pc.extension_folder / "all_pc1.npy" # pc.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) @@ -76,139 +78,41 @@ class PrincipalComponentsExtensionTest(ResultExtensionCommonTestSuite, unittest. # pc_unit = all_pc_sparse[all_spikes_seg0["unit_index"] == unit_index] # assert np.allclose(pc_unit[:, :, len(sparse_channel_ids) :], 0) - # def test_sparse(self): - # we = self.we2 - # unit_ids = we.unit_ids - # num_channels = we.get_num_channels() - # pc = self.extension_class(we) - - # sparsity_radius = compute_sparsity(we, method="radius", radius_um=50) - # sparsity_best = compute_sparsity(we, method="best_channels", num_channels=2) - # sparsities = [sparsity_radius, sparsity_best] - # print(sparsities) - - # for mode in ("by_channel_local", "by_channel_global"): - # for sparsity in sparsities: - # pc.set_params(n_components=5, mode=mode, sparsity=sparsity) - # pc.run() - # for i, unit_id in enumerate(unit_ids): - # proj_sparse = pc.get_projections(unit_id, sparse=True) - # assert proj_sparse.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) - # proj_dense = pc.get_projections(unit_id, sparse=False) - # assert proj_dense.shape[1:] == (5, num_channels) - - # # test project_new - # unit_id = 3 - # new_wfs = we.get_waveforms(unit_id) - # new_proj_sparse = pc.project_new(new_wfs, unit_id=unit_id, sparse=True) - # assert new_proj_sparse.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) - # new_proj_dense = pc.project_new(new_wfs, unit_id=unit_id, sparse=False) - # assert new_proj_dense.shape == (new_wfs.shape[0], 5, num_channels) - - # if DEBUG: - # import matplotlib.pyplot as plt - - # plt.ion() - # cmap = plt.get_cmap("jet", len(unit_ids)) - # fig, axs = plt.subplots(nrows=len(unit_ids), ncols=num_channels) - # for i, unit_id in enumerate(unit_ids): - # comp = pc.get_projections(unit_id) - # print(comp.shape) - # for chan_ind in range(num_channels): - # ax = axs[i, chan_ind] - # ax.scatter(comp[:, 0, chan_ind], comp[:, 1, chan_ind], color=cmap(i)) - # ax.set_title(f"{mode}-{sparsity.unit_id_to_channel_ids[unit_id]}") - # if i == 0: - # ax.set_xlabel(f"Ch{chan_ind}") - # plt.show() - - # for mode in ("concatenated",): - # # concatenated is only compatible with "best" - # pc.set_params(n_components=5, mode=mode, sparsity=sparsity_best) - # print(pc) - # pc.run() - # for i, unit_id in enumerate(unit_ids): - # proj = pc.get_projections(unit_id) - # assert proj.shape[1] == 5 - - # # test project_new - # unit_id = 3 - # new_wfs = we.get_waveforms(unit_id) - # new_proj = pc.project_new(new_wfs, unit_id) - # assert new_proj.shape == (len(new_wfs), 5) - - # def test_project_new(self): - # from sklearn.decomposition import IncrementalPCA - - # we = self.we1 - # if we.has_extension("principal_components"): - # we.delete_extension("principal_components") - # we_cp = we.select_units(we.unit_ids, self.cache_folder / "toy_waveforms_1seg_cp") - - # wfs0 = we.get_waveforms(unit_id=we.unit_ids[0]) - # n_samples = wfs0.shape[1] - # n_channels = wfs0.shape[2] - # n_components = 5 - - # # local - # pc_local = compute_principal_components( - # we, n_components=n_components, load_if_exists=True, mode="by_channel_local" - # ) - # pc_local_par = compute_principal_components( - # we_cp, n_components=n_components, load_if_exists=True, mode="by_channel_local", n_jobs=2, progress_bar=True - # ) - - # all_pca = pc_local.get_pca_model() - # all_pca_par = pc_local_par.get_pca_model() - - # assert len(all_pca) == we.get_num_channels() - # assert len(all_pca_par) == we.get_num_channels() - - # for pc, pc_par in zip(all_pca, all_pca_par): - # assert np.allclose(pc.components_, pc_par.components_) - - # # project - # new_waveforms = np.random.randn(100, n_samples, n_channels) - # new_proj = pc_local.project_new(new_waveforms) - - # assert new_proj.shape == (100, n_components, n_channels) - - # # global - # we.delete_extension("principal_components") - # pc_global = compute_principal_components( - # we, n_components=n_components, load_if_exists=True, mode="by_channel_global" - # ) - - # all_pca = pc_global.get_pca_model() - # assert isinstance(all_pca, IncrementalPCA) - - # # project - # new_waveforms = np.random.randn(100, n_samples, n_channels) - # new_proj = pc_global.project_new(new_waveforms) - - # assert new_proj.shape == (100, n_components, n_channels) - - # # concatenated - # we.delete_extension("principal_components") - # pc_concatenated = compute_principal_components( - # we, n_components=n_components, load_if_exists=True, mode="concatenated" - # ) - - # all_pca = pc_concatenated.get_pca_model() - # assert isinstance(all_pca, IncrementalPCA) - - # # project - # new_waveforms = np.random.randn(100, n_samples, n_channels) - # new_proj = pc_concatenated.project_new(new_waveforms) - - # assert new_proj.shape == (100, n_components) + + def test_project_new(self): + from sklearn.decomposition import IncrementalPCA + + sorting_result = self._prepare_sorting_result(format="memory", sparse=False) + + waveforms = sorting_result.get_extension("waveforms").data["waveforms"] + + n_components = 3 + sorting_result.compute("principal_components", mode="by_channel_local", n_components=n_components) + ext_pca = sorting_result.get_extension(self.extension_name) + + + num_spike = 100 + new_spikes = sorting_result.sorting.to_spike_vector()[:num_spike] + new_waveforms = np.random.randn(num_spike, waveforms.shape[1], waveforms.shape[2]) + new_proj = ext_pca.project_new(new_spikes, new_waveforms) + + assert new_proj.shape[0] == num_spike + assert new_proj.shape[1] == n_components + assert new_proj.shape[2] == ext_pca.data["pca_projection"].shape[2] if __name__ == "__main__": test = PrincipalComponentsExtensionTest() test.setUp() - # test.test_extension() - # test.test_shapes() + test.test_extension() + test.test_mode_concatenated() # test.test_compute_for_all_spikes() - # test.test_sparse() - # test.test_project_new() + test.test_project_new() + + + # ext = test.sorting_results["sparseTrue_memory"].get_extension("principal_components") + # pca = ext.data["pca_projection"] + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.scatter(pca[:, 0, 0], pca[:, 0, 1]) + # plt.show() From ec395ad28be0f7f86d3b8c29efce7f486fa3ea6e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 29 Jan 2024 14:38:40 +0100 Subject: [PATCH 034/192] ResultExtensionCommonTestSuite.setUpClass --- .../tests/common_extension_tests.py | 40 +++++++++---------- .../tests/test_amplitude_scalings.py | 5 ++- .../postprocessing/tests/test_isi.py | 10 ++--- 3 files changed, 25 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 79c7f575b0..0be7ca5bd2 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -61,26 +61,20 @@ class ResultExtensionCommonTestSuite: """ extension_class = None extension_function_kwargs_list = None - def setUp(self): - - self.recording, self.sorting = get_dataset() + + @classmethod + def setUpClass(cls): + cls.recording, cls.sorting = get_dataset() # sparsity is computed once for all cases to save processing time - self.sparsity = estimate_sparsity(self.recording, self.sorting) + cls.sparsity = estimate_sparsity(cls.recording, cls.sorting) - self.sorting_results = {} - for sparse in (True, False): - for format in ("memory", "binary_folder", "zarr"): - sorting_result = self._prepare_sorting_result(format, sparse) - key = f"sparse{sparse}_{format}" - self.sorting_results[key] = sorting_result - - def tearDown(self): - for k in list(self.sorting_results.keys()): - sorting_result = self.sorting_results.pop(k) - if sorting_result.format != "memory": - folder = sorting_result.folder - del sorting_result - shutil.rmtree(folder) + # def tearDown(self): + # for k in list(self.sorting_results.keys()): + # sorting_result = self.sorting_results.pop(k) + # if sorting_result.format != "memory": + # folder = sorting_result.folder + # del sorting_result + # shutil.rmtree(folder) @property def extension_name(self): @@ -113,7 +107,9 @@ def _check_one(self, sorting_result): def test_extension(self): - for key, sorting_result in self.sorting_results.items(): - print() - print(self.extension_name, key) - self._check_one(sorting_result) + for sparse in (True, False): + for format in ("memory", "binary_folder", "zarr"): + print() + print("sparse", sparse, format) + sorting_result = self._prepare_sorting_result(format, sparse) + self._check_one(sorting_result) diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index e0abd07edc..4d90f6e6a8 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -16,7 +16,8 @@ class AmplitudeScalingsExtensionTest(ResultExtensionCommonTestSuite, unittest.Te ] def test_scaling_values(self): - sorting_result = self.sorting_results["sparseTrue_memory"] + sorting_result = self._prepare_sorting_result("memory", True) + sorting_result.compute("amplitude_scalings", handle_collisions=False) spikes = sorting_result.sorting.to_spike_vector() @@ -39,4 +40,4 @@ def test_scaling_values(self): test = AmplitudeScalingsExtensionTest() test.setUp() test.test_extension() - test.test_scaling_values() + # test.test_scaling_values() diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 8867de08f3..31d940609f 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -30,11 +30,8 @@ def test_compute_ISI(self): if HAVE_NUMBA: methods.append("numba") - key0 = list(self.sorting_results.keys())[0] - sorting = self.sorting_results[key0].sorting - - _test_ISI(sorting, window_ms=60.0, bin_ms=1.0, methods=methods) - _test_ISI(sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) + _test_ISI(self.sorting, window_ms=60.0, bin_ms=1.0, methods=methods) + _test_ISI(self.sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): @@ -51,7 +48,8 @@ def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): if __name__ == "__main__": test = ComputeISIHistogramsTest() - test.setUp() + # test.setUp() + test.setUpClass() test.test_extension() test.test_compute_ISI() From dca13221308fc4283a09f7db1b592b96023db998 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 29 Jan 2024 18:42:10 +0100 Subject: [PATCH 035/192] Handle job_kwargs for extensions --- src/spikeinterface/core/result_core.py | 33 ++++++++++---- src/spikeinterface/core/sortingresult.py | 45 +++++++++++++++---- .../core/tests/test_result_core.py | 23 ++++++---- .../postprocessing/amplitude_scalings.py | 1 + .../postprocessing/correlograms.py | 1 + src/spikeinterface/postprocessing/isi.py | 1 + .../postprocessing/principal_component.py | 1 + .../postprocessing/spike_amplitudes.py | 1 + .../postprocessing/spike_locations.py | 1 + .../postprocessing/template_metrics.py | 1 + .../postprocessing/template_similarity.py | 1 + .../tests/common_extension_tests.py | 16 ++++--- .../tests/test_amplitude_scalings.py | 6 +-- .../postprocessing/tests/test_correlograms.py | 6 +-- .../postprocessing/tests/test_isi.py | 5 +-- .../tests/test_principal_component.py | 6 +-- .../tests/test_spike_amplitudes.py | 4 +- .../tests/test_spike_locations.py | 4 +- .../tests/test_template_metrics.py | 13 +----- .../tests/test_template_similarity.py | 4 +- .../tests/test_unit_localization.py | 4 +- .../postprocessing/unit_localization.py | 3 +- .../tests/test_quality_metric_calculator.py | 2 +- 23 files changed, 116 insertions(+), 66 deletions(-) diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 77c3ac4d67..4644a26f6e 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -23,6 +23,7 @@ class ComputeWaveforms(ResultExtension): depend_on = [] need_recording = True use_nodepipeline = False + need_job_kwargs = True @property def nbefore(self): @@ -32,7 +33,7 @@ def nbefore(self): def nafter(self): return int(self.params["ms_after"] * self.sorting_result.sampling_frequency / 1000.0) - def _run(self, **kwargs): + def _run(self, **job_kwargs): self.data.clear() if self.sorting_result.random_spikes_indices is None: @@ -61,9 +62,6 @@ def _run(self, **kwargs): else: sparsity_mask = self.sparsity.mask - # TODO propagate some job_kwargs - job_kwargs = dict(n_jobs=-1) - all_waveforms = extract_waveforms_to_single_buffer( recording, some_spikes, @@ -140,6 +138,8 @@ def get_waveforms_one_unit(self, unit_id, force_dense: bool = False,): return wfs + def _get_data(self): + return self.data["waveforms"] @@ -160,8 +160,9 @@ class ComputeTemplates(ResultExtension): depend_on = ["waveforms"] need_recording = False use_nodepipeline = False + need_job_kwargs = False - def _run(self, **kwargs): + def _run(self): unit_ids = self.sorting_result.unit_ids channel_ids = self.sorting_result.channel_ids @@ -247,6 +248,14 @@ def _select_extension_data(self, unit_ids): return new_data + def _get_data(self, operator="average", percentile=None): + if operator != "percentile": + key = operator + else: + assert percentile is not None, "You must provide percentile=..." + key = f"pencentile_{percentile}" + return self.data[key] + compute_templates = ComputeTemplates.function_factory() @@ -262,6 +271,7 @@ class ComputeFastTemplates(ResultExtension): depend_on = [] need_recording = True use_nodepipeline = False + need_job_kwargs = True @property def nbefore(self): @@ -271,7 +281,7 @@ def nbefore(self): def nafter(self): return int(self.params["ms_after"] * self.sorting_result.sampling_frequency / 1000.0) - def _run(self, **kwargs): + def _run(self, **job_kwargs): self.data.clear() if self.sorting_result.random_spikes_indices is None: @@ -288,7 +298,7 @@ def _run(self, **kwargs): return_scaled = self.params["return_scaled"] # TODO jobw_kwargs - self.data["average"] = estimate_templates(recording, some_spikes, unit_ids, self.nbefore, self.nafter, return_scaled=return_scaled) + self.data["average"] = estimate_templates(recording, some_spikes, unit_ids, self.nbefore, self.nafter, return_scaled=return_scaled, **job_kwargs) def _set_params(self, ms_before: float = 1.0, @@ -302,7 +312,8 @@ def _set_params(self, ) return params - + def _get_data(self): + return self.data["average"] def _select_extension_data(self, unit_ids): keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) @@ -340,6 +351,10 @@ class ComputeNoiseLevels(ResultExtension): noise level vector. """ extension_name = "noise_levels" + depend_on = [] + need_recording = True + use_nodepipeline = False + need_job_kwargs = False def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) @@ -355,7 +370,7 @@ def _select_extension_data(self, unit_ids): def _run(self): self.data["noise_levels"] = get_noise_levels(self.sorting_result.recording, **self.params) - # def get_data(self): + # def _get_data(self): # """ # Get computed noise levels. diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 83d7f4e975..b012679d6f 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -22,6 +22,7 @@ from .recording_tools import check_probe_do_not_overlap, get_rec_attributes from .sorting_tools import random_spikes_selection from .core_tools import check_json +from .job_tools import split_job_kwargs from .numpyextractors import SharedMemorySorting from .sparsity import ChannelSparsity, estimate_sparsity from .sortingfolder import NumpyFolderSorting @@ -738,7 +739,7 @@ def get_sorting_property(self, key) -> np.ndarray: - def compute(self, input, save=True, **params): + def compute(self, input, save=True, **kwargs): """ Compute one extension or several extension. Internally calling compute_one_extension() or compute_several_extensions() depending th input type. @@ -750,12 +751,12 @@ def compute(self, input, save=True, **params): If the input is a dict then compute several extension with compute_several_extensions(extensions=input) """ if isinstance(input, str): - self.compute_one_extension(extension_name=input, save=save, **params) + return self.compute_one_extension(extension_name=input, save=save, **kwargs) elif isinstance(input, dict): - assert len(params) == 0, "Too many arguments for SortingResult.compute_several_extensions()" + assert len(kwargs) == 0, "Too many arguments for SortingResult.compute_several_extensions()" self.compute_several_extensions(extensions=input, save=save) - def compute_one_extension(self, extension_name, save=True, **params): + def compute_one_extension(self, extension_name, save=True, **kwargs): """ Compute one extension @@ -769,8 +770,8 @@ def compute_one_extension(self, extension_name, save=True, **params): If not then the extension will only live in memory as long as the object is deleted. save=False is convinient to try some parameters without changing an already saved extension. - **params: - All other kwargs are transimited to extension.set_params() + **kwargs: + All other kwargs are transimited to extension.set_params() or job_kwargs Returns ------- @@ -785,10 +786,18 @@ def compute_one_extension(self, extension_name, save=True, **params): >>> wfs = extension.data["waveforms"] """ + extension_class = get_extension_class(extension_name) + + if extension_class.need_job_kwargs: + params, job_kwargs = split_job_kwargs(kwargs) + else: + params = kwargs + job_kwargs = {} + # check dependencies if extension_class.need_recording: assert self.has_recording(), f"Extension {extension_name} need the recording" @@ -802,7 +811,7 @@ def compute_one_extension(self, extension_name, save=True, **params): extension_instance = extension_class(self) extension_instance.set_params(save=save, **params) - extension_instance.run(save=save) + extension_instance.run(save=save, **job_kwargs) self.extensions[extension_name] = extension_instance @@ -1045,9 +1054,11 @@ class ResultExtension: * depend_on * need_recording * use_nodepipeline + * need_job_kwargs * _set_params() * _run() * _select_extension_data() + * _get_data() The subclass must also set an `extension_name` class attribute which is not None by default. @@ -1063,6 +1074,7 @@ class ResultExtension: depend_on = [] need_recording = False use_nodepipeline = False + need_job_kwargs = False def __init__(self, sorting_result): self._sorting_result = weakref.ref(sorting_result) @@ -1091,6 +1103,10 @@ def _get_pipeline_nodes(self): # must be implemented in subclass only if use_nodepipeline=True raise NotImplementedError + def _get_data(self): + # must be implemented in subclass + raise NotImplementedError + # ####### @@ -1098,17 +1114,27 @@ def _get_pipeline_nodes(self): def function_factory(cls): # make equivalent # comptute_unit_location(sorting_result, ...) <> sorting_result.compute("unit_location", ...) + # this also make backcompatibility + # comptute_unit_location(we, ...) + class FuncWrapper: def __init__(self, extension_name): self.extension_name = extension_name def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): - # backward compatibility with "load_if_exists" + from .waveforms_extractor_backwards_compatibility import MockWaveformExtractor + + if isinstance(sorting_result, MockWaveformExtractor): + # backward compatibility with WaveformsExtractor + sorting_result = sorting_result.sorting_result + if load_if_exists is not None: + # backward compatibility with "load_if_exists" warnings.warn(f"compute_{cls.extension_name}(..., load_if_exists=True/False) is kept for backward compatibility but should not be used anymore") assert isinstance(load_if_exists, bool) if load_if_exists: ext = sorting_result.get_extension(self.extension_name) return ext + ext = sorting_result.compute(cls.extension_name, *args, **kwargs) # TODO be discussed @@ -1118,7 +1144,6 @@ def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): func = FuncWrapper(cls.extension_name) func.__doc__ = cls.__doc__ - # TODO: add load_if_exists return func @property @@ -1386,3 +1411,5 @@ def get_pipeline_nodes(self): assert self.use_nodepipeline, "ResultExtension.get_pipeline_nodes() must be called only when use_nodepipeline=True" return self._get_pipeline_nodes() + def get_data(self, *args, **kwargs): + return self._get_data(*args, **kwargs) diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index fad5795fe8..a7ee0cc322 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -73,8 +73,9 @@ def _check_result_extension(sortres, extension_name): def test_ComputeWaveforms(format, sparse): sortres = get_sorting_result(format=format, sparse=sparse) + job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) sortres.select_random_spikes(max_spikes_per_unit=50, seed=2205) - ext = sortres.compute("waveforms") + ext = sortres.compute("waveforms", **job_kwargs) wfs = ext.data["waveforms"] _check_result_extension(sortres, "waveforms") @@ -90,7 +91,8 @@ def test_ComputeTemplates(format, sparse): # This require "waveforms first and should trig an error sortres.compute("templates") - sortres.compute("waveforms") + job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) + sortres.compute("waveforms", **job_kwargs) sortres.compute("templates", operators=["average", "std", "median", ("percentile", 5.), ("percentile", 95.),]) @@ -118,18 +120,21 @@ def test_ComputeTemplates(format, sparse): def test_ComputeFastTemplates(format, sparse): sortres = get_sorting_result(format=format, sparse=sparse) + # TODO check this because this is not passing with n_jobs=2 + job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) + ms_before=1.0 ms_after=2.5 sortres.select_random_spikes(max_spikes_per_unit=20, seed=2205) - sortres.compute("fast_templates", ms_before=ms_before, ms_after=ms_after, return_scaled=True) + sortres.compute("fast_templates", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) _check_result_extension(sortres, "fast_templates") # compare ComputeTemplates with dense and ComputeFastTemplates: should give the same on "average" other_sortres = get_sorting_result(format=format, sparse=False) other_sortres.select_random_spikes(max_spikes_per_unit=20, seed=2205) - other_sortres.compute("waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True) + other_sortres.compute("waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) other_sortres.compute("templates", operators=["average",]) templates0 = sortres.get_extension("fast_templates").data["average"] @@ -166,11 +171,11 @@ def test_ComputeNoiseLevels(format, sparse): # test_ComputeWaveforms(format="zarr", sparse=True) # test_ComputeWaveforms(format="zarr", sparse=False) - # test_ComputeTemplates(format="memory", sparse=True) - # test_ComputeTemplates(format="memory", sparse=False) - # test_ComputeTemplates(format="binary_folder", sparse=True) - # test_ComputeTemplates(format="zarr", sparse=True) + test_ComputeTemplates(format="memory", sparse=True) + test_ComputeTemplates(format="memory", sparse=False) + test_ComputeTemplates(format="binary_folder", sparse=True) + test_ComputeTemplates(format="zarr", sparse=True) - # test_ComputeFastTemplates(format="memory", sparse=True) + test_ComputeFastTemplates(format="memory", sparse=True) test_ComputeNoiseLevels(format="memory", sparse=False) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 331c122929..dd84861348 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -60,6 +60,7 @@ class ComputeAmplitudeScalings(ResultExtension): depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = True + need_job_kwargs = True def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 8b76f0e8b2..5d1401a983 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -41,6 +41,7 @@ class ComputeCorrelograms(ResultExtension): depend_on = [] need_recording = False use_nodepipeline = False + need_job_kwargs = False def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index a4e2c41818..457e464009 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -36,6 +36,7 @@ class ComputeISIHistograms(ResultExtension): depend_on = [] need_recording = False use_nodepipeline = False + need_job_kwargs = False def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 2c02457168..29c7d554b8 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -59,6 +59,7 @@ class ComputePrincipalComponents(ResultExtension): depend_on = ["waveforms", ] need_recording = False use_nodepipeline = False + need_job_kwargs = True def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 63bbc47411..197e7bf917 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -73,6 +73,7 @@ class ComputeSpikeAmplitudes(ResultExtension): depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = True + need_job_kwargs = True def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 5b88b6a761..dd5c288102 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -55,6 +55,7 @@ class ComputeSpikeLocations(ResultExtension): depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = True + need_job_kwargs = True def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index fae4273a10..0bf7b5ec3f 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -99,6 +99,7 @@ class ComputeTemplateMetrics(ResultExtension): depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = False + need_job_kwargs = False min_channels_for_multi_channel_warning = 10 diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 748e342dc6..8a659a5179 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -24,6 +24,7 @@ class ComputeTemplateSimilarity(ResultExtension): depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = False + need_job_kwargs = False def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 0be7ca5bd2..45133c854d 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -53,14 +53,14 @@ class ResultExtensionCommonTestSuite: """ Common tests with class approach to compute extension on several cases (3 format x 2 sparsity) - This is done a a list of differents parameters (extension_function_kwargs_list). + This is done a a list of differents parameters (extension_function_params_list). This automatically precompute extension dependencies with default params before running computation. This also test the select_units() ability. """ extension_class = None - extension_function_kwargs_list = None + extension_function_params_list = None @classmethod def setUpClass(cls): @@ -92,10 +92,14 @@ def _prepare_sorting_result(self, format, sparse): return sorting_result def _check_one(self, sorting_result): - - for kwargs in self.extension_function_kwargs_list: - print(' kwargs', kwargs) - sorting_result.compute(self.extension_name, **kwargs) + if self.extension_class.need_job_kwargs: + job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) + else: + job_kwargs = dict() + + for params in self.extension_function_params_list: + print(' params', params) + sorting_result.compute(self.extension_name, **params, **job_kwargs) ext = sorting_result.get_extension(self.extension_name) assert ext is not None assert len(ext.data) > 0 diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index 4d90f6e6a8..c43419eb5a 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -10,7 +10,7 @@ class AmplitudeScalingsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeAmplitudeScalings - extension_function_kwargs_list = [ + extension_function_params_list = [ dict(handle_collisions=True), dict(handle_collisions=False), ] @@ -38,6 +38,6 @@ def test_scaling_values(self): if __name__ == "__main__": test = AmplitudeScalingsExtensionTest() - test.setUp() + test.setUpClass() test.test_extension() - # test.test_scaling_values() + test.test_scaling_values() diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 42e3421036..0c17371576 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -18,12 +18,12 @@ class ComputeCorrelogramsTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeCorrelograms - extension_function_kwargs_list = [ + extension_function_params_list = [ dict(method="numpy"), dict(method="auto"), ] if HAVE_NUMBA: - extension_function_kwargs_list.append(dict(method="numba")) + extension_function_params_list.append(dict(method="numba")) @@ -201,5 +201,5 @@ def test_detect_injected_correlation(): # test_detect_injected_correlation() test = ComputeCorrelogramsTest() - test.setUp() + test.setUpClass() test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 31d940609f..618d4a6b06 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -18,12 +18,12 @@ class ComputeISIHistogramsTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeISIHistograms - extension_function_kwargs_list = [ + extension_function_params_list = [ dict(method="numpy"), dict(method="auto"), ] if HAVE_NUMBA: - extension_function_kwargs_list.append(dict(method="numba")) + extension_function_params_list.append(dict(method="numba")) def test_compute_ISI(self): methods = ["numpy", "auto"] @@ -48,7 +48,6 @@ def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): if __name__ == "__main__": test = ComputeISIHistogramsTest() - # test.setUp() test.setUpClass() test.test_extension() test.test_compute_ISI() diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 541f7fc4de..915d08acc7 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -25,7 +25,7 @@ class PrincipalComponentsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputePrincipalComponents - extension_function_kwargs_list = [ + extension_function_params_list = [ dict(mode="by_channel_local"), dict(mode="by_channel_global"), # mode concatenated cannot be tested here because it do not work with sparse=True @@ -34,7 +34,7 @@ class PrincipalComponentsExtensionTest(ResultExtensionCommonTestSuite, unittest. # TODO : put back theses tests def test_mode_concatenated(self): - # this is tested outside "extension_function_kwargs_list" because it do not support sparsity! + # this is tested outside "extension_function_params_list" because it do not support sparsity! sorting_result = self._prepare_sorting_result(format="memory", sparse=False) @@ -103,7 +103,7 @@ def test_project_new(self): if __name__ == "__main__": test = PrincipalComponentsExtensionTest() - test.setUp() + test.setUpClass() test.test_extension() test.test_mode_concatenated() # test.test_compute_for_all_spikes() diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index c4c8bb7974..12b800a8cc 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -7,14 +7,14 @@ class ComputeSpikeAmplitudesTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeSpikeAmplitudes - extension_function_kwargs_list = [ + extension_function_params_list = [ dict(return_scaled=True), dict(return_scaled=False), ] if __name__ == "__main__": test = ComputeSpikeAmplitudesTest() - test.setUp() + test.setUpClass() test.test_extension() # for k, sorting_result in test.sorting_results.items(): diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 98b8d19c2b..b2a5d6c9d5 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -9,7 +9,7 @@ class SpikeLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeSpikeLocations - extension_function_kwargs_list = [ + extension_function_params_list = [ dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)), # chunk_size=10000, n_jobs=1, dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), dict(method="center_of_mass", ), @@ -21,5 +21,5 @@ class SpikeLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestC if __name__ == "__main__": test = SpikeLocationsExtensionTest() - test.setUp() + test.setUpClass() test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 430523cf99..e5d5c73b8e 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -8,24 +8,15 @@ class TemplateMetricsTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeTemplateMetrics - extension_function_kwargs_list = [ + extension_function_params_list = [ dict(), dict(upsampling_factor=2), dict(include_multi_channel_metrics=True), ] - # def test_sparse_metrics(self): - # tm_sparse = self.extension_class.get_extension_function()(self.we1, sparsity=self.sparsity1) - # print(tm_sparse) - - # def test_multi_channel_metrics(self): - # tm_multi = self.extension_class.get_extension_function()(self.we1, include_multi_channel_metrics=True) - # print(tm_multi) if __name__ == "__main__": test = TemplateMetricsTest() - test.setUp() + test.setUpClass() test.test_extension() - # test.test_extension() - # test.test_multi_channel_metrics() diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 646fccf0fa..b169d5fe49 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -7,7 +7,7 @@ class SimilarityExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeTemplateSimilarity - extension_function_kwargs_list = [ + extension_function_params_list = [ dict(method="cosine_similarity"), ] @@ -26,5 +26,5 @@ class SimilarityExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase) if __name__ == "__main__": test = SimilarityExtensionTest() - test.setUp() + test.setUpClass() test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_unit_localization.py b/src/spikeinterface/postprocessing/tests/test_unit_localization.py index 3ce078a195..16d22386f5 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_localization.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_localization.py @@ -6,7 +6,7 @@ class UnitLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeUnitLocations - extension_function_kwargs_list = [ + extension_function_params_list = [ dict(method="center_of_mass", radius_um=100), dict(method="center_of_mass", radius_um=100), dict(method="grid_convolution", radius_um=50), @@ -18,6 +18,6 @@ class UnitLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCa if __name__ == "__main__": test = UnitLocationsExtensionTest() - test.setUp() + test.setUpClass() test.test_extension() # test.tearDown() diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index c78e70cd3f..d6a4159eb5 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -50,6 +50,7 @@ class ComputeUnitLocations(ResultExtension): depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = False + need_job_kwargs = False def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) @@ -63,7 +64,7 @@ def _select_extension_data(self, unit_ids): new_unit_location = self.data["unit_locations"][unit_inds] return dict(unit_locations=new_unit_location) - def _run(self, **job_kwargs): + def _run(self): method = self.params["method"] method_kwargs = self.params["method_kwargs"] diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index b1055a716d..e697c0e762 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -39,7 +39,7 @@ class QualityMetricsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): extension_class = QualityMetricCalculator extension_data_names = ["metrics"] - extension_function_kwargs_list = [dict(), dict(n_jobs=2), dict(metric_names=["snr", "firing_rate"])] + extension_function_params_list = [dict(), dict(n_jobs=2), dict(metric_names=["snr", "firing_rate"])] exact_same_content = False From 49318b1f34ef4215d92dfcf65ec7623f30d6eda8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 29 Jan 2024 20:23:53 +0100 Subject: [PATCH 036/192] ResultExtension._get_data --- src/spikeinterface/core/sortingresult.py | 1 + .../postprocessing/amplitude_scalings.py | 33 +--------- .../postprocessing/correlograms.py | 23 +++---- src/spikeinterface/postprocessing/isi.py | 16 +---- .../postprocessing/principal_component.py | 14 +---- .../postprocessing/spike_amplitudes.py | 63 ++----------------- .../postprocessing/spike_locations.py | 43 ++----------- .../postprocessing/template_metrics.py | 14 +---- .../postprocessing/template_similarity.py | 19 +----- .../tests/common_extension_tests.py | 8 ++- .../postprocessing/unit_localization.py | 25 +------- 11 files changed, 37 insertions(+), 222 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index b012679d6f..ece50e8c41 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -1412,4 +1412,5 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): + assert len(self.data) > 0, f"You must run the extension {self.extension_name} before retrieving data" return self._get_data(*args, **kwargs) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index dd84861348..e2c2f91677 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -199,37 +199,8 @@ def _run(self, **job_kwargs): # # Note: collisions are note in _extension_data because they are not pickable. We only store the indices # self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) - # def get_data(self, outputs="concatenated"): - # """ - # Get computed spike amplitudes. - # Parameters - # ---------- - # outputs : "concatenated" | "by_unit", default: "concatenated" - # The output format - - # Returns - # ------- - # spike_amplitudes : np.array or dict - # The spike amplitudes as an array (outputs="concatenated") or - # as a dict with units as key and spike amplitudes as values. - # """ - # we = self.sorting_result - # sorting = we.sorting - - # if outputs == "concatenated": - # return self._extension_data[f"amplitude_scalings"] - # elif outputs == "by_unit": - # amplitudes_by_unit = [] - # for segment_index in range(we.get_num_segments()): - # amplitudes_by_unit.append({}) - # segment_mask = self.spikes["segment_index"] == segment_index - # spikes_segment = self.spikes[segment_mask] - # amp_scalings_segment = self._extension_data[f"amplitude_scalings"][segment_mask] - # for unit_index, unit_id in enumerate(sorting.unit_ids): - # unit_mask = spikes_segment["unit_index"] == unit_index - # amp_scalings = amp_scalings_segment[unit_mask] - # amplitudes_by_unit[segment_index][unit_id] = amp_scalings - # return amplitudes_by_unit + def _get_data(self): + return self.data[f"amplitude_scalings"] register_result_extension(ComputeAmplitudeScalings) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 5d1401a983..85d212b96e 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -35,6 +35,13 @@ class ComputeCorrelograms(ResultExtension): ccgs[A, B, :] have to be read as the histogram of spiketimesA - spiketimesB bins : np.array The bin edges in ms + + Returns + ------- + isi_histograms : np.array + 2D array with ISI histograms (num_units, num_bins) + bins : np.array + 1D array with bins in ms """ extension_name = "correlograms" @@ -64,20 +71,8 @@ def _run(self): self.data["ccgs"] = ccgs self.data["bins"] = bins - # def get_data(self): - # """ - # Get the computed ISI histograms. - - # Returns - # ------- - # isi_histograms : np.array - # 2D array with ISI histograms (num_units, num_bins) - # bins : np.array - # 1D array with bins in ms - # """ - # msg = "Crosscorrelograms are not computed. Use the 'run()' function." - # assert self.data["ccgs"] is not None and self.data["bins"] is not None, msg - # return self.data["ccgs"], self.data["bins"] + def _get_data(self): + return self.data["ccgs"], self.data["bins"] register_result_extension(ComputeCorrelograms) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 457e464009..f76f5a441f 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -59,20 +59,8 @@ def _run(self): self.data["isi_histograms"] = isi_histograms self.data["bins"] = bins - # def get_data(self): - # """ - # Get the computed ISI histograms. - - # Returns - # ------- - # isi_histograms : np.array - # 2D array with ISI histograms (num_units, num_bins) - # bins : np.array - # 1D array with bins in ms - # """ - # msg = "ISI histograms are not computed. Use the 'run()' function." - # assert self.data["isi_histograms"] is not None and self.data["bins"] is not None, msg - # return self.data["isi_histograms"], self.data["bins"] + def _get_data(self): + return self.data["isi_histograms"], self.data["bins"] register_result_extension(ComputeISIHistograms) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 29c7d554b8..d5d5e5fc36 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -254,18 +254,8 @@ def _run(self, **job_kwargs): self.data["pca_projection"] = pca_projection - # def get_data(self): - # """ - # Get computed PCA projections. - - # Returns - # ------- - # all_labels : 1d np.array - # Array with all spike labels - # all_projections : 3d array - # Array with PCA projections (num_spikes, num_components, num_channels) - # """ - # return self.get_all_projections() + def _get_data(self): + return self.data["pca_projection"] # @staticmethod # def get_extension_function(): diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 197e7bf917..66a127463b 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -42,32 +42,12 @@ class ComputeSpikeAmplitudes(ResultExtension): Other kwargs depending on the method. outputs : "concatenated" | "by_unit", default: "concatenated" The output format - {} Returns ------- - spike_locations: np.array or list of dict - The spike locations. - - If "concatenated" all locations for all spikes and all units are concatenated - - If "by_unit", locations are returned as a list (for segments) of dictionaries (for units) - 1. Determine the max channel per unit. - 2. Then a "peak_shift" is estimated because for some sorters the spike index is not always at the - peak. - 3. Amplitudes are extracted in chunks (parallel or not) + spike_locations: np.array + All locations for all spikes and all units are concatenated - Parameters - ---------- - sorting_result: SortingResult - The SortingResult object - load_if_exists : bool, default: False - Whether to load precomputed spike amplitudes, if they already exist. - peak_sign: "neg" | "pos" | "both", default: "neg - The sign to compute maximum channel - return_scaled: bool - If True and recording has gain_to_uV/offset_to_uV properties, amplitudes are converted to uV. - outputs: "concatenated" | "by_unit", default: "concatenated" - How the output should be returned - {} """ extension_name = "spike_amplitudes" depend_on = ["fast_templates|templates", ] @@ -136,43 +116,8 @@ def _run(self, **job_kwargs): self.data["amplitudes"] = amps - # def get_data(self, outputs="concatenated"): - # """ - # Get computed spike amplitudes. - - # Parameters - # ---------- - # outputs : "concatenated" | "by_unit", default: "concatenated" - # The output format - - # Returns - # ------- - # spike_amplitudes : np.array or dict - # The spike amplitudes as an array (outputs="concatenated") or - # as a dict with units as key and spike amplitudes as values. - # """ - # sorting_result = self.sorting_result - # sorting = sorting_result.sorting - - # if outputs == "concatenated": - # amplitudes = [] - # for segment_index in range(sorting_result.get_num_segments()): - # amplitudes.append(self._extension_data[f"amplitude_segment_{segment_index}"]) - # return amplitudes - # elif outputs == "by_unit": - # all_spikes = sorting.to_spike_vector(concatenated=False) - - # amplitudes_by_unit = [] - # for segment_index in range(sorting_result.get_num_segments()): - # amplitudes_by_unit.append({}) - # for unit_index, unit_id in enumerate(sorting.unit_ids): - # spike_labels = all_spikes[segment_index]["unit_index"] - # mask = spike_labels == unit_index - # amps = self._extension_data[f"amplitude_segment_{segment_index}"][mask] - # amplitudes_by_unit[segment_index][unit_id] = amps - # return amplitudes_by_unit - - + def _get_data(self, outputs="concatenated"): + return self.data["amplitudes"] register_result_extension(ComputeSpikeAmplitudes) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index dd5c288102..158137e09b 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -46,10 +46,9 @@ class ComputeSpikeLocations(ResultExtension): Returns ------- - spike_locations: np.array or list of dict - The spike locations. - - If "concatenated" all locations for all spikes and all units are concatenated - - If "by_unit", locations are returned as a list (for segments) of dictionaries (for units) """ + spike_locations: np.array + All locations for all spikes + """ extension_name = "spike_locations" depend_on = ["fast_templates|templates", ] @@ -120,40 +119,8 @@ def _run(self, **job_kwargs): ) self.data["spike_locations"] = spike_locations - # def get_data(self, outputs="concatenated"): - # """ - # Get computed spike locations - - # Parameters - # ---------- - # outputs : "concatenated" | "by_unit", default: "concatenated" - # The output format - - # Returns - # ------- - # spike_locations : np.array or dict - # The spike locations as a structured array (outputs="concatenated") or - # as a dict with units as key and spike locations as values. - # """ - # we = self.sorting_result - # sorting = we.sorting - - # if outputs == "concatenated": - # return self._extension_data["spike_locations"] - - # elif outputs == "by_unit": - # locations_by_unit = [] - # for segment_index in range(self.sorting_result.get_num_segments()): - # i0 = np.searchsorted(self.spikes["segment_index"], segment_index, side="left") - # i1 = np.searchsorted(self.spikes["segment_index"], segment_index, side="right") - # spikes = self.spikes[i0:i1] - # locations = self._extension_data["spike_locations"][i0:i1] - - # locations_by_unit.append({}) - # for unit_ind, unit_id in enumerate(sorting.unit_ids): - # mask = spikes["unit_index"] == unit_ind - # locations_by_unit[segment_index][unit_id] = locations[mask] - # return locations_by_unit + def _get_data(self, outputs="concatenated"): + return self.data["spike_locations"] ComputeSpikeLocations.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 0bf7b5ec3f..46da504a15 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -262,18 +262,8 @@ def _run(self): template_metrics.at[index, metric_name] = value self.data["metrics"] = template_metrics - # def get_data(self): - # """ - # Get the computed metrics. - - # Returns - # ------- - # metrics : pd.DataFrame - # Dataframe with template metrics - # """ - # msg = "Template metrics are not computed. Use the 'run()' function." - # assert self.data["metrics"] is not None, msg - # return self.data["metrics"] + def _get_data(self): + return self.data["metrics"] register_result_extension(ComputeTemplateMetrics) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 8a659a5179..7aee21e1d1 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -44,22 +44,9 @@ def _run(self): similarity = compute_similarity_with_templates_array(templates_array, templates_array, method=self.params["method"]) self.data["similarity"] = similarity - # def get_data(self): - # """ - # Get the computed similarity. - - # Returns - # ------- - # similarity : 2d np.array - # 2d matrix with computed similarity values. - # """ - # msg = "Template similarity is not computed. Use the 'run()' function." - # assert self._extension_data["similarity"] is not None, msg - # return self._extension_data["similarity"] - - # @staticmethod - # def get_extension_function(): - # return compute_template_similarity + def _get_data(self): + return self.data["similarity"] + register_result_extension(ComputeTemplateSimilarity) compute_template_similarity = ComputeTemplateSimilarity.function_factory() diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 45133c854d..bf3a3ad00f 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -17,7 +17,7 @@ def get_dataset(): recording, sorting = generate_ground_truth_recording( - durations=[30.0, 20.0], sampling_frequency=24000.0, num_channels=10, num_units=5, + durations=[15.0, 5.0], sampling_frequency=24000.0, num_channels=6, num_units=3, generate_sorting_kwargs=dict(firing_rates=3.0, refractory_period_ms=4.0), generate_unit_locations_kwargs=dict( margin_um=5.0, @@ -99,10 +99,12 @@ def _check_one(self, sorting_result): for params in self.extension_function_params_list: print(' params', params) - sorting_result.compute(self.extension_name, **params, **job_kwargs) + ext = sorting_result.compute(self.extension_name, **params, **job_kwargs) + assert len(ext.data) > 0 + main_data = ext.get_data() + ext = sorting_result.get_extension(self.extension_name) assert ext is not None - assert len(ext.data) > 0 some_unit_ids = sorting_result.unit_ids[::2] sliced = sorting_result.select_units(some_unit_ids, format="memory") diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index d6a4159eb5..46e7b9d71b 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -78,29 +78,8 @@ def _run(self): unit_location = compute_monopolar_triangulation(self.sorting_result, **method_kwargs) self.data["unit_locations"] = unit_location - # def get_data(self, outputs="numpy"): - # """ - # Get the computed unit locations. - - # Parameters - # ---------- - # outputs : "numpy" | "by_unit", default: "numpy" - # The output format - - # Returns - # ------- - # unit_locations : np.array or dict - # The unit locations as a Nd array (outputs="numpy") or - # as a dict with units as key and locations as values. - # """ - # if outputs == "numpy": - # return self.data["unit_locations"] - - # elif outputs == "by_unit": - # locations_by_unit = {} - # for unit_ind, unit_id in enumerate(self.sorting_result.sorting.unit_ids): - # locations_by_unit[unit_id] = self.data["unit_locations"][unit_ind] - # return locations_by_unit + def get_data(self, outputs="numpy"): + return self.data["unit_locations"] register_result_extension(ComputeUnitLocations) From a4e355e62bb9dafa899dfc760de66425dcb157f9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 1 Feb 2024 15:57:14 +0100 Subject: [PATCH 037/192] Implement nodepipeline several extensions at once --- src/spikeinterface/core/__init__.py | 6 +- src/spikeinterface/core/sortingresult.py | 66 ++++++++++++++++--- src/spikeinterface/core/sparsity.py | 1 + src/spikeinterface/core/waveform_tools.py | 5 +- ...forms_extractor_backwards_compatibility.py | 3 +- src/spikeinterface/full.py | 11 ++-- .../postprocessing/amplitude_scalings.py | 1 + .../postprocessing/spike_amplitudes.py | 1 + .../postprocessing/spike_locations.py | 1 + 9 files changed, 77 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 1249c76143..965e9135a7 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -111,8 +111,8 @@ from .waveform_extractor import ( WaveformExtractor, BaseWaveformExtractorExtension, - extract_waveforms, - load_waveforms, + # extract_waveforms, + # load_waveforms, precompute_sparsity, ) @@ -154,5 +154,5 @@ # Important not for compatibility!! # This wil be uncommented after 0.100 -# from .waveforms_extractor_backwards_compatibility import extract_waveforms, load_waveforms +from .waveforms_extractor_backwards_compatibility import extract_waveforms, load_waveforms diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index ece50e8c41..239d8d5969 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -27,6 +27,7 @@ from .sparsity import ChannelSparsity, estimate_sparsity from .sortingfolder import NumpyFolderSorting from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor +from .node_pipeline import run_node_pipeline @@ -753,8 +754,9 @@ def compute(self, input, save=True, **kwargs): if isinstance(input, str): return self.compute_one_extension(extension_name=input, save=save, **kwargs) elif isinstance(input, dict): - assert len(kwargs) == 0, "Too many arguments for SortingResult.compute_several_extensions()" - self.compute_several_extensions(extensions=input, save=save) + params_, job_kwargs = split_job_kwargs(kwargs) + assert len(params_) == 0, "Too many arguments for SortingResult.compute_several_extensions()" + self.compute_several_extensions(extensions=input, save=save, **job_kwargs) def compute_one_extension(self, extension_name, save=True, **kwargs): """ @@ -820,7 +822,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): # OR return extension_instance.data - def compute_several_extensions(self, extensions, save=True): + def compute_several_extensions(self, extensions, save=True, **job_kwargs): """ Compute several extensions @@ -846,8 +848,57 @@ def compute_several_extensions(self, extensions, save=True): """ # TODO this is a simple implementation # this will be improved with nodepipeline!!! + + pipeline_mode = True for extension_name, extension_params in extensions.items(): - self.compute_one_extension(self, extension_name, save=save, **extension_params) + extension_class = get_extension_class(extension_name) + if not extension_class.use_nodepipeline: + pipeline_mode = False + break + + if not pipeline_mode: + # simple loop + for extension_name, extension_params in extensions.items(): + extension_class = get_extension_class(extension_name) + if extension_class.need_job_kwargs: + self.compute_one_extension(extension_name, save=save, **extension_params) + else: + self.compute_one_extension(extension_name, save=save, **extension_params) + else: + + all_nodes = [] + result_routage = [] + extension_instances = {} + for extension_name, extension_params in extensions.items(): + extension_class = get_extension_class(extension_name) + assert self.has_recording(), f"Extension {extension_name} need the recording" + + for variable_name in extension_class.nodepipeline_variables: + result_routage.append((extension_name, variable_name)) + + extension_instance = extension_class(self) + extension_instance.set_params(save=save, **extension_params) + extension_instances[extension_name] = extension_instance + + nodes = extension_instance.get_pipeline_nodes() + all_nodes.extend(nodes) + + job_name = "Compute : " + " + ".join(extensions.keys()) + results = run_node_pipeline( + self.recording, all_nodes, job_kwargs=job_kwargs, job_name=job_name, gather_mode="memory" + ) + + for r, result in enumerate(results): + extension_name, variable_name = result_routage[r] + extension_instances[extension_name].data[variable_name] = result + + + for extension_name, extension_instance in extension_instances.items(): + self.extensions[extension_name] = extension_instance + if save: + extension_instance.save() + + def get_saved_extension_names(self): @@ -1054,6 +1105,7 @@ class ResultExtension: * depend_on * need_recording * use_nodepipeline + * nodepipeline_variables only if use_nodepipeline=True * need_job_kwargs * _set_params() * _run() @@ -1074,6 +1126,7 @@ class ResultExtension: depend_on = [] need_recording = False use_nodepipeline = False + nodepipeline_variables = None need_job_kwargs = False def __init__(self, sorting_result): @@ -1137,10 +1190,7 @@ def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): ext = sorting_result.compute(cls.extension_name, *args, **kwargs) - # TODO be discussed - return ext - # return ext.data - # return ext.get_data() + return ext.get_data() func = FuncWrapper(cls.extension_name) func.__doc__ = cls.__doc__ diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 8d3fbb5e5a..cd2d3d3647 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -543,6 +543,7 @@ def estimate_sparsity( nbefore, nafter, return_scaled=False, + job_name="estimate_sparsity", **job_kwargs ) templates = Templates( diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 1fee9a44a1..a30a0ab8f5 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -703,6 +703,7 @@ def estimate_templates( nbefore, nafter, return_scaled=True, + job_name=None, **job_kwargs ): """ @@ -767,7 +768,9 @@ def estimate_templates( ) - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="estimate_templates", **job_kwargs) + if job_name is None: + job_name = "estimate_templates" + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() # average diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 93c27cded2..ff9f694b52 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -72,7 +72,8 @@ def extract_waveforms( folder = Path(folder) format = "binary_folder" else: - mode = "memory" + folder = None + format = "memory" assert sparsity_temp_folder is None, "sparsity_temp_folder must be None" assert unit_batch_size is None, "unit_batch_size must be None" diff --git a/src/spikeinterface/full.py b/src/spikeinterface/full.py index 57cad51769..b9034f00f1 100644 --- a/src/spikeinterface/full.py +++ b/src/spikeinterface/full.py @@ -18,9 +18,10 @@ from .sorters import * from .preprocessing import * from .postprocessing import * -from .qualitymetrics import * -from .curation import * -from .comparison import * -from .widgets import * -from .exporters import * +# TODO +#from .qualitymetrics import * +# from .curation import * +# from .comparison import * +# from .widgets import * +# from .exporters import * from .generation import * diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index e2c2f91677..820eb7b51a 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -60,6 +60,7 @@ class ComputeAmplitudeScalings(ResultExtension): depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = True + nodepipeline_variables = ["amplitude_scalings", "collision_mask"] need_job_kwargs = True def __init__(self, sorting_result): diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 66a127463b..f674688749 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -53,6 +53,7 @@ class ComputeSpikeAmplitudes(ResultExtension): depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = True + nodepipeline_variables = ["amplitudes"] need_job_kwargs = True def __init__(self, sorting_result): diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 158137e09b..b0f94fe588 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -54,6 +54,7 @@ class ComputeSpikeLocations(ResultExtension): depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = True + nodepipeline_variables = ["spike_locations"] need_job_kwargs = True def __init__(self, sorting_result): From b57963d4648b59cf26a938e45fbc4625b9be43ea Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 7 Feb 2024 18:19:36 +0100 Subject: [PATCH 038/192] port quality metrics on sorting result --- src/spikeinterface/core/result_core.py | 12 +- src/spikeinterface/full.py | 2 +- src/spikeinterface/qualitymetrics/__init__.py | 2 +- .../qualitymetrics/misc_metrics.py | 425 +++++++------ .../qualitymetrics/pca_metrics.py | 177 +++--- .../quality_metric_calculator.py | 161 ++--- .../tests/test_metrics_functions.py | 314 +++++---- .../qualitymetrics/tests/test_pca_metrics.py | 68 ++ .../tests/test_quality_metric_calculator.py | 599 +++++++++--------- 9 files changed, 883 insertions(+), 877 deletions(-) create mode 100644 src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 4644a26f6e..ba31c19b25 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -370,16 +370,8 @@ def _select_extension_data(self, unit_ids): def _run(self): self.data["noise_levels"] = get_noise_levels(self.sorting_result.recording, **self.params) - # def _get_data(self): - # """ - # Get computed noise levels. - - # Returns - # ------- - # noise_levels : np.array - # The noise levels associated to each channel. - # """ - # return self._extension_data["noise_levels"] + def _get_data(self): + return self.data["noise_levels"] register_result_extension(ComputeNoiseLevels) diff --git a/src/spikeinterface/full.py b/src/spikeinterface/full.py index b9034f00f1..7719319b06 100644 --- a/src/spikeinterface/full.py +++ b/src/spikeinterface/full.py @@ -18,8 +18,8 @@ from .sorters import * from .preprocessing import * from .postprocessing import * +from .qualitymetrics import * # TODO -#from .qualitymetrics import * # from .curation import * # from .comparison import * # from .widgets import * diff --git a/src/spikeinterface/qualitymetrics/__init__.py b/src/spikeinterface/qualitymetrics/__init__.py index 23d3894c03..ce477ce8fb 100644 --- a/src/spikeinterface/qualitymetrics/__init__.py +++ b/src/spikeinterface/qualitymetrics/__init__.py @@ -2,7 +2,7 @@ from .quality_metric_calculator import ( compute_quality_metrics, get_quality_metric_list, - QualityMetricCalculator, + ComputeQualityMetrics, get_default_qm_params, ) from .pca_metrics import get_quality_pca_metric_list diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index fe5f52af6e..41811f84d8 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -13,13 +13,16 @@ import numpy as np import warnings -from ..postprocessing import compute_spike_amplitudes, correlogram_for_one_segment -from ..core import WaveformExtractor, get_noise_levels +from ..postprocessing import correlogram_for_one_segment +from ..core import SortingResult, get_noise_levels from ..core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, + _get_dense_templates_array, ) + + try: import numba @@ -31,13 +34,13 @@ _default_params = dict() -def compute_num_spikes(waveform_extractor, unit_ids=None, **kwargs): +def compute_num_spikes(sorting_result, unit_ids=None, **kwargs): """Compute the number of spike across segments. Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object unit_ids : list or None The list of unit ids to compute the number of spikes. If None, all units are used. @@ -47,7 +50,7 @@ def compute_num_spikes(waveform_extractor, unit_ids=None, **kwargs): The number of spikes, across all segments, for each unit ID. """ - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting if unit_ids is None: unit_ids = sorting.unit_ids num_segs = sorting.get_num_segments() @@ -63,13 +66,13 @@ def compute_num_spikes(waveform_extractor, unit_ids=None, **kwargs): return num_spikes -def compute_firing_rates(waveform_extractor, unit_ids=None, **kwargs): +def compute_firing_rates(sorting_result, unit_ids=None, **kwargs): """Compute the firing rate across segments. Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object unit_ids : list or None The list of unit ids to compute the firing rate. If None, all units are used. @@ -79,25 +82,25 @@ def compute_firing_rates(waveform_extractor, unit_ids=None, **kwargs): The firing rate, across all segments, for each unit ID. """ - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting if unit_ids is None: unit_ids = sorting.unit_ids - total_duration = waveform_extractor.get_total_duration() + total_duration = sorting_result.get_total_duration() firing_rates = {} - num_spikes = compute_num_spikes(waveform_extractor) + num_spikes = compute_num_spikes(sorting_result) for unit_id in unit_ids: firing_rates[unit_id] = num_spikes[unit_id] / total_duration return firing_rates -def compute_presence_ratios(waveform_extractor, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None, **kwargs): +def compute_presence_ratios(sorting_result, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None, **kwargs): """Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object bin_duration_s : float, default: 60 The duration of each bin in seconds. If the duration is less than this value, presence_ratio is set to NaN @@ -117,15 +120,15 @@ def compute_presence_ratios(waveform_extractor, bin_duration_s=60.0, mean_fr_rat The total duration, across all segments, is divided into "num_bins". To do so, spike trains across segments are concatenated to mimic a continuous segment. """ - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting if unit_ids is None: - unit_ids = sorting.unit_ids - num_segs = sorting.get_num_segments() + unit_ids = sorting_result.unit_ids + num_segs = sorting_result.get_num_segments() - seg_lengths = [waveform_extractor.get_num_samples(i) for i in range(num_segs)] - total_length = waveform_extractor.get_total_samples() - total_duration = waveform_extractor.get_total_duration() - bin_duration_samples = int((bin_duration_s * waveform_extractor.sampling_frequency)) + seg_lengths = [sorting_result.get_num_samples(i) for i in range(num_segs)] + total_length = sorting_result.get_total_samples() + total_duration = sorting_result.get_total_duration() + bin_duration_samples = int((bin_duration_s * sorting_result.sampling_frequency)) num_bin_edges = total_length // bin_duration_samples + 1 bin_edges = np.arange(num_bin_edges) * bin_duration_samples @@ -172,27 +175,23 @@ def compute_presence_ratios(waveform_extractor, bin_duration_s=60.0, mean_fr_rat def compute_snrs( - waveform_extractor, + sorting_result, peak_sign: str = "neg", peak_mode: str = "extremum", - random_chunk_kwargs_dict=None, unit_ids=None, ): """Compute signal to noise ratio. Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the template to compute best channels. peak_mode: "extremum" | "at_index", default: "extremum" How to compute the amplitude. Extremum takes the maxima/minima - At_index takes the value at t=waveform_extractor.nbefore - random_chunk_kwarg_dict: dict or None - Dictionary to control the get_random_data_chunks() function. - If None, default values are used + At_index takes the value at t=sorting_result.nbefore unit_ids : list or None The list of unit ids to compute the SNR. If None, all units are used. @@ -201,25 +200,18 @@ def compute_snrs( snrs : dict Computed signal to noise ratio for each unit. """ - if waveform_extractor.has_extension("noise_levels"): - noise_levels = waveform_extractor.load_extension("noise_levels").get_data() - else: - if random_chunk_kwargs_dict is None: - random_chunk_kwargs_dict = {} - noise_levels = get_noise_levels( - waveform_extractor.recording, return_scaled=waveform_extractor.return_scaled, **random_chunk_kwargs_dict - ) + assert sorting_result.has_extension("noise_levels") + noise_levels = sorting_result.get_extension("noise_levels").get_data() assert peak_sign in ("neg", "pos", "both") assert peak_mode in ("extremum", "at_index") - sorting = waveform_extractor.sorting if unit_ids is None: - unit_ids = sorting.unit_ids - channel_ids = waveform_extractor.channel_ids + unit_ids = sorting_result.unit_ids + channel_ids = sorting_result.channel_ids - extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign, mode=peak_mode) - unit_amplitudes = get_template_extremum_amplitude(waveform_extractor, peak_sign=peak_sign, mode=peak_mode) + extremum_channels_ids = get_template_extremum_channel(sorting_result, peak_sign=peak_sign, mode=peak_mode) + unit_amplitudes = get_template_extremum_amplitude(sorting_result, peak_sign=peak_sign, mode=peak_mode) # make a dict to access by chan_id noise_levels = dict(zip(channel_ids, noise_levels)) @@ -234,10 +226,10 @@ def compute_snrs( return snrs -_default_params["snr"] = dict(peak_sign="neg", peak_mode="extremum", random_chunk_kwargs_dict=None) +_default_params["snr"] = dict(peak_sign="neg", peak_mode="extremum") -def compute_isi_violations(waveform_extractor, isi_threshold_ms=1.5, min_isi_ms=0, unit_ids=None): +def compute_isi_violations(sorting_result, isi_threshold_ms=1.5, min_isi_ms=0, unit_ids=None): """Calculate Inter-Spike Interval (ISI) violations. It computes several metrics related to isi violations: @@ -247,8 +239,8 @@ def compute_isi_violations(waveform_extractor, isi_threshold_ms=1.5, min_isi_ms= Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object + sorting_result : SortingResult + The SortingResult object isi_threshold_ms : float, default: 1.5 Threshold for classifying adjacent spikes as an ISI violation, in ms. This is the biophysical refractory period @@ -281,13 +273,13 @@ def compute_isi_violations(waveform_extractor, isi_threshold_ms=1.5, min_isi_ms= """ res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting if unit_ids is None: - unit_ids = sorting.unit_ids - num_segs = sorting.get_num_segments() + unit_ids = sorting_result.unit_ids + num_segs = sorting_result.get_num_segments() - total_duration_s = waveform_extractor.get_total_duration() - fs = waveform_extractor.sampling_frequency + total_duration_s = sorting_result.get_total_duration() + fs = sorting_result.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 min_isi_s = min_isi_ms / 1000 @@ -319,7 +311,7 @@ def compute_isi_violations(waveform_extractor, isi_threshold_ms=1.5, min_isi_ms= def compute_refrac_period_violations( - waveform_extractor, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, unit_ids=None + sorting_result, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, unit_ids=None ): """Calculates the number of refractory period violations. @@ -330,8 +322,8 @@ def compute_refrac_period_violations( Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object + sorting_result : SortingResult + The SortingResult object refractory_period_ms : float, default: 1.0 The period (in ms) where no 2 good spikes can occur. censored_period_ms : float, default: 0.0 @@ -363,17 +355,17 @@ def compute_refrac_period_violations( print("compute_refrac_period_violations cannot run without numba.") return None - sorting = waveform_extractor.sorting - fs = sorting.get_sampling_frequency() - num_units = len(sorting.unit_ids) - num_segments = sorting.get_num_segments() + sorting = sorting_result.sorting + fs = sorting_result.sampling_frequency + num_units = len(sorting_result.unit_ids) + num_segments = sorting_result.get_num_segments() spikes = sorting.to_spike_vector(concatenated=False) if unit_ids is None: - unit_ids = sorting.unit_ids + unit_ids = sorting_result.unit_ids - num_spikes = compute_num_spikes(waveform_extractor) + num_spikes = compute_num_spikes(sorting_result) t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) @@ -384,7 +376,7 @@ def compute_refrac_period_violations( spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) - T = waveform_extractor.get_total_samples() + T = sorting_result.get_total_samples() nb_violations = {} rp_contamination = {} @@ -408,7 +400,7 @@ def compute_refrac_period_violations( def compute_sliding_rp_violations( - waveform_extractor, + sorting_result, min_spikes=0, bin_size_ms=0.25, window_size_s=1, @@ -423,8 +415,8 @@ def compute_sliding_rp_violations( Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object min_spikes : int, default: 0 Contamination is set to np.nan if the unit has less than this many spikes across all segments. @@ -452,12 +444,12 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ - duration = waveform_extractor.get_total_duration() - sorting = waveform_extractor.sorting + duration = sorting_result.get_total_duration() + sorting = sorting_result.sorting if unit_ids is None: - unit_ids = sorting.unit_ids - num_segs = sorting.get_num_segments() - fs = waveform_extractor.sampling_frequency + unit_ids = sorting_result.unit_ids + num_segs = sorting_result.get_num_segments() + fs = sorting_result.sampling_frequency contamination = {} @@ -502,14 +494,14 @@ def compute_sliding_rp_violations( ) -def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): +def compute_synchrony_metrics(sorting_result, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): """Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of "synchrony_size" spikes at the exact same sample index. Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object synchrony_sizes : list or tuple, default: (2, 4, 8) The synchrony sizes to compute. unit_ids : list or None, default: None @@ -527,17 +519,17 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), uni This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" - spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit(outputs="dict") - sorting = waveform_extractor.sorting + spike_counts = sorting_result.sorting.count_num_spikes_per_unit(outputs="dict") + sorting = sorting_result.sorting spikes = sorting.to_spike_vector(concatenated=False) if unit_ids is None: - unit_ids = sorting.unit_ids + unit_ids = sorting_result.unit_ids # Pre-allocate synchrony counts synchrony_counts = {} for synchrony_size in synchrony_sizes: - synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64) + synchrony_counts[synchrony_size] = np.zeros(len(sorting_result.unit_ids), dtype=np.int64) all_unit_ids = list(sorting.unit_ids) for segment_index in range(sorting.get_num_segments()): @@ -575,14 +567,14 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), uni _default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) -def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): +def compute_firing_ranges(sorting_result, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): """Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object bin_size_s : float, default: 5 The size of the bin in seconds. percentiles : tuple, default: (5, 95) @@ -599,16 +591,16 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), ----- Designed by Simon Musall and ported to SpikeInterface by Alessio Buccino. """ - sampling_frequency = waveform_extractor.sampling_frequency + sampling_frequency = sorting_result.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting if unit_ids is None: unit_ids = sorting.unit_ids if all( [ - waveform_extractor.get_num_samples(segment_index) < bin_size_samples - for segment_index in range(waveform_extractor.get_num_segments()) + sorting_result.get_num_samples(segment_index) < bin_size_samples + for segment_index in range(sorting_result.get_num_segments()) ] ): warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.") @@ -616,8 +608,8 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} - for segment_index in range(waveform_extractor.get_num_segments()): - num_samples = waveform_extractor.get_num_samples(segment_index) + for segment_index in range(sorting_result.get_num_segments()): + num_samples = sorting_result.get_num_samples(segment_index) edges = np.arange(0, num_samples + 1, bin_size_samples) for unit_id in unit_ids: @@ -640,7 +632,7 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), def compute_amplitude_cv_metrics( - waveform_extractor, + sorting_result, average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, @@ -653,8 +645,8 @@ def compute_amplitude_cv_metrics( Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object average_num_spikes_per_bin : int, default: 50 The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is @@ -683,26 +675,23 @@ def compute_amplitude_cv_metrics( "spike_amplitudes", "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" - sorting = waveform_extractor.sorting - total_duration = waveform_extractor.get_total_duration() + sorting = sorting_result.sorting + total_duration = sorting_result.get_total_duration() spikes = sorting.to_spike_vector() num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.has_extension(amplitude_extension): - sac = waveform_extractor.load_extension(amplitude_extension) - amps = sac.get_data(outputs="concatenated") - if amplitude_extension == "spike_amplitudes": - amps = np.concatenate(amps) + if sorting_result.has_extension(amplitude_extension): + amps = sorting_result.get_extension(amplitude_extension).get_data() else: - warnings.warn("") + warnings.warn("compute_amplitude_cv_metrics() need 'spike_amplitudes' or 'amplitude_scalings'") empty_dict = {unit_id: np.nan for unit_id in unit_ids} return empty_dict # precompute segment slice segment_slices = [] - for segment_index in range(waveform_extractor.get_num_segments()): + for segment_index in range(sorting_result.get_num_segments()): i0 = np.searchsorted(spikes["segment_index"], segment_index) i1 = np.searchsorted(spikes["segment_index"], segment_index + 1) segment_slices.append(slice(i0, i1)) @@ -712,14 +701,14 @@ def compute_amplitude_cv_metrics( for unit_id in unit_ids: firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( - (average_num_spikes_per_bin / firing_rate) * waveform_extractor.sampling_frequency + (average_num_spikes_per_bin / firing_rate) * sorting_result.sampling_frequency ) amp_spreads = [] # bins and amplitude means are computed for each segment - for segment_index in range(waveform_extractor.get_num_segments()): + for segment_index in range(sorting_result.get_num_segments()): sample_bin_edges = np.arange( - 0, waveform_extractor.get_num_samples(segment_index) + 1, temporal_bin_size_samples + 0, sorting_result.get_num_samples(segment_index) + 1, temporal_bin_size_samples ) spikes_in_segment = spikes[segment_slices[segment_index]] amps_in_segment = amps[segment_slices[segment_index]] @@ -749,8 +738,35 @@ def compute_amplitude_cv_metrics( ) +def _get_amplitudes_by_units(sorting_result, unit_ids, peak_sign): + # used by compute_amplitude_cutoffs and compute_amplitude_medians + amplitudes_by_units = {} + if sorting_result.has_extension("spike_amplitudes"): + spikes = sorting_result.sorting.to_spike_vector() + ext = sorting_result.get_extension("spike_amplitudes") + all_amplitudes = ext.get_data() + for unit_id in unit_ids: + unit_index = sorting_result.sorting.id_to_index(unit_id) + spike_mask = spikes["unit_index"] ==unit_index + amplitudes_by_units[unit_id] = all_amplitudes[spike_mask] + + elif sorting_result.has_extension("waveforms"): + waveforms_ext = sorting_result.get_extension("waveforms") + before = waveforms_ext.nbefore + extremum_channels_ids = get_template_extremum_channel(sorting_result, peak_sign=peak_sign) + for unit_id in unit_ids: + waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) + chan_id = extremum_channels_ids[unit_id] + if sorting_result.is_sparse(): + chan_ind = np.where(sorting_result.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] + else: + chan_ind = sorting_result.channel_ids_to_indices([chan_id])[0] + amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] + + return amplitudes_by_units + def compute_amplitude_cutoffs( - waveform_extractor, + sorting_result, peak_sign="neg", num_histogram_bins=500, histogram_smoothing_value=3, @@ -761,8 +777,8 @@ def compute_amplitude_cutoffs( Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the peaks. num_histogram_bins : int, default: 100 @@ -786,7 +802,7 @@ def compute_amplitude_cutoffs( ----- This approach assumes the amplitude histogram is symmetric (not valid in the presence of drift). If available, amplitudes are extracted from the "spike_amplitude" extension (recommended). - If the "spike_amplitude" extension is not available, the amplitudes are extracted from the waveform extractor, + If the "spike_amplitude" extension is not available, the amplitudes are extracted from the SortingResult, which usually has waveforms for a small subset of spikes (500 by default). References @@ -797,52 +813,37 @@ def compute_amplitude_cutoffs( https://github.com/AllenInstitute/ecephys_spike_sorting/tree/master/ecephys_spike_sorting/modules/quality_metrics """ - sorting = waveform_extractor.sorting if unit_ids is None: - unit_ids = sorting.unit_ids + unit_ids = sorting_result.unit_ids - before = waveform_extractor.nbefore - extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign) + + all_fraction_missing = {} + if sorting_result.has_extension("spike_amplitudes") or sorting_result.has_extension("waveforms"): - spike_amplitudes = None - invert_amplitudes = False - if waveform_extractor.has_extension("spike_amplitudes"): - amp_calculator = waveform_extractor.load_extension("spike_amplitudes") - spike_amplitudes = amp_calculator.get_data(outputs="by_unit") - if amp_calculator._params["peak_sign"] == "pos": - invert_amplitudes = True - else: - if peak_sign == "pos": + invert_amplitudes = False + if sorting_result.has_extension("spike_amplitudes") and sorting_result.get_extension("spike_amplitudes").params["peak_sign"] == "pos": invert_amplitudes = True + elif sorting_result.has_extension("waveforms") and peak_sign == "pos": + invert_amplitudes = True - all_fraction_missing = {} - nan_units = [] - for unit_id in unit_ids: - if spike_amplitudes is None: - waveforms = waveform_extractor.get_waveforms(unit_id) - chan_id = extremum_channels_ids[unit_id] - if waveform_extractor.is_sparse(): - chan_ind = np.where(waveform_extractor.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] - else: - chan_ind = waveform_extractor.channel_ids_to_indices([chan_id])[0] - amplitudes = waveforms[:, before, chan_ind] - else: - amplitudes = np.concatenate([spike_amps[unit_id] for spike_amps in spike_amplitudes]) + amplitudes_by_units = _get_amplitudes_by_units(sorting_result, unit_ids, peak_sign) - # change amplitudes signs in case peak_sign is pos - if invert_amplitudes: - amplitudes = -amplitudes + for unit_id in unit_ids: + amplitudes = amplitudes_by_units[unit_id] + if invert_amplitudes: + amplitudes = -amplitudes - fraction_missing = amplitude_cutoff( - amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio - ) - if np.isnan(fraction_missing): - nan_units.append(unit_id) + all_fraction_missing[unit_id] = amplitude_cutoff( + amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio + ) - all_fraction_missing[unit_id] = fraction_missing + if np.any(np.isnan(list(all_fraction_missing.values()))): + warnings.warn(f"Some units have too few spikes : amplitude_cutoff is set to NaN") - if len(nan_units) > 0: - warnings.warn(f"Units {nan_units} have too few spikes and " "amplitude_cutoff is set to NaN") + else: + warnings.warn("compute_amplitude_cutoffs need 'spike_amplitudes' or 'waveforms' extension") + for unit_id in unit_ids: + all_fraction_missing[unit_id] = np.nan return all_fraction_missing @@ -852,13 +853,13 @@ def compute_amplitude_cutoffs( ) -def compute_amplitude_medians(waveform_extractor, peak_sign="neg", unit_ids=None): +def compute_amplitude_medians(sorting_result, peak_sign="neg", unit_ids=None): """Compute median of the amplitude distributions (in absolute value). Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the peaks. unit_ids : list or None @@ -875,44 +876,27 @@ def compute_amplitude_medians(waveform_extractor, peak_sign="neg", unit_ids=None This code is ported from: https://github.com/int-brain-lab/ibllib/blob/master/brainbox/metrics/single_units.py """ - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting if unit_ids is None: - unit_ids = sorting.unit_ids - - before = waveform_extractor.nbefore - - extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign) - - spike_amplitudes = None - if waveform_extractor.has_extension("spike_amplitudes"): - amp_calculator = waveform_extractor.load_extension("spike_amplitudes") - spike_amplitudes = amp_calculator.get_data(outputs="by_unit") + unit_ids = sorting_result.unit_ids all_amplitude_medians = {} - for unit_id in unit_ids: - if spike_amplitudes is None: - waveforms = waveform_extractor.get_waveforms(unit_id) - chan_id = extremum_channels_ids[unit_id] - if waveform_extractor.is_sparse(): - chan_ind = np.where(waveform_extractor.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] - else: - chan_ind = waveform_extractor.channel_ids_to_indices([chan_id])[0] - amplitudes = waveforms[:, before, chan_ind] - else: - amplitudes = np.concatenate([spike_amps[unit_id] for spike_amps in spike_amplitudes]) - - # change amplitudes signs in case peak_sign is pos - abs_amplitudes = np.abs(amplitudes) - all_amplitude_medians[unit_id] = np.median(abs_amplitudes) + if sorting_result.has_extension("spike_amplitudes") or sorting_result.has_extension("waveforms"): + amplitudes_by_units = _get_amplitudes_by_units(sorting_result, unit_ids, peak_sign) + for unit_id in unit_ids: + all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) + else: + warnings.warn("compute_amplitude_medians need 'spike_amplitudes' or 'waveforms' extension") + for unit_id in unit_ids: + all_amplitude_medians[unit_id] = np.nan return all_amplitude_medians - _default_params["amplitude_median"] = dict(peak_sign="neg") def compute_drift_metrics( - waveform_extractor, + sorting_result, interval_s=60, min_spikes_per_interval=100, direction="y", @@ -936,8 +920,8 @@ def compute_drift_metrics( Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object interval_s : int, default: 60 Interval length is seconds for computing spike depth min_spikes_per_interval : int, default: 100 @@ -973,14 +957,23 @@ def compute_drift_metrics( there are large displacements in between segments, the resulting metric values will be very high. """ res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.has_extension("spike_locations"): - locs_calculator = waveform_extractor.load_extension("spike_locations") - spike_locations = locs_calculator.get_data(outputs="concatenated") - spike_locations_by_unit = locs_calculator.get_data(outputs="by_unit") + if sorting_result.has_extension("spike_locations"): + spike_locations_ext = sorting_result.get_extension("spike_locations") + spike_locations = spike_locations_ext.get_data() + # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") + spikes = sorting.to_spike_vector() + spike_locations_by_unit = {} + for unit_id in unit_ids: + unit_index = sorting.id_to_index(unit_id) + spike_mask = spikes["unit_index"] ==unit_index + spike_locations_by_unit[unit_id] = spike_locations[spike_mask] + + + else: warnings.warn( "The drift metrics require the `spike_locations` waveform extension. " @@ -993,11 +986,11 @@ def compute_drift_metrics( else: return res(empty_dict, empty_dict, empty_dict) - interval_samples = int(interval_s * waveform_extractor.sampling_frequency) + interval_samples = int(interval_s * sorting_result.sampling_frequency) assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) - total_duration = waveform_extractor.get_total_duration() + total_duration = sorting_result.get_total_duration() if total_duration < min_num_bins * interval_s: warnings.warn( "The recording is too short given the specified 'interval_s' and " @@ -1017,15 +1010,12 @@ def compute_drift_metrics( # reference positions are the medians across segments reference_positions = np.zeros(len(unit_ids)) for unit_ind, unit_id in enumerate(unit_ids): - locs = [] - for segment_index in range(waveform_extractor.get_num_segments()): - locs.append(spike_locations_by_unit[segment_index][unit_id][direction]) - reference_positions[unit_ind] = np.median(np.concatenate(locs)) + reference_positions[unit_ind] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments median_position_segments = None - for segment_index in range(waveform_extractor.get_num_segments()): - seg_length = waveform_extractor.get_num_samples(segment_index) + for segment_index in range(sorting_result.get_num_segments()): + seg_length = sorting_result.get_num_samples(segment_index) num_bin_edges = seg_length // interval_samples + 1 bins = np.arange(num_bin_edges) * interval_samples spike_vector = sorting.to_spike_vector() @@ -1371,7 +1361,7 @@ def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters, def compute_sd_ratio( - wvf_extractor: WaveformExtractor, + sorting_result: SortingResult, censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, @@ -1385,8 +1375,8 @@ def compute_sd_ratio( Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. + sorting_result: SortingResult + A SortingResult object censored_period_ms : float, default: 4.0 The censored period in milliseconds. This is to remove any potential bursts that could affect the SD. correct_for_drift: bool, default: True @@ -1408,20 +1398,23 @@ def compute_sd_ratio( import numba from ..curation.curation_tools import _find_duplicated_spikes_keep_first_iterative - censored_period = int(round(censored_period_ms * 1e-3 * wvf_extractor.sampling_frequency)) + sorting = sorting_result.sorting + + censored_period = int(round(censored_period_ms * 1e-3 * sorting_result.sampling_frequency)) if unit_ids is None: - unit_ids = wvf_extractor.unit_ids + unit_ids = sorting_result.unit_ids - if not wvf_extractor.has_recording(): + if not sorting_result.has_recording(): warnings.warn( - "The `sd_ratio` metric cannot work with a recordless WaveformExtractor object" + "The `sd_ratio` metric cannot work with a recordless SortingResult object" "SD ratio metric will be set to NaN" ) return {unit_id: np.nan for unit_id in unit_ids} - if wvf_extractor.has_extension("spike_amplitudes"): - amplitudes_ext = wvf_extractor.load_extension("spike_amplitudes") - spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit") + if sorting_result.has_extension("spike_amplitudes"): + amplitudes_ext = sorting_result.get_extension("spike_amplitudes") + # spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit") + spike_amplitudes = amplitudes_ext.get_data() else: warnings.warn( "The `sd_ratio` metric require the `spike_amplitudes` waveform extension. " @@ -1431,24 +1424,36 @@ def compute_sd_ratio( return {unit_id: np.nan for unit_id in unit_ids} noise_levels = get_noise_levels( - wvf_extractor.recording, return_scaled=amplitudes_ext._params["return_scaled"], method="std" + sorting_result.recording, return_scaled=amplitudes_ext.params["return_scaled"], method="std" ) - best_channels = get_template_extremum_channel(wvf_extractor, outputs="index", **kwargs) - n_spikes = wvf_extractor.sorting.count_num_spikes_per_unit() + best_channels = get_template_extremum_channel(sorting_result, outputs="index", **kwargs) + n_spikes = sorting.count_num_spikes_per_unit() + + if correct_for_template_itself: + tamplates_array = _get_dense_templates_array(sorting_result, return_scaled=True) + spikes = sorting.to_spike_vector() sd_ratio = {} for unit_id in unit_ids: + unit_index = sorting_result.sorting.id_to_index(unit_id) + spk_amp = [] - for segment_index in range(wvf_extractor.get_num_segments()): - spike_train = wvf_extractor.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype( - np.int64, copy=False - ) + for segment_index in range(sorting_result.get_num_segments()): + # spike_train = sorting_result.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype( + # np.int64, copy=False + # ) + spike_mask = (spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index) + spike_train = spikes[spike_mask]["sample_index"].astype(np.int64, copy=False) + amplitudes = spike_amplitudes[spike_mask] + censored_indices = _find_duplicated_spikes_keep_first_iterative( spike_train, censored_period, ) - spk_amp.append(np.delete(spike_amplitudes[segment_index][unit_id], censored_indices)) + # spk_amp.append(np.delete(spike_amplitudes[segment_index][unit_id], censored_indices)) + spk_amp.append(np.delete(amplitudes, censored_indices)) + spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))]) if correct_for_drift: @@ -1460,11 +1465,15 @@ def compute_sd_ratio( std_noise = noise_levels[best_channel] if correct_for_template_itself: - template = wvf_extractor.get_template(unit_id, force_dense=True)[:, best_channel] + # template = sorting_result.get_template(unit_id, force_dense=True)[:, best_channel] + + template = tamplates_array[unit_index, :, :][:, best_channel] + nsamples = template.shape[0] + # Computing the variance of a trace that is all 0 and n_spikes non-overlapping template. # TODO: Take into account that templates for different segments might differ. - p = wvf_extractor.nsamples * n_spikes[unit_id] / wvf_extractor.get_total_samples() + p = nsamples * n_spikes[unit_id] / sorting_result.get_total_samples() total_variance = p * np.mean(template**2) - p**2 * np.mean(template) std_noise = np.sqrt(std_noise**2 - total_variance) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index a8e6d90e6a..91d3d47f2e 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -15,17 +15,12 @@ except: pass -from ..core import get_random_data_chunks, compute_sparsity, WaveformExtractor -from ..core.template_tools import get_template_extremum_channel - -from ..postprocessing import WaveformPrincipalComponent import warnings from .misc_metrics import compute_num_spikes, compute_firing_rates -from ..core import get_random_data_chunks, load_waveforms, compute_sparsity, WaveformExtractor +from ..core import get_random_data_chunks, compute_sparsity from ..core.template_tools import get_template_extremum_channel -from ..postprocessing import WaveformPrincipalComponent _possible_pc_metric_names = [ @@ -60,21 +55,17 @@ def get_quality_pca_metric_list(): def calculate_pc_metrics( - pca, metric_names=None, sparsity=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False + sorting_result, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False ): """Calculate principal component derived metrics. Parameters ---------- - pca : WaveformPrincipalComponent - Waveform object with principal components computed. + sorting_result: SortingResult + A SortingResult object metric_names : list of str, default: None The list of PC metrics to compute. If not provided, defaults to all PC metrics. - sparsity: ChannelSparsity or None, default: None - The sparsity object. This is used also to identify neighbor - units and speed up computations. If None all channels and all units are used - for each unit. qm_params : dict or None Dictionary with parameters for each PC metric function. unit_ids : list of int or None @@ -91,18 +82,22 @@ def calculate_pc_metrics( pc_metrics : dict The computed PC metrics. """ + pca_ext = sorting_result.get_extension("principal_components") + assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" + + + sorting = sorting_result.sorting + if metric_names is None: metric_names = _possible_pc_metric_names if qm_params is None: qm_params = _default_params - assert isinstance(pca, WaveformPrincipalComponent) - we = pca.waveform_extractor - extremum_channels = get_template_extremum_channel(we) + extremum_channels = get_template_extremum_channel(sorting_result) if unit_ids is None: - unit_ids = we.unit_ids - channel_ids = we.channel_ids + unit_ids = sorting_result.unit_ids + channel_ids = sorting_result.channel_ids # create output dict of dict pc_metrics['metric_name'][unit_id] pc_metrics = {k: {} for k in metric_names} @@ -116,8 +111,8 @@ def calculate_pc_metrics( # Compute nspikes and firing rate outside of main loop for speed if any([n in metric_names for n in ["nn_isolation", "nn_noise_overlap"]]): - n_spikes_all_units = compute_num_spikes(we, unit_ids=unit_ids) - fr_all_units = compute_firing_rates(we, unit_ids=unit_ids) + n_spikes_all_units = compute_num_spikes(sorting_result, unit_ids=unit_ids) + fr_all_units = compute_firing_rates(sorting_result, unit_ids=unit_ids) else: n_spikes_all_units = None fr_all_units = None @@ -131,24 +126,31 @@ def calculate_pc_metrics( if run_in_parallel: parallel_functions = [] - all_labels, all_pcs = pca.get_all_projections() + + # all_labels, all_pcs = pca.get_all_projections() + # TODO: this is wring all_pcs used to be dense even when the waveform extractor was sparse + all_pcs = pca_ext.data["pca_projection"] + spikes = sorting.to_spike_vector() + some_spikes = spikes[sorting_result.random_spikes_indices] + all_labels = sorting.unit_ids[some_spikes["unit_index"]] items = [] for unit_id in unit_ids: - if we.is_sparse(): - neighbor_channel_ids = we.sparsity.unit_id_to_channel_ids[unit_id] - neighbor_unit_ids = [ - other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids - ] - elif sparsity is not None: - neighbor_channel_ids = sparsity.unit_id_to_channel_ids[unit_id] + print(sorting_result.is_sparse()) + if sorting_result.is_sparse(): + neighbor_channel_ids = sorting_result.sparsity.unit_id_to_channel_ids[unit_id] neighbor_unit_ids = [ other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids ] + # elif sparsity is not None: + # neighbor_channel_ids = sparsity.unit_id_to_channel_ids[unit_id] + # neighbor_unit_ids = [ + # other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids + # ] else: neighbor_channel_ids = channel_ids neighbor_unit_ids = unit_ids - neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids) + neighbor_channel_indices = sorting_result.channel_ids_to_indices(neighbor_channel_ids) labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] pcs = all_pcs[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] @@ -162,7 +164,7 @@ def calculate_pc_metrics( unit_ids, qm_params, seed, - we.folder, + # we.folder, n_spikes_all_units, fr_all_units, ) @@ -174,15 +176,16 @@ def calculate_pc_metrics( for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric else: - with ProcessPoolExecutor(n_jobs) as executor: - results = executor.map(pca_metrics_one_unit, items) - if progress_bar: - results = tqdm(results, total=len(unit_ids)) + raise NotImplementedError + # with ProcessPoolExecutor(n_jobs) as executor: + # results = executor.map(pca_metrics_one_unit, items) + # if progress_bar: + # results = tqdm(results, total=len(unit_ids)) - for ui, pca_metrics_unit in enumerate(results): - unit_id = unit_ids[ui] - for metric_name, metric in pca_metrics_unit.items(): - pc_metrics[metric_name][unit_id] = metric + # for ui, pca_metrics_unit in enumerate(results): + # unit_id = unit_ids[ui] + # for metric_name, metric in pca_metrics_unit.items(): + # pc_metrics[metric_name][unit_id] = metric return pc_metrics @@ -358,8 +361,8 @@ def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_n def nearest_neighbors_isolation( - waveform_extractor: WaveformExtractor, - this_unit_id: int, + sorting_result, + this_unit_id: int | str, n_spikes_all_units: dict = None, fr_all_units: dict = None, max_spikes: int = 1000, @@ -376,9 +379,9 @@ def nearest_neighbors_isolation( Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. - this_unit_id : int + sorting_result: SortingResult + A SortingResult object + this_unit_id : int | str The ID for the unit to calculate these metrics for. n_spikes_all_units: dict, default: None Dictionary of the form ``{: }`` for the waveform extractor. @@ -403,10 +406,10 @@ def nearest_neighbors_isolation( radius_um : float, default: 100 The radius, in um, that channels need to be within the peak channel to be included. peak_sign: "neg" | "pos" | "both", default: "neg" - The peak_sign used to compute sparsity and neighbor units. Used if waveform_extractor + The peak_sign used to compute sparsity and neighbor units. Used if sorting_result is not sparse already. min_spatial_overlap : float, default: 100 - In case waveform_extractor is sparse, other units are selected if they share at least + In case sorting_result is sparse, other units are selected if they share at least `min_spatial_overlap` times `n_target_unit_channels` with the target unit seed : int, default: None Seed for random subsampling of spikes. @@ -451,12 +454,15 @@ def nearest_neighbors_isolation( """ rng = np.random.default_rng(seed=seed) - sorting = waveform_extractor.sorting + waveforms_ext = sorting_result.get_extension("waveforms") + assert waveforms_ext is not None, "nearest_neighbors_isolation() need extension 'waveforms'" + + sorting = sorting_result.sorting all_units_ids = sorting.get_unit_ids() if n_spikes_all_units is None: - n_spikes_all_units = compute_num_spikes(waveform_extractor) + n_spikes_all_units = compute_num_spikes(sorting_result) if fr_all_units is None: - fr_all_units = compute_firing_rates(waveform_extractor) + fr_all_units = compute_firing_rates(sorting_result) # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: @@ -486,15 +492,17 @@ def nearest_neighbors_isolation( other_units_ids = np.setdiff1d(all_units_ids, this_unit_id) # get waveforms of target unit - waveforms_target_unit = waveform_extractor.get_waveforms(unit_id=this_unit_id) + # waveforms_target_unit = sorting_result.get_waveforms(unit_id=this_unit_id) + waveforms_target_unit = waveforms_ext.get_waveforms_one_unit(unit_id=this_unit_id, force_dense=False) + n_spikes_target_unit = waveforms_target_unit.shape[0] # find units whose signal channels (i.e. channels inside some radius around # the channel with largest amplitude) overlap with signal channels of the target unit - if waveform_extractor.is_sparse(): - sparsity = waveform_extractor.sparsity + if sorting_result.is_sparse(): + sparsity = sorting_result.sparsity else: - sparsity = compute_sparsity(waveform_extractor, method="radius", peak_sign=peak_sign, radius_um=radius_um) + sparsity = compute_sparsity(sorting_result, method="radius", peak_sign=peak_sign, radius_um=radius_um) closest_chans_target_unit = sparsity.unit_id_to_channel_indices[this_unit_id] n_channels_target_unit = len(closest_chans_target_unit) # select other units that have a minimum spatial overlap with target unit @@ -515,7 +523,9 @@ def nearest_neighbors_isolation( len(other_units_ids), ) for other_unit_id in other_units_ids: - waveforms_other_unit = waveform_extractor.get_waveforms(unit_id=other_unit_id) + # waveforms_other_unit = sorting_result.get_waveforms(unit_id=other_unit_id) + waveforms_other_unit = waveforms_ext.get_waveforms_one_unit(unit_id=other_unit_id, force_dense=False) + n_spikes_other_unit = waveforms_other_unit.shape[0] closest_chans_other_unit = sparsity.unit_id_to_channel_indices[other_unit_id] n_snippets = np.min([n_spikes_target_unit, n_spikes_other_unit, max_spikes]) @@ -528,7 +538,7 @@ def nearest_neighbors_isolation( # project this unit and other unit waveforms on common subspace common_channel_idxs = np.intersect1d(closest_chans_target_unit, closest_chans_other_unit) - if waveform_extractor.is_sparse(): + if sorting_result.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ :, :, np.isin(closest_chans_target_unit, common_channel_idxs) @@ -565,8 +575,8 @@ def nearest_neighbors_isolation( def nearest_neighbors_noise_overlap( - waveform_extractor: WaveformExtractor, - this_unit_id: int, + sorting_result, + this_unit_id: int | str, n_spikes_all_units: dict = None, fr_all_units: dict = None, max_spikes: int = 1000, @@ -582,9 +592,9 @@ def nearest_neighbors_noise_overlap( Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object. - this_unit_id : int + sorting_result: SortingResult + A SortingResult object + this_unit_id : int | str The ID of the unit to calculate this metric on. n_spikes_all_units: dict, default: None Dictionary of the form ``{: }`` for the waveform extractor. @@ -607,7 +617,7 @@ def nearest_neighbors_noise_overlap( radius_um : float, default: 100 The radius, in um, that channels need to be within the peak channel to be included. peak_sign: "neg" | "pos" | "both", default: "neg" - The peak_sign used to compute sparsity and neighbor units. Used if waveform_extractor + The peak_sign used to compute sparsity and neighbor units. Used if sorting_result is not sparse already. seed : int, default: 0 Random seed for subsampling spikes. @@ -638,10 +648,17 @@ def nearest_neighbors_noise_overlap( """ rng = np.random.default_rng(seed=seed) + waveforms_ext = sorting_result.get_extension("waveforms") + assert waveforms_ext is not None, "nearest_neighbors_isolation() need extension 'waveforms'" + + templates_ext = sorting_result.get_extension("templates") + assert templates_ext is not None, "nearest_neighbors_isolation() need extension 'templates'" + + if n_spikes_all_units is None: - n_spikes_all_units = compute_num_spikes(waveform_extractor) + n_spikes_all_units = compute_num_spikes(sorting_result) if fr_all_units is None: - fr_all_units = compute_firing_rates(waveform_extractor) + fr_all_units = compute_firing_rates(sorting_result) # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: @@ -658,18 +675,20 @@ def nearest_neighbors_noise_overlap( return np.nan else: # get random snippets from the recording to create a noise cluster - recording = waveform_extractor.recording + nsamples = waveforms_ext.nbefore + waveforms_ext.nafter + recording = sorting_result.recording noise_cluster = get_random_data_chunks( recording, - return_scaled=waveform_extractor.return_scaled, + return_scaled=waveforms_ext.params["return_scaled"], num_chunks_per_segment=max_spikes, - chunk_size=waveform_extractor.nsamples, + chunk_size=nsamples, seed=seed, ) - noise_cluster = np.reshape(noise_cluster, (max_spikes, waveform_extractor.nsamples, -1)) + noise_cluster = np.reshape(noise_cluster, (max_spikes, nsamples, -1)) # get waveforms for target cluster - waveforms = waveform_extractor.get_waveforms(unit_id=this_unit_id).copy() + # waveforms = sorting_result.get_waveforms(unit_id=this_unit_id).copy() + waveforms = waveforms_ext.get_waveforms_one_unit(unit_id=this_unit_id, force_dense=False).copy() # adjust the size of the target and noise clusters to be equal if waveforms.shape[0] > max_spikes: @@ -684,17 +703,21 @@ def nearest_neighbors_noise_overlap( n_snippets = max_spikes # restrict to channels with significant signal - if waveform_extractor.is_sparse(): - sparsity = waveform_extractor.sparsity + if sorting_result.is_sparse(): + sparsity = sorting_result.sparsity else: - sparsity = compute_sparsity(waveform_extractor, method="radius", peak_sign=peak_sign, radius_um=radius_um) + sparsity = compute_sparsity(sorting_result, method="radius", peak_sign=peak_sign, radius_um=radius_um) noise_cluster = noise_cluster[:, :, sparsity.unit_id_to_channel_indices[this_unit_id]] # compute weighted noise snippet (Z) - median_waveform = waveform_extractor.get_template(unit_id=this_unit_id, mode="median") - - # in case waveform_extractor is sparse, waveforms and templates are already sparse - if not waveform_extractor.is_sparse(): + # median_waveform = sorting_result.get_template(unit_id=this_unit_id, mode="median") + all_templates = templates_ext.get_data(operator="median") + this_unit_index = sorting_result.sorting.id_to_index(this_unit_id) + median_waveform = all_templates[this_unit_index, :, :] + + # in case sorting_result is sparse, waveforms and templates are already sparse + if not sorting_result.is_sparse(): + # @alessio : this next line is suspicious because the waveforms is already sparse no ? Am i wrong ? waveforms = waveforms[:, :, sparsity.unit_id_to_channel_indices[this_unit_id]] median_waveform = median_waveform[:, sparsity.unit_id_to_channel_indices[this_unit_id]] @@ -894,13 +917,13 @@ def pca_metrics_one_unit(args): unit_ids, qm_params, seed, - we_folder, + # we_folder, n_spikes_all_units, fr_all_units, ) = args - if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names: - we = load_waveforms(we_folder) + # if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names: + # we = load_waveforms(we_folder) pc_metrics = {} # metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 54b1027305..4026850165 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -5,20 +5,35 @@ import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension + from .quality_metric_list import calculate_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params +# TODO : -class QualityMetricCalculator(BaseWaveformExtractorExtension): - """Class to compute quality metrics of spike sorting output. +class ComputeQualityMetrics(ResultExtension): + """ + Compute quality metrics on sorting_. Parameters ---------- - waveform_extractor: WaveformExtractor - The waveform extractor object + sorting_result: SortingResult + A SortingResult object + metric_names : list or None + List of quality metrics to compute. + qm_params : dict or None + Dictionary with parameters for quality metrics calculation. + Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` + skip_pc_metrics : bool + If True, PC metrics computation is skipped. + + Returns + ------- + metrics: pandas.DataFrame + Data frame with the computed metrics Notes ----- @@ -26,30 +41,25 @@ class QualityMetricCalculator(BaseWaveformExtractorExtension): """ extension_name = "quality_metrics" - - def __init__(self, waveform_extractor): - BaseWaveformExtractorExtension.__init__(self, waveform_extractor) - - if waveform_extractor.has_recording(): - self.recording = waveform_extractor.recording - else: - self.recording = None - self.sorting = waveform_extractor.sorting + depend_on = ["waveforms", "templates", "noise_levels"] + need_recording = False + use_nodepipeline = False + need_job_kwargs = True def _set_params( - self, metric_names=None, qm_params=None, peak_sign=None, seed=None, sparsity=None, skip_pc_metrics=False + self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False ): if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list - if self.waveform_extractor.has_extension("principal_components"): + if self.sorting_result.has_extension("principal_components") and not skip_pc_metrics: # by default 'nearest_neightbor' is removed because too slow pc_metrics = _possible_pc_metric_names.copy() pc_metrics.remove("nn_isolation") pc_metrics.remove("nn_noise_overlap") metric_names += pc_metrics # if spike_locations are not available, drift is removed from the list - if not self.waveform_extractor.has_extension("spike_locations"): + if not self.sorting_result.has_extension("spike_locations"): if "drift" in metric_names: metric_names.remove("drift") @@ -62,7 +72,6 @@ def _set_params( params = dict( metric_names=[str(name) for name in np.unique(metric_names)], - sparsity=sparsity, peak_sign=peak_sign, seed=seed, qm_params=qm_params_, @@ -72,26 +81,27 @@ def _set_params( return params def _select_extension_data(self, unit_ids): - # filter metrics dataframe - new_metrics = self._extension_data["metrics"].loc[np.array(unit_ids)] - return dict(metrics=new_metrics) + new_metrics = self.data["metrics"].loc[np.array(unit_ids)] + new_data = dict(metrics=new_metrics) + return new_data - def _run(self, verbose, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): """ Compute quality metrics. """ - metric_names = self._params["metric_names"] - qm_params = self._params["qm_params"] - sparsity = self._params["sparsity"] - seed = self._params["seed"] + metric_names = self.params["metric_names"] + qm_params = self.params["qm_params"] + # sparsity = self.params["sparsity"] + seed = self.params["seed"] # update job_kwargs with global ones job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] - unit_ids = self.sorting.unit_ids - non_empty_unit_ids = self.sorting.get_non_empty_unit_ids() + sorting = self.sorting_result.sorting + unit_ids = sorting.unit_ids + non_empty_unit_ids = sorting.get_non_empty_unit_ids() empty_unit_ids = unit_ids[~np.isin(unit_ids, non_empty_unit_ids)] if len(empty_unit_ids) > 0: warnings.warn( @@ -115,7 +125,7 @@ def _run(self, verbose, **job_kwargs): func = _misc_metric_name_to_func[metric_name] params = qm_params[metric_name] if metric_name in qm_params else {} - res = func(self.waveform_extractor, unit_ids=non_empty_unit_ids, **params) + res = func(self.sorting_result, unit_ids=non_empty_unit_ids, **params) # QM with uninstall dependencies might return None if res is not None: if isinstance(res, dict): @@ -129,15 +139,14 @@ def _run(self, verbose, **job_kwargs): # metrics based on PCs pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] - if len(pc_metric_names) > 0 and not self._params["skip_pc_metrics"]: - if not self.waveform_extractor.has_extension("principal_components"): + if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]: + if not self.sorting_result.has_extension("principal_components"): raise ValueError("waveform_principal_component must be provied") - pc_extension = self.waveform_extractor.load_extension("principal_components") pc_metrics = calculate_pc_metrics( - pc_extension, + self.sorting_result, unit_ids=non_empty_unit_ids, metric_names=pc_metric_names, - sparsity=sparsity, + # sparsity=sparsity, progress_bar=progress_bar, n_jobs=n_jobs, qm_params=qm_params, @@ -150,89 +159,15 @@ def _run(self, verbose, **job_kwargs): if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan - self._extension_data["metrics"] = metrics - - def get_data(self): - """ - Get the computed metrics. - - Returns - ------- - metrics : pd.DataFrame - Dataframe with quality metrics - """ - msg = "Quality metrics are not computed. Use the 'run()' function." - assert self._extension_data["metrics"] is not None, msg - return self._extension_data["metrics"] - - @staticmethod - def get_extension_function(): - return compute_quality_metrics - + self.data["metrics"] = metrics -WaveformExtractor.register_extension(QualityMetricCalculator) + def _get_data(self): + return self.data["metrics"] -def compute_quality_metrics( - waveform_extractor, - load_if_exists=False, - metric_names=None, - qm_params=None, - peak_sign=None, - seed=None, - sparsity=None, - skip_pc_metrics=False, - verbose=False, - **job_kwargs, -): - """Compute quality metrics on waveform extractor. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - The waveform extractor to compute metrics on. - load_if_exists : bool, default: False - Whether to load precomputed quality metrics, if they already exist. - metric_names : list or None - List of quality metrics to compute. - qm_params : dict or None - Dictionary with parameters for quality metrics calculation. - Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` - sparsity : dict or None, default: None - If given, the sparse channel_ids for each unit in PCA metrics computation. - This is used also to identify neighbor units and speed up computations. - If None all channels and all units are used for each unit. - skip_pc_metrics : bool - If True, PC metrics computation is skipped. - n_jobs : int - Number of jobs (used for PCA metrics) - verbose : bool - If True, output is verbose. - progress_bar : bool - If True, progress bar is shown. - - Returns - ------- - metrics: pandas.DataFrame - Data frame with the computed metrics - """ - if load_if_exists and waveform_extractor.has_extension(QualityMetricCalculator.extension_name): - qmc = waveform_extractor.load_extension(QualityMetricCalculator.extension_name) - else: - qmc = QualityMetricCalculator(waveform_extractor) - qmc.set_params( - metric_names=metric_names, - qm_params=qm_params, - peak_sign=peak_sign, - seed=seed, - sparsity=sparsity, - skip_pc_metrics=skip_pc_metrics, - ) - qmc.run(verbose=verbose, **job_kwargs) - - metrics = qmc.get_data() +register_result_extension(ComputeQualityMetrics) +compute_quality_metrics = ComputeQualityMetrics.function_factory() - return metrics def get_quality_metric_list(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 555db030e7..2e89428fa8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -2,18 +2,14 @@ import shutil from pathlib import Path import numpy as np -from spikeinterface import extract_waveforms -from spikeinterface.core import NumpySorting, synthetize_spike_train_bad_isi, add_synchrony_to_sorting -from spikeinterface.extractors.toy_example import toy_example +from spikeinterface.core import ( + NumpySorting, synthetize_spike_train_bad_isi, add_synchrony_to_sorting, generate_ground_truth_recording, start_sorting_result +) + +# from spikeinterface.extractors.toy_example import toy_example from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions from spikeinterface.qualitymetrics import calculate_pc_metrics -from spikeinterface.postprocessing import ( - compute_principal_components, - compute_spike_locations, - compute_spike_amplitudes, - compute_amplitude_scalings, -) from spikeinterface.qualitymetrics import ( mahalanobis_metrics, @@ -38,13 +34,40 @@ ) -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "qualitymetrics" -else: - cache_folder = Path("cache_folder") / "qualitymetrics" +# if hasattr(pytest, "global_test_folder"): +# cache_folder = pytest.global_test_folder / "qualitymetrics" +# else: +# cache_folder = Path("cache_folder") / "qualitymetrics" + + +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + + +def _sorting_result_simple(): + recording, sorting = generate_ground_truth_recording( + durations=[50.0,], sampling_frequency=30_000.0, num_channels=6, num_units=10, + generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + + sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True) + + sorting_result.select_random_spikes(max_spikes_per_unit=300, seed=2205) + sorting_result.compute("noise_levels") + sorting_result.compute("waveforms", **job_kwargs) + sorting_result.compute("templates") + sorting_result.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + sorting_result.compute("spike_amplitudes", **job_kwargs) + + return sorting_result +@pytest.fixture(scope="module") +def sorting_result_simple(): + return _sorting_result_simple() -def _simulated_data(): +def _sorting_violation(): max_time = 100.0 sampling_frequency = 30000 trains = [ @@ -56,100 +79,46 @@ def _simulated_data(): labels = [np.ones((len(trains[i]),), dtype="int") * i for i in range(len(trains))] spike_times = np.concatenate(trains) - spike_clusters = np.concatenate(labels) + spike_labels = np.concatenate(labels) order = np.argsort(spike_times) max_num_samples = np.floor(max_time * sampling_frequency) - 1 indexes = np.arange(0, max_time + 1, 1 / sampling_frequency) spike_times = np.searchsorted(indexes, spike_times[order], side="left") - spike_clusters = spike_clusters[order] + spike_labels = spike_labels[order] mask = spike_times < max_num_samples spike_times = spike_times[mask] - spike_clusters = spike_clusters[mask] - - return {"duration": max_time, "times": spike_times, "labels": spike_clusters} - - -def _waveform_extractor_simple(): - for name in ("rec1", "sort1", "waveform_folder1"): - if (cache_folder / name).exists(): - shutil.rmtree(cache_folder / name) - - recording, sorting = toy_example(duration=50, seed=10, firing_rate=6.0) - - recording = recording.save(folder=cache_folder / "rec1") - sorting = sorting.save(folder=cache_folder / "sort1") - folder = cache_folder / "waveform_folder1" - we = extract_waveforms( - recording, - sorting, - folder, - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=1000, - n_jobs=1, - chunk_size=30000, - overwrite=True, - ) - _ = compute_principal_components(we, n_components=5, mode="by_channel_local") - _ = compute_spike_amplitudes(we, return_scaled=True) - return we - - -def _waveform_extractor_violations(data): - for name in ("rec2", "sort2", "waveform_folder2"): - if (cache_folder / name).exists(): - shutil.rmtree(cache_folder / name) - - recording, sorting = toy_example( - duration=[data["duration"]], - spike_times=[data["times"]], - spike_labels=[data["labels"]], - num_segments=1, - num_units=4, - # score_detection=score_detection, - seed=10, - ) - recording = recording.save(folder=cache_folder / "rec2") - sorting = sorting.save(folder=cache_folder / "sort2") - folder = cache_folder / "waveform_folder2" - we = extract_waveforms( - recording, - sorting, - folder, - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=1000, - n_jobs=1, - chunk_size=30000, - overwrite=True, - ) - return we + spike_labels = spike_labels[mask] + unit_ids = ["a", "b", "c"] + sorting = NumpySorting.from_times_labels(spike_times, spike_labels, sampling_frequency, unit_ids=unit_ids) -@pytest.fixture(scope="module") -def simulated_data(): - return _simulated_data() + return sorting -@pytest.fixture(scope="module") -def waveform_extractor_violations(simulated_data): - return _waveform_extractor_violations(simulated_data) +def _sorting_result_violations(): + + sorting = _sorting_violation() + duration = (sorting.to_spike_vector()["sample_index"][-1] + 1) / sorting.sampling_frequency + + recording, sorting = generate_ground_truth_recording( + durations=[duration], sampling_frequency=sorting.sampling_frequency, num_channels=6, + sorting=sorting, + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True) + # this used only for ISI metrics so no need to compute heavy extensions + return sorting_result @pytest.fixture(scope="module") -def waveform_extractor_simple(): - return _waveform_extractor_simple() +def sorting_result_violations(): + return _sorting_result_violations() + -def test_calculate_pc_metrics(waveform_extractor_simple): - we = waveform_extractor_simple - print(we) - pca = we.load_extension("principal_components") - print(pca) - res = calculate_pc_metrics(pca) - print(res) def test_mahalanobis_metrics(): @@ -214,10 +183,10 @@ def test_simplified_silhouette_score_metrics(): assert sim_sil_score1 < sim_sil_score2 -def test_calculate_firing_rate_num_spikes(waveform_extractor_simple): - we = waveform_extractor_simple - firing_rates = compute_firing_rates(we) - num_spikes = compute_num_spikes(we) +def test_calculate_firing_rate_num_spikes(sorting_result_simple): + sorting_result = sorting_result_simple + firing_rates = compute_firing_rates(sorting_result) + num_spikes = compute_num_spikes(sorting_result) # testing method accuracy with magic number is not a good pratcice, I remove this. # firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09} @@ -226,47 +195,48 @@ def test_calculate_firing_rate_num_spikes(waveform_extractor_simple): # np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) -def test_calculate_firing_range(waveform_extractor_simple): - we = waveform_extractor_simple - firing_ranges = compute_firing_ranges(we) +def test_calculate_firing_range(sorting_result_simple): + sorting_result = sorting_result_simple + firing_ranges = compute_firing_ranges(sorting_result) print(firing_ranges) with pytest.warns(UserWarning) as w: - firing_ranges_nan = compute_firing_ranges(we, bin_size_s=we.get_total_duration() + 1) + firing_ranges_nan = compute_firing_ranges(sorting_result, bin_size_s=sorting_result.get_total_duration() + 1) assert np.all([np.isnan(f) for f in firing_ranges_nan.values()]) -def test_calculate_amplitude_cutoff(waveform_extractor_simple): - we = waveform_extractor_simple - spike_amps = we.load_extension("spike_amplitudes").get_data() - amp_cuts = compute_amplitude_cutoffs(we, num_histogram_bins=10) - print(amp_cuts) +def test_calculate_amplitude_cutoff(sorting_result_simple): + sorting_result = sorting_result_simple + # spike_amps = sorting_result.get_extension("spike_amplitudes").get_data() + amp_cuts = compute_amplitude_cutoffs(sorting_result, num_histogram_bins=10) + # print(amp_cuts) # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_cuts_gt = {0: 0.33067210050787543, 1: 0.43482247296942045, 2: 0.43482247296942045} # assert np.allclose(list(amp_cuts_gt.values()), list(amp_cuts.values()), rtol=0.05) -def test_calculate_amplitude_median(waveform_extractor_simple): - we = waveform_extractor_simple - spike_amps = we.load_extension("spike_amplitudes").get_data() - amp_medians = compute_amplitude_medians(we) - print(spike_amps, amp_medians) +def test_calculate_amplitude_median(sorting_result_simple): + sorting_result = sorting_result_simple + # spike_amps = sorting_result.get_extension("spike_amplitudes").get_data() + amp_medians = compute_amplitude_medians(sorting_result) + # print(amp_medians) # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725} # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) -def test_calculate_amplitude_cv_metrics(waveform_extractor_simple): - we = waveform_extractor_simple - amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(we, average_num_spikes_per_bin=20) +def test_calculate_amplitude_cv_metrics(sorting_result_simple): + sorting_result = sorting_result_simple + amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(sorting_result, average_num_spikes_per_bin=20) print(amp_cv_median) print(amp_cv_range) - amps_scalings = compute_amplitude_scalings(we) + # amps_scalings = compute_amplitude_scalings(sorting_result) + sorting_result.compute("amplitude_scalings", **job_kwargs) amp_cv_median_scalings, amp_cv_range_scalings = compute_amplitude_cv_metrics( - we, + sorting_result, average_num_spikes_per_bin=20, amplitude_extension="amplitude_scalings", min_num_bins=5, @@ -275,9 +245,9 @@ def test_calculate_amplitude_cv_metrics(waveform_extractor_simple): print(amp_cv_range_scalings) -def test_calculate_snrs(waveform_extractor_simple): - we = waveform_extractor_simple - snrs = compute_snrs(we) +def test_calculate_snrs(sorting_result_simple): + sorting_result = sorting_result_simple + snrs = compute_snrs(sorting_result) print(snrs) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -285,9 +255,9 @@ def test_calculate_snrs(waveform_extractor_simple): # assert np.allclose(list(snrs_gt.values()), list(snrs.values()), rtol=0.05) -def test_calculate_presence_ratio(waveform_extractor_simple): - we = waveform_extractor_simple - ratios = compute_presence_ratios(we, bin_duration_s=10) +def test_calculate_presence_ratio(sorting_result_simple): + sorting_result = sorting_result_simple + ratios = compute_presence_ratios(sorting_result, bin_duration_s=10) print(ratios) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -295,9 +265,9 @@ def test_calculate_presence_ratio(waveform_extractor_simple): # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) -def test_calculate_isi_violations(waveform_extractor_violations): - we = waveform_extractor_violations - isi_viol, counts = compute_isi_violations(we, isi_threshold_ms=1, min_isi_ms=0.0) +def test_calculate_isi_violations(sorting_result_violations): + sorting_result = sorting_result_violations + isi_viol, counts = compute_isi_violations(sorting_result, isi_threshold_ms=1, min_isi_ms=0.0) print(isi_viol) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -307,9 +277,9 @@ def test_calculate_isi_violations(waveform_extractor_violations): # np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) -def test_calculate_sliding_rp_violations(waveform_extractor_violations): - we = waveform_extractor_violations - contaminations = compute_sliding_rp_violations(we, bin_size_ms=0.25, window_size_s=1) +def test_calculate_sliding_rp_violations(sorting_result_violations): + sorting_result = sorting_result_violations + contaminations = compute_sliding_rp_violations(sorting_result, bin_size_ms=0.25, window_size_s=1) print(contaminations) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -317,9 +287,9 @@ def test_calculate_sliding_rp_violations(waveform_extractor_violations): # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) -def test_calculate_rp_violations(waveform_extractor_violations): - we = waveform_extractor_violations - rp_contamination, counts = compute_refrac_period_violations(we, refractory_period_ms=1, censored_period_ms=0.0) +def test_calculate_rp_violations(sorting_result_violations): + sorting_result = sorting_result_violations + rp_contamination, counts = compute_refrac_period_violations(sorting_result, refractory_period_ms=1, censored_period_ms=0.0) print(rp_contamination, counts) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -331,17 +301,18 @@ def test_calculate_rp_violations(waveform_extractor_violations): sorting = NumpySorting.from_unit_dict( {0: np.array([28, 150], dtype=np.int16), 1: np.array([], dtype=np.int16)}, 30000 ) - we.sorting = sorting + # we.sorting = sorting + sorting_result2 = start_sorting_result(sorting, sorting_result.recording, format="memory", sparse=False) - rp_contamination, counts = compute_refrac_period_violations(we, refractory_period_ms=1, censored_period_ms=0.0) + rp_contamination, counts = compute_refrac_period_violations(sorting_result2, refractory_period_ms=1, censored_period_ms=0.0) assert np.isnan(rp_contamination[1]) -def test_synchrony_metrics(waveform_extractor_simple): - we = waveform_extractor_simple - sorting = we.sorting +def test_synchrony_metrics(sorting_result_simple): + sorting_result = sorting_result_simple + sorting = sorting_result.sorting synchrony_sizes = (2, 3, 4) - synchrony_metrics = compute_synchrony_metrics(we, synchrony_sizes=synchrony_sizes) + synchrony_metrics = compute_synchrony_metrics(sorting_result, synchrony_sizes=synchrony_sizes) print(synchrony_metrics) # check returns @@ -350,14 +321,17 @@ def test_synchrony_metrics(waveform_extractor_simple): # here we test that increasing added synchrony is captured by syncrhony metrics added_synchrony_levels = (0.2, 0.5, 0.8) - previous_waveform_extractor = we + previous_sorting_result = sorting_result for sync_level in added_synchrony_levels: sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) - waveform_extractor_sync = extract_waveforms(previous_waveform_extractor.recording, sorting_sync, mode="memory") + #waveform_extractor_sync = extract_waveforms(previous_waveform_extractor.recording, sorting_sync, mode="memory") + sorting_result_sync = start_sorting_result(sorting_sync, sorting_result.recording, format="memory") + + previous_synchrony_metrics = compute_synchrony_metrics( - previous_waveform_extractor, synchrony_sizes=synchrony_sizes + previous_sorting_result, synchrony_sizes=synchrony_sizes ) - current_synchrony_metrics = compute_synchrony_metrics(waveform_extractor_sync, synchrony_sizes=synchrony_sizes) + current_synchrony_metrics = compute_synchrony_metrics(sorting_result_sync, synchrony_sizes=synchrony_sizes) print(current_synchrony_metrics) # check that all values increased for i, col in enumerate(previous_synchrony_metrics._fields): @@ -369,16 +343,17 @@ def test_synchrony_metrics(waveform_extractor_simple): ) # set new previous waveform extractor - previous_waveform_extractor = waveform_extractor_sync + previous_sorting_result = sorting_result_sync @pytest.mark.sortingcomponents -def test_calculate_drift_metrics(waveform_extractor_simple): - we = waveform_extractor_simple - spike_locs = compute_spike_locations(we) - drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics(we, interval_s=10, min_spikes_per_interval=10) +def test_calculate_drift_metrics(sorting_result_simple): + sorting_result = sorting_result_simple + sorting_result.compute("spike_locations", **job_kwargs) + + drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics(sorting_result, interval_s=10, min_spikes_per_interval=10) - print(drifts_ptps, drifts_stds, drift_mads) + # print(drifts_ptps, drifts_stds, drift_mads) # testing method accuracy with magic number is not a good pratcice, I remove this. # drift_ptps_gt = {0: 0.7155675636836349, 1: 0.8163672125409391, 2: 1.0224792180505773} @@ -389,29 +364,40 @@ def test_calculate_drift_metrics(waveform_extractor_simple): # assert np.allclose(list(drift_mads_gt.values()), list(drift_mads.values()), rtol=0.05) -def test_calculate_sd_ratio(waveform_extractor_simple): +def test_calculate_sd_ratio(sorting_result_simple): sd_ratio = compute_sd_ratio( - waveform_extractor_simple, + sorting_result_simple, ) - assert np.all(list(sd_ratio.keys()) == waveform_extractor_simple.unit_ids) + assert np.all(list(sd_ratio.keys()) == sorting_result_simple.unit_ids) # assert np.allclose(list(sd_ratio.values()), 1, atol=0.2, rtol=0) if __name__ == "__main__": - sim_data = _simulated_data() - we = _waveform_extractor_simple() - - we_violations = _waveform_extractor_violations(sim_data) - test_calculate_amplitude_cutoff(we) - test_calculate_presence_ratio(we) - test_calculate_amplitude_median(we) - test_calculate_isi_violations(we) - test_calculate_sliding_rp_violations(we) - test_calculate_drift_metrics(we) - test_synchrony_metrics(we) - test_calculate_firing_range(we) - test_calculate_amplitude_cv_metrics(we) - - # for windows we need an explicit del for closing the recording files - del we, we_violations + + sorting_result = _sorting_result_simple() + print(sorting_result) + + # test_calculate_firing_rate_num_spikes(sorting_result) + # test_calculate_snrs(sorting_result) + test_calculate_amplitude_cutoff(sorting_result) + # test_calculate_presence_ratio(sorting_result) + # test_calculate_amplitude_median(sorting_result) + # test_calculate_sliding_rp_violations(sorting_result) + # test_calculate_drift_metrics(sorting_result) + # test_synchrony_metrics(sorting_result) + # test_calculate_firing_range(sorting_result) + # test_calculate_amplitude_cv_metrics(sorting_result) + # test_calculate_sd_ratio(sorting_result) + + + + # sorting_result_violations = _sorting_result_violations() + # print(sorting_result_violations) + # test_calculate_isi_violations(sorting_result_violations) + # test_calculate_sliding_rp_violations(sorting_result_violations) + # test_calculate_rp_violations(sorting_result_violations) + + + + diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py new file mode 100644 index 0000000000..80e398822e --- /dev/null +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -0,0 +1,68 @@ +import pytest +import shutil +from pathlib import Path +import numpy as np +import pandas as pd +from spikeinterface.core import ( + NumpySorting, synthetize_spike_train_bad_isi, add_synchrony_to_sorting, generate_ground_truth_recording, start_sorting_result +) + +# from spikeinterface.extractors.toy_example import toy_example +from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions + +from spikeinterface.qualitymetrics import ( + calculate_pc_metrics, + nearest_neighbors_isolation, + nearest_neighbors_noise_overlap +) + + + +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + +def _sorting_result_simple(): + recording, sorting = generate_ground_truth_recording( + durations=[50.0,], sampling_frequency=30_000.0, num_channels=6, num_units=10, + generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + + sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True) + + sorting_result.select_random_spikes(max_spikes_per_unit=300, seed=2205) + sorting_result.compute("noise_levels") + sorting_result.compute("waveforms", **job_kwargs) + sorting_result.compute("templates", operators=["average", "std", "median"]) + sorting_result.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + sorting_result.compute("spike_amplitudes", **job_kwargs) + + return sorting_result + +@pytest.fixture(scope="module") +def sorting_result_simple(): + return _sorting_result_simple() + + +def test_calculate_pc_metrics(sorting_result_simple): + sorting_result = sorting_result_simple + res = calculate_pc_metrics(sorting_result) + print(pd.DataFrame(res)) + +def test_nearest_neighbors_isolation(sorting_result_simple): + sorting_result = sorting_result_simple + this_unit_id = sorting_result.unit_ids[0] + nearest_neighbors_isolation(sorting_result, this_unit_id) + + +def test_nearest_neighbors_noise_overlap(sorting_result_simple): + sorting_result = sorting_result_simple + this_unit_id = sorting_result.unit_ids[0] + nearest_neighbors_noise_overlap(sorting_result, this_unit_id) + +if __name__ == "__main__": + sorting_result = _sorting_result_simple() + test_calculate_pc_metrics(sorting_result) + test_nearest_neighbors_isolation(sorting_result) + test_nearest_neighbors_noise_overlap(sorting_result) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index e697c0e762..21c7285455 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -5,29 +5,41 @@ import numpy as np import shutil -from spikeinterface import ( - WaveformExtractor, +from spikeinterface.core import ( + generate_ground_truth_recording, + start_sorting_result, NumpySorting, - compute_sparsity, - load_extractor, - extract_waveforms, - split_recording, - select_segment_sorting, - load_waveforms, aggregate_units, ) -from spikeinterface.extractors import toy_example -from spikeinterface.postprocessing import ( - compute_principal_components, - compute_spike_amplitudes, - compute_spike_locations, - compute_noise_levels, + +from spikeinterface.qualitymetrics import ( + compute_quality_metrics, ) -from spikeinterface.preprocessing import scale -from spikeinterface.qualitymetrics import QualityMetricCalculator, get_default_qm_params -from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite +# generate_ground_truth_recording +# WaveformExtractor, +# NumpySorting, +# compute_sparsity, +# load_extractor, +# extract_waveforms, +# split_recording, +# select_segment_sorting, +# load_waveforms, +# aggregate_units, +# ) +# from spikeinterface.extractors import toy_example + +# from spikeinterface.postprocessing import ( +# compute_principal_components, +# compute_spike_amplitudes, +# compute_spike_locations, +# compute_noise_levels, +# ) +# from spikeinterface.preprocessing import scale +# from spikeinterface.qualitymetrics import QualityMetricCalculator, get_default_qm_params + +# from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite if hasattr(pytest, "global_test_folder"): @@ -36,294 +48,275 @@ cache_folder = Path("cache_folder") / "qualitymetrics" -class QualityMetricsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): - extension_class = QualityMetricCalculator - extension_data_names = ["metrics"] - extension_function_params_list = [dict(), dict(n_jobs=2), dict(metric_names=["snr", "firing_rate"])] - - exact_same_content = False - - def _clean_folders_metrics(self): - for name in ( - "toy_rec_long", - "toy_sorting_long", - "toy_waveforms_long", - "toy_waveforms_short", - "toy_waveforms_inv", - ): - if (cache_folder / name).is_dir(): - shutil.rmtree(cache_folder / name) - - def setUp(self): - super().setUp() - self._clean_folders_metrics() - - recording, sorting = toy_example(num_segments=2, num_units=10, duration=120, seed=42) - recording = recording.save(folder=cache_folder / "toy_rec_long") - sorting = sorting.save(folder=cache_folder / "toy_sorting_long") - we_long = extract_waveforms( - recording, - sorting, - cache_folder / "toy_waveforms_long", - max_spikes_per_unit=500, - overwrite=True, - seed=0, - ) - # make a short we for testing amp cutoff - recording_one = split_recording(recording)[0] - sorting_one = select_segment_sorting(sorting, [0]) - - nsec_short = 30 - recording_short = recording_one.frame_slice( - start_frame=0, end_frame=int(nsec_short * recording.sampling_frequency) - ) - sorting_short = sorting_one.frame_slice(start_frame=0, end_frame=int(nsec_short * recording.sampling_frequency)) - we_short = extract_waveforms( - recording_short, - sorting_short, - cache_folder / "toy_waveforms_short", - max_spikes_per_unit=500, - overwrite=True, - seed=0, - ) - self.sparsity_long = compute_sparsity(we_long, method="radius", radius_um=50) - self.we_long = we_long - self.we_short = we_short - - def tearDown(self): - super().tearDown() - # delete object to release memmap - del self.we_long, self.we_short - self._clean_folders_metrics() - - def test_metrics(self): - we = self.we_long - - # avoid NaNs - if we.has_extension("spike_amplitudes"): - we.delete_extension("spike_amplitudes") - - # without PC - metrics = self.extension_class.get_extension_function()(we, metric_names=["snr"]) - assert "snr" in metrics.columns - assert "isolation_distance" not in metrics.columns - metrics = self.extension_class.get_extension_function()( - we, metric_names=["snr"], qm_params=dict(isi_violation=dict(isi_threshold_ms=2)) - ) - # check that parameters are correctly set - qm = we.load_extension("quality_metrics") - assert qm._params["qm_params"]["isi_violation"]["isi_threshold_ms"] == 2 - assert "snr" in metrics.columns - assert "isolation_distance" not in metrics.columns - # print(metrics) - - # with PCs - # print("Computing PCA") - _ = compute_principal_components(we, n_components=5, mode="by_channel_local") - metrics = self.extension_class.get_extension_function()(we, seed=0) - assert "isolation_distance" in metrics.columns - - # with PC - parallel - metrics_par = self.extension_class.get_extension_function()( - we, n_jobs=2, verbose=True, progress_bar=True, seed=0 - ) - # print(metrics) - # print(metrics_par) - for metric_name in metrics.columns: - # skip NaNs - metric_values = metrics[metric_name].values[~np.isnan(metrics[metric_name].values)] - metric_par_values = metrics_par[metric_name].values[~np.isnan(metrics_par[metric_name].values)] - assert np.allclose(metric_values, metric_par_values) - # print(metrics) - - # with sparsity - metrics_sparse = self.extension_class.get_extension_function()(we, sparsity=self.sparsity_long, n_jobs=1) - assert "isolation_distance" in metrics_sparse.columns - # for metric_name in metrics.columns: - # assert np.allclose(metrics[metric_name], metrics_par[metric_name]) - # print(metrics_sparse) - - def test_amplitude_cutoff(self): - we = self.we_short - _ = compute_spike_amplitudes(we, peak_sign="neg") - - # If too few spikes, should raise a warning and set amplitude cutoffs to nans - with pytest.warns(UserWarning) as w: - metrics = self.extension_class.get_extension_function()( - we, metric_names=["amplitude_cutoff"], peak_sign="neg" - ) - assert all(np.isnan(cutoff) for cutoff in metrics["amplitude_cutoff"].values) - - # now we decrease the number of bins and check that amplitude cutoffs are correctly computed - qm_params = dict(amplitude_cutoff=dict(num_histogram_bins=5)) - with warnings.catch_warnings(): - warnings.simplefilter("error") - metrics = self.extension_class.get_extension_function()( - we, metric_names=["amplitude_cutoff"], peak_sign="neg", qm_params=qm_params +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + +def get_sorting_result(seed=2205): + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[120.0,], sampling_frequency=30_000.0, num_channels=6, num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params_range=dict( + alpha=(9_000.0, 12_000.0), ) - assert all(not np.isnan(cutoff) for cutoff in metrics["amplitude_cutoff"].values) + ), + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=seed, + ) + + sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True) + + sorting_result.select_random_spikes(max_spikes_per_unit=300, seed=seed) + sorting_result.compute("noise_levels") + sorting_result.compute("waveforms", **job_kwargs) + sorting_result.compute("templates") + sorting_result.compute("spike_amplitudes", **job_kwargs) + + return sorting_result + + +@pytest.fixture(scope="module") +def sorting_result_simple(): + sorting_result = get_sorting_result(seed=2205) + return sorting_result + + +def test_compute_quality_metrics(sorting_result_simple): + sorting_result = sorting_result_simple + print(sorting_result) + + # without PCs + metrics = compute_quality_metrics(sorting_result, + metric_names=["snr"], + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=True, + seed=2205 + ) + # print(metrics) + + qm = sorting_result.get_extension("quality_metrics") + assert qm.params["qm_params"]["isi_violation"]["isi_threshold_ms"] == 2 + assert "snr" in metrics.columns + assert "isolation_distance" not in metrics.columns + + # with PCs + sorting_result.compute("principal_components") + metrics = compute_quality_metrics(sorting_result, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205 + ) + print(metrics.columns) + assert "isolation_distance" in metrics.columns + +def test_compute_quality_metrics_recordingless(sorting_result_simple): + + sorting_result = sorting_result_simple + metrics = compute_quality_metrics(sorting_result, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205 + ) + + + # make a copy and make it recordingless + sorting_result_norec = sorting_result.save_as(format="memory") + sorting_result_norec._recording = None + assert not sorting_result_norec.has_recording() + + print(sorting_result_norec) + + metrics_norec = compute_quality_metrics(sorting_result_norec, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205 + ) + + for metric_name in metrics.columns: + if metric_name == "sd_ratio": + # this one need recording!!! + continue + assert np.allclose(metrics[metric_name].values, metrics_norec[metric_name].values, rtol=1e-02) + + +def test_empty_units(sorting_result_simple): + sorting_result = sorting_result_simple + + empty_spike_train = np.array([], dtype="int64") + empty_sorting = NumpySorting.from_unit_dict( + {100: empty_spike_train, 200: empty_spike_train, 300: empty_spike_train}, + sampling_frequency=sorting_result.sampling_frequency, + ) + sorting_empty = aggregate_units([sorting_result.sorting, empty_sorting]) + assert len(sorting_empty.get_empty_unit_ids()) == 3 + + sorting_result_empty = start_sorting_result(sorting_empty, sorting_result.recording, format="memory") + sorting_result_empty.select_random_spikes(max_spikes_per_unit=300, seed=2205) + sorting_result_empty.compute("noise_levels") + sorting_result_empty.compute("waveforms", **job_kwargs) + sorting_result_empty.compute("templates") + sorting_result_empty.compute("spike_amplitudes", **job_kwargs) + + metrics_empty = compute_quality_metrics(sorting_result_empty, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=True, + seed=2205 + ) + + + for empty_unit_id in sorting_empty.get_empty_unit_ids(): + assert np.all(np.isnan(metrics_empty.loc[empty_unit_id])) + + + +# @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics() + +# def test_amplitude_cutoff(self): +# we = self.we_short +# _ = compute_spike_amplitudes(we, peak_sign="neg") + +# # If too few spikes, should raise a warning and set amplitude cutoffs to nans +# with pytest.warns(UserWarning) as w: +# metrics = self.extension_class.get_extension_function()( +# we, metric_names=["amplitude_cutoff"], peak_sign="neg" +# ) +# assert all(np.isnan(cutoff) for cutoff in metrics["amplitude_cutoff"].values) + +# # now we decrease the number of bins and check that amplitude cutoffs are correctly computed +# qm_params = dict(amplitude_cutoff=dict(num_histogram_bins=5)) +# with warnings.catch_warnings(): +# warnings.simplefilter("error") +# metrics = self.extension_class.get_extension_function()( +# we, metric_names=["amplitude_cutoff"], peak_sign="neg", qm_params=qm_params +# ) +# assert all(not np.isnan(cutoff) for cutoff in metrics["amplitude_cutoff"].values) + +# def test_presence_ratio(self): +# we = self.we_long + +# total_duration = we.get_total_duration() +# # If bin_duration_s is larger than total duration, should raise a warning and set presence ratios to nans +# qm_params = dict(presence_ratio=dict(bin_duration_s=total_duration + 1)) +# with pytest.warns(UserWarning) as w: +# metrics = self.extension_class.get_extension_function()( +# we, metric_names=["presence_ratio"], qm_params=qm_params +# ) +# assert all(np.isnan(ratio) for ratio in metrics["presence_ratio"].values) + +# # now we decrease the bin_duration_s and check that presence ratios are correctly computed +# qm_params = dict(presence_ratio=dict(bin_duration_s=total_duration // 10)) +# with warnings.catch_warnings(): +# warnings.simplefilter("error") +# metrics = self.extension_class.get_extension_function()( +# we, metric_names=["presence_ratio"], qm_params=qm_params +# ) +# assert all(not np.isnan(ratio) for ratio in metrics["presence_ratio"].values) + +# def test_drift_metrics(self): +# we = self.we_long # is also multi-segment + +# # if spike_locations is not an extension, raise a warning and set values to NaN +# with pytest.warns(UserWarning) as w: +# metrics = self.extension_class.get_extension_function()(we, metric_names=["drift"]) +# assert all(np.isnan(metric) for metric in metrics["drift_ptp"].values) +# assert all(np.isnan(metric) for metric in metrics["drift_std"].values) +# assert all(np.isnan(metric) for metric in metrics["drift_mad"].values) + +# # now we compute spike locations, but use an interval_s larger than half the total duration +# _ = compute_spike_locations(we) +# total_duration = we.get_total_duration() +# qm_params = dict(drift=dict(interval_s=total_duration // 2 + 1, min_spikes_per_interval=10, min_num_bins=2)) +# with pytest.warns(UserWarning) as w: +# metrics = self.extension_class.get_extension_function()(we, metric_names=["drift"], qm_params=qm_params) +# assert all(np.isnan(metric) for metric in metrics["drift_ptp"].values) +# assert all(np.isnan(metric) for metric in metrics["drift_std"].values) +# assert all(np.isnan(metric) for metric in metrics["drift_mad"].values) + +# # finally let's use an interval compatible with segment durations +# qm_params = dict(drift=dict(interval_s=total_duration // 10, min_spikes_per_interval=10)) +# with warnings.catch_warnings(): +# warnings.simplefilter("error") +# metrics = self.extension_class.get_extension_function()(we, metric_names=["drift"], qm_params=qm_params) +# # print(metrics) +# assert all(not np.isnan(metric) for metric in metrics["drift_ptp"].values) +# assert all(not np.isnan(metric) for metric in metrics["drift_std"].values) +# assert all(not np.isnan(metric) for metric in metrics["drift_mad"].values) + +# def test_peak_sign(self): +# we = self.we_long +# rec = we.recording +# sort = we.sorting + +# # invert recording +# rec_inv = scale(rec, gain=-1.0) + +# we_inv = extract_waveforms(rec_inv, sort, cache_folder / "toy_waveforms_inv", seed=0) + +# # compute amplitudes +# _ = compute_spike_amplitudes(we, peak_sign="neg") +# _ = compute_spike_amplitudes(we_inv, peak_sign="pos") + +# # without PC +# metrics = self.extension_class.get_extension_function()( +# we, metric_names=["snr", "amplitude_cutoff"], peak_sign="neg" +# ) +# metrics_inv = self.extension_class.get_extension_function()( +# we_inv, metric_names=["snr", "amplitude_cutoff"], peak_sign="pos" +# ) +# # print(metrics) +# # print(metrics_inv) +# # for SNR we allow a 5% tollerance because of waveform sub-sampling +# assert np.allclose(metrics["snr"].values, metrics_inv["snr"].values, rtol=0.05) +# # for amplitude_cutoff, since spike amplitudes are computed, values should be exactly the same +# assert np.allclose(metrics["amplitude_cutoff"].values, metrics_inv["amplitude_cutoff"].values, atol=1e-3) + +# def test_nn_metrics(self): +# we_dense = self.we1 +# we_sparse = self.we_sparse +# sparsity = self.sparsity1 +# # print(sparsity) + +# metric_names = ["nearest_neighbor", "nn_isolation", "nn_noise_overlap"] + +# # with external sparsity on dense waveforms +# _ = compute_principal_components(we_dense, n_components=5, mode="by_channel_local") +# metrics = self.extension_class.get_extension_function()( +# we_dense, metric_names=metric_names, sparsity=sparsity, seed=0 +# ) +# # print(metrics) + +# # with sparse waveforms +# _ = compute_principal_components(we_sparse, n_components=5, mode="by_channel_local") +# metrics = self.extension_class.get_extension_function()( +# we_sparse, metric_names=metric_names, sparsity=None, seed=0 +# ) +# # print(metrics) + +# # with 2 jobs +# # with sparse waveforms +# _ = compute_principal_components(we_sparse, n_components=5, mode="by_channel_local") +# metrics_par = self.extension_class.get_extension_function()( +# we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2 +# ) +# for metric_name in metrics.columns: +# # NaNs are skipped +# assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) - def test_presence_ratio(self): - we = self.we_long +if __name__ == "__main__": - total_duration = we.get_total_duration() - # If bin_duration_s is larger than total duration, should raise a warning and set presence ratios to nans - qm_params = dict(presence_ratio=dict(bin_duration_s=total_duration + 1)) - with pytest.warns(UserWarning) as w: - metrics = self.extension_class.get_extension_function()( - we, metric_names=["presence_ratio"], qm_params=qm_params - ) - assert all(np.isnan(ratio) for ratio in metrics["presence_ratio"].values) - - # now we decrease the bin_duration_s and check that presence ratios are correctly computed - qm_params = dict(presence_ratio=dict(bin_duration_s=total_duration // 10)) - with warnings.catch_warnings(): - warnings.simplefilter("error") - metrics = self.extension_class.get_extension_function()( - we, metric_names=["presence_ratio"], qm_params=qm_params - ) - assert all(not np.isnan(ratio) for ratio in metrics["presence_ratio"].values) - - def test_drift_metrics(self): - we = self.we_long # is also multi-segment - - # if spike_locations is not an extension, raise a warning and set values to NaN - with pytest.warns(UserWarning) as w: - metrics = self.extension_class.get_extension_function()(we, metric_names=["drift"]) - assert all(np.isnan(metric) for metric in metrics["drift_ptp"].values) - assert all(np.isnan(metric) for metric in metrics["drift_std"].values) - assert all(np.isnan(metric) for metric in metrics["drift_mad"].values) - - # now we compute spike locations, but use an interval_s larger than half the total duration - _ = compute_spike_locations(we) - total_duration = we.get_total_duration() - qm_params = dict(drift=dict(interval_s=total_duration // 2 + 1, min_spikes_per_interval=10, min_num_bins=2)) - with pytest.warns(UserWarning) as w: - metrics = self.extension_class.get_extension_function()(we, metric_names=["drift"], qm_params=qm_params) - assert all(np.isnan(metric) for metric in metrics["drift_ptp"].values) - assert all(np.isnan(metric) for metric in metrics["drift_std"].values) - assert all(np.isnan(metric) for metric in metrics["drift_mad"].values) - - # finally let's use an interval compatible with segment durations - qm_params = dict(drift=dict(interval_s=total_duration // 10, min_spikes_per_interval=10)) - with warnings.catch_warnings(): - warnings.simplefilter("error") - metrics = self.extension_class.get_extension_function()(we, metric_names=["drift"], qm_params=qm_params) - # print(metrics) - assert all(not np.isnan(metric) for metric in metrics["drift_ptp"].values) - assert all(not np.isnan(metric) for metric in metrics["drift_std"].values) - assert all(not np.isnan(metric) for metric in metrics["drift_mad"].values) - - def test_peak_sign(self): - we = self.we_long - rec = we.recording - sort = we.sorting - - # invert recording - rec_inv = scale(rec, gain=-1.0) - - we_inv = extract_waveforms(rec_inv, sort, cache_folder / "toy_waveforms_inv", seed=0) - - # compute amplitudes - _ = compute_spike_amplitudes(we, peak_sign="neg") - _ = compute_spike_amplitudes(we_inv, peak_sign="pos") - - # without PC - metrics = self.extension_class.get_extension_function()( - we, metric_names=["snr", "amplitude_cutoff"], peak_sign="neg" - ) - metrics_inv = self.extension_class.get_extension_function()( - we_inv, metric_names=["snr", "amplitude_cutoff"], peak_sign="pos" - ) - # print(metrics) - # print(metrics_inv) - # for SNR we allow a 5% tollerance because of waveform sub-sampling - assert np.allclose(metrics["snr"].values, metrics_inv["snr"].values, rtol=0.05) - # for amplitude_cutoff, since spike amplitudes are computed, values should be exactly the same - assert np.allclose(metrics["amplitude_cutoff"].values, metrics_inv["amplitude_cutoff"].values, atol=1e-3) - - def test_nn_metrics(self): - we_dense = self.we1 - we_sparse = self.we_sparse - sparsity = self.sparsity1 - # print(sparsity) - - metric_names = ["nearest_neighbor", "nn_isolation", "nn_noise_overlap"] - - # with external sparsity on dense waveforms - _ = compute_principal_components(we_dense, n_components=5, mode="by_channel_local") - metrics = self.extension_class.get_extension_function()( - we_dense, metric_names=metric_names, sparsity=sparsity, seed=0 - ) - # print(metrics) - - # with sparse waveforms - _ = compute_principal_components(we_sparse, n_components=5, mode="by_channel_local") - metrics = self.extension_class.get_extension_function()( - we_sparse, metric_names=metric_names, sparsity=None, seed=0 - ) - # print(metrics) - - # with 2 jobs - # with sparse waveforms - _ = compute_principal_components(we_sparse, n_components=5, mode="by_channel_local") - metrics_par = self.extension_class.get_extension_function()( - we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2 - ) - for metric_name in metrics.columns: - # NaNs are skipped - assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) - - def test_recordingless(self): - we = self.we_long - # pre-compute needed extensions - _ = compute_noise_levels(we) - _ = compute_spike_amplitudes(we) - _ = compute_spike_locations(we) - - # load in recordingless mode - we_no_rec = load_waveforms(we.folder, with_recording=False) - qm_rec = self.extension_class.get_extension_function()(we) - qm_no_rec = self.extension_class.get_extension_function()(we_no_rec) - - # print(qm_rec) - # print(qm_no_rec) - - # check metrics are the same - for metric_name in qm_rec.columns: - if metric_name == "sd_ratio": - continue - - # rtol is addedd for sliding_rp_violation, for a reason I do not have to explore now. Sam. - assert np.allclose(qm_rec[metric_name].values, qm_no_rec[metric_name].values, rtol=1e-02) - - def test_empty_units(self): - we = self.we1 - empty_spike_train = np.array([], dtype="int64") - empty_sorting = NumpySorting.from_unit_dict( - {100: empty_spike_train, 200: empty_spike_train, 300: empty_spike_train}, - sampling_frequency=we.sampling_frequency, - ) - sorting_w_empty = aggregate_units([we.sorting, empty_sorting]) - assert len(sorting_w_empty.get_empty_unit_ids()) == 3 - - we_empty = extract_waveforms(we.recording, sorting_w_empty, folder=None, mode="memory") - qm_empty = self.extension_class.get_extension_function()(we_empty) - - for empty_unit in sorting_w_empty.get_empty_unit_ids(): - assert np.all(np.isnan(qm_empty.loc[empty_unit])) + sorting_result = get_sorting_result() + print(sorting_result) + test_compute_quality_metrics(sorting_result) + test_compute_quality_metrics_recordingless(sorting_result) + test_empty_units(sorting_result) -if __name__ == "__main__": - test = QualityMetricsExtensionTest() - test.setUp() - test.test_extension() - test.test_metrics() - test.test_amplitude_cutoff() - test.test_presence_ratio() - test.test_drift_metrics() - test.test_peak_sign() - test.test_nn_metrics() - test.test_recordingless() - test.test_empty_units() - test.tearDown() From 4c17930fb49fd084320b25516ed66f1569b15c4a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 8 Feb 2024 12:05:57 +0100 Subject: [PATCH 039/192] remove extract_waveforms() from some places --- .../comparison/tests/test_groundtruthstudy.py | 2 +- .../tests/test_templatecomparison.py | 39 ++++---- .../core/tests/test_node_pipeline.py | 9 +- .../core/tests/test_result_core.py | 12 +-- .../core/tests/test_sortingresult.py | 7 +- .../core/tests/test_template_tools.py | 94 +++++++++---------- ...forms_extractor_backwards_compatibility.py | 4 + src/spikeinterface/postprocessing/__init__.py | 1 + .../postprocessing/template_similarity.py | 11 ++- 9 files changed, 101 insertions(+), 78 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 91c8c640e0..ef79299795 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -86,7 +86,7 @@ def test_GroundTruthStudy(): study.run_comparisons() print(study.comparisons) - study.extract_waveforms_gt(n_jobs=-1) + study.start_sorting_result_gt(n_jobs=-1) study.compute_metrics() diff --git a/src/spikeinterface/comparison/tests/test_templatecomparison.py b/src/spikeinterface/comparison/tests/test_templatecomparison.py index 14e9ebe1e6..90f35e4dbf 100644 --- a/src/spikeinterface/comparison/tests/test_templatecomparison.py +++ b/src/spikeinterface/comparison/tests/test_templatecomparison.py @@ -3,24 +3,24 @@ from pathlib import Path import numpy as np -from spikeinterface.core import extract_waveforms +from spikeinterface.core import start_sorting_result from spikeinterface.extractors import toy_example from spikeinterface.comparison import compare_templates, compare_multiple_templates -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "comparison" -else: - cache_folder = Path("cache_folder") / "comparison" +# if hasattr(pytest, "global_test_folder"): +# cache_folder = pytest.global_test_folder / "comparison" +# else: +# cache_folder = Path("cache_folder") / "comparison" -test_dir = cache_folder / "temp_comp_test" +# test_dir = cache_folder / "temp_comp_test" -def setup_module(): - if test_dir.is_dir(): - shutil.rmtree(test_dir) - test_dir.mkdir(exist_ok=True) +# def setup_module(): +# if test_dir.is_dir(): +# shutil.rmtree(test_dir) +# test_dir.mkdir(exist_ok=True) def test_compare_multiple_templates(): @@ -28,8 +28,8 @@ def test_compare_multiple_templates(): num_channels = 8 rec, sort = toy_example(duration=duration, num_segments=1, num_channels=num_channels) - rec = rec.save(folder=test_dir / "rec") - sort = sort.save(folder=test_dir / "sort") + # rec = rec.save(folder=test_dir / "rec") + # sort = sort.save(folder=test_dir / "sort") # split recording in 3 equal slices fs = rec.get_sampling_frequency() @@ -39,13 +39,18 @@ def test_compare_multiple_templates(): sort1 = sort.frame_slice(start_frame=0 * fs, end_frame=duration / 3 * fs) sort2 = sort.frame_slice(start_frame=duration / 3 * fs, end_frame=2 / 3 * duration * fs) sort3 = sort.frame_slice(start_frame=2 / 3 * duration * fs, end_frame=duration * fs) + # compute waveforms - we1 = extract_waveforms(rec1, sort1, test_dir / "wf1", n_jobs=1) - we2 = extract_waveforms(rec2, sort2, test_dir / "wf2", n_jobs=1) - we3 = extract_waveforms(rec3, sort3, test_dir / "wf3", n_jobs=1) + sorting_result_1 = start_sorting_result(sort1, rec1, format="memory") + sorting_result_2 = start_sorting_result(sort2, rec2, format="memory") + sorting_result_3 = start_sorting_result(sort3, rec3, format="memory") + + for sorting_result in (sorting_result_1, sorting_result_2, sorting_result_3): + sorting_result.select_random_spikes() + sorting_result.compute("fast_templates") # paired comparison - temp_cmp = compare_templates(we1, we2) + temp_cmp = compare_templates(sorting_result_1, sorting_result_2) for u1 in temp_cmp.hungarian_match_12.index.values: u2 = temp_cmp.hungarian_match_12[u1] @@ -53,7 +58,7 @@ def test_compare_multiple_templates(): assert u1 == u2 # multi-comparison - temp_mcmp = compare_multiple_templates([we1, we2, we3]) + temp_mcmp = compare_multiple_templates([sorting_result_1, sorting_result_2, sorting_result_3]) # assert unit ids are the same across sessions (because of initial slicing) for unit_dict in temp_mcmp.units.values(): unit_ids = unit_dict["unit_ids"].values() diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index e3a793b2d5..2e25aa618e 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,7 +3,7 @@ from pathlib import Path import shutil -from spikeinterface import extract_waveforms, get_template_extremum_channel, generate_ground_truth_recording +from spikeinterface import start_sorting_result, get_template_extremum_channel, generate_ground_truth_recording # from spikeinterface.sortingcomponents.peak_detection import detect_peaks @@ -77,8 +77,11 @@ def test_run_node_pipeline(): spikes = sorting.to_spike_vector() # create peaks from spikes - we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") + sorting_result = start_sorting_result(sorting, recording, format="memory") + sorting_result.select_random_spikes() + sorting_result.compute("fast_templates") + extremum_channel_inds = get_template_extremum_channel(sorting_result, peak_sign="neg", outputs="index") + peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) peak_retriever = PeakRetriever(recording, peaks) diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index a7ee0cc322..809a5587d8 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -152,7 +152,7 @@ def test_ComputeFastTemplates(format, sparse): # plt.show() @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) -# @pytest.mark.parametrize("sparse", [True, False]) +@pytest.mark.parametrize("sparse", [True, False]) def test_ComputeNoiseLevels(format, sparse): sortres = get_sorting_result(format=format, sparse=sparse) @@ -171,11 +171,11 @@ def test_ComputeNoiseLevels(format, sparse): # test_ComputeWaveforms(format="zarr", sparse=True) # test_ComputeWaveforms(format="zarr", sparse=False) - test_ComputeTemplates(format="memory", sparse=True) - test_ComputeTemplates(format="memory", sparse=False) - test_ComputeTemplates(format="binary_folder", sparse=True) - test_ComputeTemplates(format="zarr", sparse=True) + # test_ComputeTemplates(format="memory", sparse=True) + # test_ComputeTemplates(format="memory", sparse=False) + # test_ComputeTemplates(format="binary_folder", sparse=True) + # test_ComputeTemplates(format="zarr", sparse=True) - test_ComputeFastTemplates(format="memory", sparse=True) + # test_ComputeFastTemplates(format="memory", sparse=True) test_ComputeNoiseLevels(format="memory", sparse=False) diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index 111ccd6cd7..bde0210e37 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -175,6 +175,9 @@ def _select_extension_data(self, unit_ids): new_data["result_two"] = self.data["result_two"][keep_spike_mask] return new_data + + def _get_data(self): + return self.data["result_one"] compute_dummy = DummyResultExtension.function_factory() @@ -194,7 +197,7 @@ def test_extension(): if __name__ == "__main__": - # test_SortingResult_memory() + test_SortingResult_memory() test_SortingResult_binary_folder() test_SortingResult_zarr() - # test_extension() + test_extension() diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index eaa7712fcb..712f87f1e7 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -1,8 +1,8 @@ import pytest -import shutil -from pathlib import Path -from spikeinterface import load_extractor, extract_waveforms, load_waveforms, generate_recording, generate_sorting +from spikeinterface.core import generate_ground_truth_recording, start_sorting_result + + from spikeinterface import Templates from spikeinterface.core import ( get_template_amplitudes, @@ -12,65 +12,63 @@ ) -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" +def get_sorting_result(): + recording, sorting = generate_ground_truth_recording( + durations=[10.0, 5.0], sampling_frequency=10_000.0, num_channels=4, num_units=10, + noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + recording.annotate(is_filtered=True) + recording.set_channel_groups([0, 0, 1, 1]) + sorting.set_property("group", [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=False) + sorting_result.select_random_spikes() + sorting_result.compute("fast_templates") -def setup_module(): - for folder_name in ("toy_rec", "toy_sort", "toy_waveforms", "toy_waveforms_1"): - if (cache_folder / folder_name).is_dir(): - shutil.rmtree(cache_folder / folder_name) + return sorting_result - durations = [10.0, 5.0] - recording = generate_recording(durations=durations, num_channels=4) - sorting = generate_sorting(durations=durations, num_units=10) +@pytest.fixture(scope="module") +def sorting_result(): + return get_sorting_result() - recording.annotate(is_filtered=True) - recording.set_channel_groups([0, 0, 1, 1]) - recording = recording.save(folder=cache_folder / "toy_rec") - sorting.set_property("group", [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) - sorting = sorting.save(folder=cache_folder / "toy_sort") - we = extract_waveforms(recording, sorting, cache_folder / "toy_waveforms") -def _get_templates_object_from_waveform_extractor(we): +def _get_templates_object_from_sorting_result(sorting_result): + ext = sorting_result.get_extension("fast_templates") templates = Templates( - templates_array=we.get_all_templates(mode="average"), - sampling_frequency=we.sampling_frequency, - nbefore=we.nbefore, + templates_array=ext.data["average"], + sampling_frequency=sorting_result.sampling_frequency, + nbefore=ext.nbefore, + # this is dense sparsity_mask=None, - channel_ids=we.channel_ids, - unit_ids=we.unit_ids, + channel_ids=sorting_result.channel_ids, + unit_ids=sorting_result.unit_ids, ) return templates -def test_get_template_amplitudes(): - we = load_waveforms(cache_folder / "toy_waveforms") - peak_values = get_template_amplitudes(we) +def test_get_template_amplitudes(sorting_result): + peak_values = get_template_amplitudes(sorting_result) print(peak_values) - templates = _get_templates_object_from_waveform_extractor(we) + templates = _get_templates_object_from_sorting_result(sorting_result) peak_values = get_template_amplitudes(templates) print(peak_values) -def test_get_template_extremum_channel(): - we = load_waveforms(cache_folder / "toy_waveforms") - extremum_channels_ids = get_template_extremum_channel(we, peak_sign="both") +def test_get_template_extremum_channel(sorting_result): + extremum_channels_ids = get_template_extremum_channel(sorting_result, peak_sign="both") print(extremum_channels_ids) - templates = _get_templates_object_from_waveform_extractor(we) + templates = _get_templates_object_from_sorting_result(sorting_result) extremum_channels_ids = get_template_extremum_channel(templates, peak_sign="both") print(extremum_channels_ids) -def test_get_template_extremum_channel_peak_shift(): - we = load_waveforms(cache_folder / "toy_waveforms") - shifts = get_template_extremum_channel_peak_shift(we, peak_sign="neg") +def test_get_template_extremum_channel_peak_shift(sorting_result): + shifts = get_template_extremum_channel_peak_shift(sorting_result, peak_sign="neg") print(shifts) - templates = _get_templates_object_from_waveform_extractor(we) + templates = _get_templates_object_from_sorting_result(sorting_result) shifts = get_template_extremum_channel_peak_shift(templates, peak_sign="neg") # DEBUG @@ -89,20 +87,22 @@ def test_get_template_extremum_channel_peak_shift(): # plt.show() -def test_get_template_extremum_amplitude(): - we = load_waveforms(cache_folder / "toy_waveforms") +def test_get_template_extremum_amplitude(sorting_result): - extremum_channels_ids = get_template_extremum_amplitude(we, peak_sign="both") + extremum_channels_ids = get_template_extremum_amplitude(sorting_result, peak_sign="both") print(extremum_channels_ids) - templates = _get_templates_object_from_waveform_extractor(we) + templates = _get_templates_object_from_sorting_result(sorting_result) extremum_channels_ids = get_template_extremum_amplitude(templates, peak_sign="both") if __name__ == "__main__": - setup_module() + # setup_module() + + sorting_result = get_sorting_result() + print(sorting_result) - test_get_template_amplitudes() - test_get_template_extremum_channel() - test_get_template_extremum_channel_peak_shift() - test_get_template_extremum_amplitude() + test_get_template_amplitudes(sorting_result) + test_get_template_extremum_channel(sorting_result) + test_get_template_extremum_channel_peak_shift(sorting_result) + test_get_template_extremum_amplitude(sorting_result) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index ff9f694b52..a81a141486 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -213,6 +213,10 @@ def get_recording_property(self, key) -> np.ndarray: def get_sorting_property(self, key) -> np.ndarray: return self.sorting_result.get_sorting_property(key) + @property + def sparsity(self): + return self.sorting_result.sparsity + def has_extension(self, extension_name: str) -> bool: return self.sorting_result.has_extension(extension_name) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index a56457e34c..528f2d3761 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -7,6 +7,7 @@ from .template_similarity import ( ComputeTemplateSimilarity, compute_template_similarity, + compute_template_similarity_by_pair, check_equal_template_with_distribution_overlap, ) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index d85dd8dad9..1d3f259b1d 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -49,7 +49,7 @@ def _run(self): def _get_data(self): return self.data["similarity"] - +# @alessio: compute_template_similarity() is now one inner SortingResult only register_result_extension(ComputeTemplateSimilarity) compute_template_similarity = ComputeTemplateSimilarity.function_factory() @@ -68,7 +68,14 @@ def compute_similarity_with_templates_array(templates_array, other_templates_arr return similarity -# TODO port the waveform_extractor_other concept that compare 2 SortingResult +def compute_template_similarity_by_pair(sorting_result_1, sorting_result_2, method="cosine_similarity"): + templates_array_1 = _get_dense_templates_array(sorting_result_1, return_scaled=True) + templates_array_2 = _get_dense_templates_array(sorting_result_2, return_scaled=True) + similmarity = compute_similarity_with_templates_array(templates_array_1, templates_array_2, method) + return similmarity + + + # def _compute_template_similarity( From a6d8973d94c209f2ddca3cd432d7bda4229f0250 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 8 Feb 2024 12:06:15 +0100 Subject: [PATCH 040/192] remove extract_waveforms() for comparisons module --- .../comparison/groundtruthstudy.py | 23 +++++--- .../comparison/multicomparisons.py | 4 +- .../comparison/paircomparisons.py | 56 +++++++++---------- 3 files changed, 44 insertions(+), 39 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 448ac3b361..178273f90d 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -8,12 +8,12 @@ import numpy as np -from spikeinterface.core import load_extractor, extract_waveforms, load_waveforms +from spikeinterface.core import load_extractor, start_sorting_result, load_sorting_result from spikeinterface.core.core_tools import SIJsonEncoder +from spikeinterface.core.job_tools import split_job_kwargs from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder -from spikeinterface import WaveformExtractor from spikeinterface.qualitymetrics import compute_quality_metrics from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison @@ -284,28 +284,33 @@ def get_run_times(self, case_keys=None): return pd.Series(run_times, name="run_time") - def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): + def start_sorting_result_gt(self, case_keys=None, **kwargs): if case_keys is None: case_keys = self.cases.keys() - base_folder = self.folder / "waveforms" + select_params, job_kwargs = split_job_kwargs(kwargs) + + base_folder = self.folder / "sorting_result" base_folder.mkdir(exist_ok=True) dataset_keys = [self.cases[key]["dataset"] for key in case_keys] dataset_keys = set(dataset_keys) for dataset_key in dataset_keys: # the waveforms depend on the dataset key - wf_folder = base_folder / self.key_to_str(dataset_key) + folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] - we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs) + sorting_result = start_sorting_result(gt_sorting, recording, format="binray_folder", folder=folder) + sorting_result.select_random_spikes(**select_params) + sorting_result.compute("fast_templates", **job_kwargs) + def get_waveform_extractor(self, case_key=None, dataset_key=None): if case_key is not None: dataset_key = self.cases[case_key]["dataset"] - wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) - we = load_waveforms(wf_folder, with_recording=True) - return we + folder = self.folder / "sorting_result" / self.key_to_str(dataset_key) + sorting_result = load_sorting_result(folder) + return sorting_result def get_templates(self, key, mode="average"): we = self.get_waveform_extractor(case_key=key) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index f0114bd5a3..77adcaa8ca 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -356,8 +356,8 @@ def _compare_ij(self, i, j): comp = TemplateComparison( self.object_list[i], self.object_list[j], - we1_name=self.name_list[i], - we2_name=self.name_list[j], + name1=self.name_list[i], + name2=self.name_list[j], match_score=self.match_score, verbose=False, ) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 64c13d60e4..cba47a6b67 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -13,7 +13,7 @@ do_count_score, compute_performance, ) -from ..postprocessing import compute_template_similarity +from ..postprocessing import compute_template_similarity_by_pair class BasePairSorterComparison(BasePairComparison, MixinSpikeTrainComparison): @@ -696,14 +696,14 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): Parameters ---------- - we1 : WaveformExtractor - The first waveform extractor to get templates to compare - we2 : WaveformExtractor - The second waveform extractor to get templates to compare + sorting_result_1 : SortingResult + The first SortingResult to get templates to compare + sorting_result_2 : SortingResult + The second SortingResult to get templates to compare unit_ids1 : list, default: None - List of units from we1 to compare + List of units from sorting_result_1 to compare unit_ids2 : list, default: None - List of units from we2 to compare + List of units from sorting_result_2 to compare similarity_method : str, default: "cosine_similarity" Method for the similaroty matrix sparsity_dict : dict, default: None @@ -719,10 +719,10 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): def __init__( self, - we1, - we2, - we1_name=None, - we2_name=None, + sorting_result_1, + sorting_result_2, + name1=None, + name2=None, unit_ids1=None, unit_ids2=None, match_score=0.7, @@ -731,29 +731,29 @@ def __init__( sparsity_dict=None, verbose=False, ): - if we1_name is None: - we1_name = "sess1" - if we2_name is None: - we2_name = "sess2" + if name1 is None: + name1 = "sess1" + if name2 is None: + name2 = "sess2" BasePairComparison.__init__( self, - object1=we1, - object2=we2, - name1=we1_name, - name2=we2_name, + object1=sorting_result_1, + object2=sorting_result_2, + name1=name1, + name2=name2, match_score=match_score, chance_score=chance_score, verbose=verbose, ) MixinTemplateComparison.__init__(self, similarity_method=similarity_method, sparsity_dict=sparsity_dict) - self.we1 = we1 - self.we2 = we2 - channel_ids1 = we1.recording.get_channel_ids() - channel_ids2 = we2.recording.get_channel_ids() + self.sorting_result_1 = sorting_result_1 + self.sorting_result_2 = sorting_result_2 + channel_ids1 = sorting_result_1.recording.get_channel_ids() + channel_ids2 = sorting_result_2.recording.get_channel_ids() # two options: all channels are shared or partial channels are shared - if we1.recording.get_num_channels() != we2.recording.get_num_channels(): + if sorting_result_1.recording.get_num_channels() != sorting_result_2.recording.get_num_channels(): raise NotImplementedError if np.any([ch1 != ch2 for (ch1, ch2) in zip(channel_ids1, channel_ids2)]): # TODO: here we can check location and run it on the union. Might be useful for reconfigurable probes @@ -762,10 +762,10 @@ def __init__( self.matches = dict() if unit_ids1 is None: - unit_ids1 = we1.sorting.get_unit_ids() + unit_ids1 = sorting_result_1.sorting.get_unit_ids() if unit_ids2 is None: - unit_ids2 = we2.sorting.get_unit_ids() + unit_ids2 = sorting_result_2.sorting.get_unit_ids() self.unit_ids = [unit_ids1, unit_ids2] if sparsity_dict is not None: @@ -780,8 +780,8 @@ def _do_agreement(self): if self._verbose: print("Agreement scores...") - agreement_scores = compute_template_similarity( - self.we1, waveform_extractor_other=self.we2, method=self.similarity_method + agreement_scores = compute_template_similarity_by_pair( + self.sorting_result_1, self.sorting_result_2, method=self.similarity_method ) import pandas as pd From e4ebdc84e9798c1bbfe5a52f9ff399c2158f6617 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 8 Feb 2024 13:15:41 +0100 Subject: [PATCH 041/192] Improve MockWaveformExtractor --- src/spikeinterface/comparison/hybrid.py | 11 +++++--- .../comparison/tests/test_hybrid.py | 11 +++++--- src/spikeinterface/core/sortingresult.py | 8 ++++-- src/spikeinterface/core/sparsity.py | 5 ++++ ...forms_extractor_backwards_compatibility.py | 26 +++++++++++++++++-- 5 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index e77ff584e7..3a87ab1832 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -6,8 +6,8 @@ from spikeinterface.core import ( BaseRecording, BaseSorting, - WaveformExtractor, - NumpySorting, + load_waveforms, + ) from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.core.generate import ( @@ -17,6 +17,9 @@ generate_sorting_to_inject, ) +# TODO aurelien : this is still using the WaveformExtractor!!! can you change it to use SortingResult ? + + class HybridUnitsRecording(InjectTemplatesRecording): """ @@ -155,7 +158,7 @@ class HybridSpikesRecording(InjectTemplatesRecording): def __init__( self, - wvf_extractor: Union[WaveformExtractor, Path], + wvf_extractor, injected_sorting: Union[BaseSorting, None] = None, unit_ids: Union[List[int], None] = None, max_injected_per_unit: int = 1000, @@ -164,7 +167,7 @@ def __init__( injected_sorting_folder: Union[str, Path, None] = None, ) -> None: if isinstance(wvf_extractor, (Path, str)): - wvf_extractor = WaveformExtractor.load(wvf_extractor) + wvf_extractor = load_waveforms(wvf_extractor) target_recording = wvf_extractor.recording target_sorting = wvf_extractor.sorting diff --git a/src/spikeinterface/comparison/tests/test_hybrid.py b/src/spikeinterface/comparison/tests/test_hybrid.py index 144e7aacd0..ab371a38bc 100644 --- a/src/spikeinterface/comparison/tests/test_hybrid.py +++ b/src/spikeinterface/comparison/tests/test_hybrid.py @@ -1,7 +1,7 @@ import pytest import shutil from pathlib import Path -from spikeinterface.core import WaveformExtractor, extract_waveforms, load_extractor +from spikeinterface.core import extract_waveforms, load_waveforms,load_extractor from spikeinterface.core.testing import check_recordings_equal from spikeinterface.comparison import ( create_hybrid_units_recording, @@ -34,7 +34,11 @@ def setup_module(): def test_hybrid_units_recording(): - wvf_extractor = WaveformExtractor.load(cache_folder / "wvf_extractor") + wvf_extractor = load_waveforms(cache_folder / "wvf_extractor") + print(wvf_extractor) + print(wvf_extractor.sorting_result) + + recording = wvf_extractor.recording templates = wvf_extractor.get_all_templates() templates[:, 0, :] = 0 @@ -61,7 +65,7 @@ def test_hybrid_units_recording(): def test_hybrid_spikes_recording(): - wvf_extractor = WaveformExtractor.load_from_folder(cache_folder / "wvf_extractor") + wvf_extractor = load_waveforms(cache_folder / "wvf_extractor") recording = wvf_extractor.recording sorting = wvf_extractor.sorting hybrid_spikes_recording = create_hybrid_spikes_recording( @@ -90,6 +94,5 @@ def test_hybrid_spikes_recording(): if __name__ == "__main__": setup_module() - test_generate_sorting_to_inject() test_hybrid_units_recording() test_hybrid_spikes_recording() diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 239d8d5969..abe284d5b0 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -918,8 +918,10 @@ def get_saved_extension_names(self): saved_extension_names = [] for extension_class in _possible_extensions: extension_name = extension_class.extension_name + if self.format == "binary_folder": - is_saved = (self.folder / extension_name).is_dir() and (self.folder / extension_name / "params.json").is_file() + extension_folder = self.folder / "extensions" /extension_name + is_saved = extension_folder.is_dir() and (extension_folder / "params.json").is_file() elif self.format == "zarr": if extension_group is not None: is_saved = extension_name in extension_group.keys() and "params" in extension_group[extension_name].attrs.keys() @@ -970,6 +972,8 @@ def load_extension(self, extension_name: str): extension_instance.load_params() extension_instance.load_data() + self.extensions[extension_name] = extension_instance + return extension_instance def load_all_saved_extension(self): @@ -1222,7 +1226,7 @@ def folder(self): return self.sorting_result.folder def _get_binary_extension_folder(self): - extension_folder = self.folder / "saved_extensions" /self.extension_name + extension_folder = self.folder / "extensions" /self.extension_name return extension_folder diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index b980b31e27..c4e703d911 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -437,8 +437,13 @@ def compute_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates from .waveform_extractor import WaveformExtractor + from .waveforms_extractor_backwards_compatibility import MockWaveformExtractor from .sortingresult import SortingResult + if isinstance(templates_or_sorting_result, MockWaveformExtractor): + # to keep backward compatibility + templates_or_sorting_result = templates_or_sorting_result.sorting_result + if method in ("best_channels", "radius"): assert isinstance(templates_or_sorting_result, (Templates, WaveformExtractor, SortingResult)), "compute_sparsity() need Templates or WaveformExtractor or SortingResult" else: diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index a81a141486..8433bf63a8 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -20,7 +20,7 @@ from .sortingresult import start_sorting_result from .job_tools import split_job_kwargs from .sparsity import ChannelSparsity -from .sortingresult import SortingResult +from .sortingresult import SortingResult, load_sorting_result from .base import load_extractor from .result_core import ComputeWaveforms, ComputeTemplates @@ -100,11 +100,15 @@ def extract_waveforms( sorting_result.select_random_spikes(max_spikes_per_unit=max_spikes_per_unit, seed=seed) waveforms_params = dict(ms_before=ms_before, ms_after=ms_after, return_scaled=return_scaled, dtype=dtype) - sorting_result.compute("waveforms", **waveforms_params) + sorting_result.compute("waveforms", **waveforms_params, **job_kwargs) templates_params = dict(operators=list(precompute_template)) sorting_result.compute("templates", **templates_params) + # this also done because some metrics need it + sorting_result.compute("noise_levels") + + we = MockWaveformExtractor(sorting_result) return we @@ -217,6 +221,11 @@ def get_sorting_property(self, key) -> np.ndarray: def sparsity(self): return self.sorting_result.sparsity + @property + def folder(self): + if self.sorting_result.format != "memory": + return self.sorting_result.folder + def has_extension(self, extension_name: str) -> bool: return self.sorting_result.has_extension(extension_name) @@ -304,10 +313,23 @@ def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSo """ This read an old WaveformsExtactor folder (folder or zarr) and convert it into a SortingResult or MockWaveformExtractor. + It also mimic the old load_waveforms by opening a Sortingresult folder and return a MockWaveformExtractor. + """ folder = Path(folder) assert folder.is_dir(), "Waveform folder does not exists" + + if (folder / "spikeinterface_info.json").exists: + with open(folder / "spikeinterface_info.json", mode="r") as f: + info = json.load(f) + if info.get("object", None) == "SortingResult": + # in this case the folder is already a sorting result from version >= 0.101.0 but create with the MockWaveformExtractor + sorting_result = load_sorting_result(folder) + sorting_result.load_all_saved_extension() + we = MockWaveformExtractor(sorting_result) + return we + if folder.suffix == ".zarr": raise NotImplementedError # Alessio this is for you From 08183c1d181bb21594de7d353b6e8ce25497dbe7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 8 Feb 2024 15:43:13 +0100 Subject: [PATCH 042/192] Start porting SortingResult on widgets --- ...forms_extractor_backwards_compatibility.py | 9 +- .../postprocessing/spike_amplitudes.py | 3 +- .../postprocessing/template_similarity.py | 2 +- .../widgets/all_amplitudes_distributions.py | 20 +- src/spikeinterface/widgets/base.py | 6 +- .../widgets/tests/test_widgets.py | 227 +++++++++--------- 6 files changed, 137 insertions(+), 130 deletions(-) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 8433bf63a8..d4c6d0e01f 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -309,11 +309,18 @@ def get_template( -def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None, output="SortingResult", ): +def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None, output="MockWaveformExtractor", ): """ This read an old WaveformsExtactor folder (folder or zarr) and convert it into a SortingResult or MockWaveformExtractor. It also mimic the old load_waveforms by opening a Sortingresult folder and return a MockWaveformExtractor. + This later behavior is usefull to no break old code like this in versio >=0.101 + + >>> # In this example we is a MockWaveformExtractor that behave the same as before + >>> we = extract_waveforms(..., folder="/my_we") + >>> we = load_waveforms("/my_we") + >>> templates = we.get_all_templates() + """ diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 4d344e7df6..b9000ddc06 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -118,8 +118,7 @@ def _run(self, **job_kwargs): ) self.data["amplitudes"] = amps - - def _get_data(self, outputs="concatenated"): + def _get_data(self): return self.data["amplitudes"] register_result_extension(ComputeSpikeAmplitudes) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 1d3f259b1d..b86ec42c9a 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -22,7 +22,7 @@ class ComputeTemplateSimilarity(ResultExtension): The similarity matrix """ - extension_name = "similarity" + extension_name = "template_similarity" depend_on = ["fast_templates|templates", ] need_recording = True use_nodepipeline = False diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 1e34f4fdb1..dcaa8653fd 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -6,7 +6,7 @@ from .base import BaseWidget, to_attr from .utils import get_some_colors -from ..core.waveform_extractor import WaveformExtractor +from ..core import SortingResult class AllAmplitudesDistributionsWidget(BaseWidget): @@ -15,8 +15,8 @@ class AllAmplitudesDistributionsWidget(BaseWidget): Parameters ---------- - waveform_extractor: WaveformExtractor - The input waveform extractor + sorting_result: SortingResult + The SortingResult unit_ids: list List of unit ids, default None unit_colors: None or dict @@ -24,20 +24,20 @@ class AllAmplitudesDistributionsWidget(BaseWidget): """ def __init__( - self, waveform_extractor: WaveformExtractor, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs + self, sorting_result: SortingResult, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs ): - we = waveform_extractor - self.check_extensions(we, "spike_amplitudes") - amplitudes = we.load_extension("spike_amplitudes").get_data(outputs="by_unit") + self.check_extensions(sorting_result, "spike_amplitudes") + + amplitudes = sorting_result.get_extension("spike_amplitudes").get_data() - num_segments = we.get_num_segments() + num_segments = sorting_result.get_num_segments() if unit_ids is None: - unit_ids = we.unit_ids + unit_ids = sorting_result.unit_ids if unit_colors is None: - unit_colors = get_some_colors(we.unit_ids) + unit_colors = get_some_colors(sorting_result.unit_ids) plot_data = dict( unit_ids=unit_ids, diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index db43c004d5..c207fca1f2 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -103,17 +103,17 @@ def do_plot(self): func(self.data_plot, **self.backend_kwargs) @staticmethod - def check_extensions(waveform_extractor, extensions): + def check_extensions(sorting_result, extensions): if isinstance(extensions, str): extensions = [extensions] error_msg = "" raise_error = False for extension in extensions: - if not waveform_extractor.has_extension(extension): + if not sorting_result.has_extension(extension): raise_error = True error_msg += ( f"The {extension} waveform extension is required for this widget. " - f"Run the `compute_{extension}` to compute it.\n" + f"Run the `sorting_result.compute('{extension}', ...)` to compute it.\n" ) if raise_error: raise Exception(error_msg) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 9c32f772e3..b93466cc4e 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -13,12 +13,10 @@ from spikeinterface import ( - load_extractor, - extract_waveforms, - load_waveforms, - download_dataset, compute_sparsity, generate_ground_truth_recording, + start_sorting_result, + load_sorting_result, ) import spikeinterface.extractors as se @@ -47,74 +45,77 @@ class TestWidgets(unittest.TestCase): - @classmethod - def _delete_widget_folders(cls): - for name in ( - "recording", - "sorting", - "we_dense", - "we_sparse", - ): - if (cache_folder / name).is_dir(): - shutil.rmtree(cache_folder / name) + # @classmethod + # def _delete_widget_folders(cls): + # for name in ( + # "recording", + # "sorting", + # "we_dense", + # "we_sparse", + # ): + # if (cache_folder / name).is_dir(): + # shutil.rmtree(cache_folder / name) @classmethod def setUpClass(cls): - cls._delete_widget_folders() - - if (cache_folder / "recording").is_dir() and (cache_folder / "sorting").is_dir(): - cls.recording = load_extractor(cache_folder / "recording") - cls.sorting = load_extractor(cache_folder / "sorting") - else: - recording, sorting = generate_ground_truth_recording( - durations=[30.0], - sampling_frequency=28000.0, - num_channels=32, - num_units=10, - generate_probe_kwargs=dict( - num_columns=2, - xpitch=20, - ypitch=20, - contact_shapes="circle", - contact_shape_params={"radius": 6}, - ), - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), - seed=2205, - ) - # cls.recording = recording.save(folder=cache_folder / "recording") - # cls.sorting = sorting.save(folder=cache_folder / "sorting") - cls.recording = recording - cls.sorting = sorting + # cls._delete_widget_folders() + + recording, sorting = generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=28000.0, + num_channels=32, + num_units=10, + generate_probe_kwargs=dict( + num_columns=2, + xpitch=20, + ypitch=20, + contact_shapes="circle", + contact_shape_params={"radius": 6}, + ), + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), + seed=2205, + ) + # cls.recording = recording.save(folder=cache_folder / "recording") + # cls.sorting = sorting.save(folder=cache_folder / "sorting") + cls.recording = recording + cls.sorting = sorting cls.num_units = len(cls.sorting.get_unit_ids()) - if (cache_folder / "we_dense").is_dir(): - cls.we_dense = load_waveforms(cache_folder / "we_dense") - else: - cls.we_dense = extract_waveforms( - recording=cls.recording, sorting=cls.sorting, folder=None, mode="memory", sparse=False - ) - metric_names = ["snr", "isi_violation", "num_spikes"] - _ = compute_spike_amplitudes(cls.we_dense) - _ = compute_unit_locations(cls.we_dense) - _ = compute_spike_locations(cls.we_dense) - _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) - _ = compute_template_metrics(cls.we_dense) - _ = compute_correlograms(cls.we_dense) - _ = compute_template_similarity(cls.we_dense) + + extensions_to_compute = dict( + waveforms=dict(), + templates=dict(), + noise_levels=dict(), + spike_amplitudes=dict(), + unit_locations=dict(), + spike_locations=dict(), + quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), + template_metrics=dict(), + correlograms=dict(), + template_similarity=dict(), + ) + job_kwargs = dict(n_jobs=-1) + + # create dense + cls.sorting_result_dense = start_sorting_result(cls.sorting, cls.recording, format="memory", sparse=False) + cls.sorting_result_dense.select_random_spikes() + cls.sorting_result_dense.compute(extensions_to_compute, **job_kwargs) sw.set_default_plotter_backend("matplotlib") # make sparse waveforms - cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50) - cls.sparsity_strict = compute_sparsity(cls.we_dense, method="radius", radius_um=20) - cls.sparsity_large = compute_sparsity(cls.we_dense, method="radius", radius_um=80) - cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5) - if (cache_folder / "we_sparse").is_dir(): - cls.we_sparse = load_waveforms(cache_folder / "we_sparse") - else: - cls.we_sparse = cls.we_dense.save(folder=cache_folder / "we_sparse", sparsity=cls.sparsity_radius) + cls.sparsity_radius = compute_sparsity(cls.sorting_result_dense, method="radius", radius_um=50) + cls.sparsity_strict = compute_sparsity(cls.sorting_result_dense, method="radius", radius_um=20) + cls.sparsity_large = compute_sparsity(cls.sorting_result_dense, method="radius", radius_um=80) + cls.sparsity_best = compute_sparsity(cls.sorting_result_dense, method="best_channels", num_channels=5) + + # create sparse + cls.sorting_result_sparse = start_sorting_result(cls.sorting, cls.recording, format="memory", sparsity=cls.sparsity_radius) + cls.sorting_result_sparse.select_random_spikes() + cls.sorting_result_sparse.compute(extensions_to_compute, **job_kwargs) + cls.skip_backends = ["ipywidgets", "ephyviewer"] @@ -133,7 +134,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - del cls.recording, cls.sorting, cls.peaks, cls.gt_comp, cls.we_sparse, cls.we_dense + del cls.recording, cls.sorting, cls.peaks, cls.gt_comp, cls.sorting_result_sparse, cls.sorting_result_dense # cls._delete_widget_folders() def test_plot_traces(self): @@ -177,28 +178,28 @@ def test_plot_unit_waveforms(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_waveforms(self.we_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_waveforms(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_waveforms( - self.we_dense, + self.sorting_result_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_waveforms( - self.we_dense, + self.sorting_result_dense, sparsity=self.sparsity_best, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_waveforms( - self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) # extra sparsity sw.plot_unit_waveforms( - self.we_sparse, + self.sorting_result_sparse, sparsity=self.sparsity_strict, unit_ids=unit_ids, backend=backend, @@ -207,7 +208,7 @@ def test_plot_unit_waveforms(self): # test "larger" sparsity with self.assertRaises(AssertionError): sw.plot_unit_waveforms( - self.we_sparse, + self.sorting_result_sparse, sparsity=self.sparsity_large, unit_ids=unit_ids, backend=backend, @@ -220,11 +221,11 @@ def test_plot_unit_templates(self): if backend not in self.skip_backends: print(f"Testing backend {backend}") print("Dense") - sw.plot_unit_templates(self.we_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_templates(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] print("Dense + radius") sw.plot_unit_templates( - self.we_dense, + self.sorting_result_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, @@ -232,7 +233,7 @@ def test_plot_unit_templates(self): ) print("Dense + best") sw.plot_unit_templates( - self.we_dense, + self.sorting_result_dense, sparsity=self.sparsity_best, unit_ids=unit_ids, backend=backend, @@ -241,7 +242,7 @@ def test_plot_unit_templates(self): # test different shadings print("Sparse") sw.plot_unit_templates( - self.we_sparse, + self.sorting_result_sparse, unit_ids=unit_ids, templates_percentile_shading=None, backend=backend, @@ -249,7 +250,7 @@ def test_plot_unit_templates(self): ) print("Sparse2") sw.plot_unit_templates( - self.we_sparse, + self.sorting_result_sparse, unit_ids=unit_ids, # templates_percentile_shading=None, scale=10, @@ -259,7 +260,7 @@ def test_plot_unit_templates(self): # test different shadings print("Sparse3") sw.plot_unit_templates( - self.we_sparse, + self.sorting_result_sparse, unit_ids=unit_ids, backend=backend, templates_percentile_shading=None, @@ -268,7 +269,7 @@ def test_plot_unit_templates(self): ) print("Sparse4") sw.plot_unit_templates( - self.we_sparse, + self.sorting_result_sparse, unit_ids=unit_ids, templates_percentile_shading=0.1, backend=backend, @@ -276,7 +277,7 @@ def test_plot_unit_templates(self): ) print("Extra sparsity") sw.plot_unit_templates( - self.we_sparse, + self.sorting_result_sparse, sparsity=self.sparsity_strict, unit_ids=unit_ids, templates_percentile_shading=[1, 10, 90, 99], @@ -286,7 +287,7 @@ def test_plot_unit_templates(self): # test "larger" sparsity with self.assertRaises(AssertionError): sw.plot_unit_templates( - self.we_sparse, + self.sorting_result_sparse, sparsity=self.sparsity_large, unit_ids=unit_ids, backend=backend, @@ -294,7 +295,7 @@ def test_plot_unit_templates(self): ) if backend != "sortingview": sw.plot_unit_templates( - self.we_sparse, + self.sorting_result_sparse, unit_ids=unit_ids, templates_percentile_shading=[1, 5, 25, 75, 95, 99], backend=backend, @@ -304,7 +305,7 @@ def test_plot_unit_templates(self): # sortingview doesn't support more than 2 shadings with self.assertRaises(AssertionError): sw.plot_unit_templates( - self.we_sparse, + self.sorting_result_sparse, unit_ids=unit_ids, templates_percentile_shading=[1, 5, 25, 75, 95, 99], backend=backend, @@ -317,7 +318,7 @@ def test_plot_unit_waveforms_density_map(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): @@ -326,7 +327,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we_dense, + self.sorting_result_dense, sparsity=self.sparsity_radius, same_axis=False, unit_ids=unit_ids, @@ -340,7 +341,7 @@ def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we_sparse, + self.sorting_result_sparse, sparsity=None, same_axis=True, unit_ids=unit_ids, @@ -383,12 +384,12 @@ def test_plot_crosscorrelogram(self): **self.backend_kwargs[backend], ) sw.plot_crosscorrelograms( - self.we_sparse, + self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend], ) sw.plot_crosscorrelograms( - self.we_sparse, + self.sorting_result_sparse, min_similarity_for_correlograms=0.6, backend=backend, **self.backend_kwargs[backend], @@ -412,18 +413,18 @@ def test_plot_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_amplitudes(self.we_dense, backend=backend, **self.backend_kwargs[backend]) - unit_ids = self.we_dense.unit_ids[:4] - sw.plot_amplitudes(self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) + sw.plot_amplitudes(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + unit_ids = self.sorting_result_dense.unit_ids[:4] + sw.plot_amplitudes(self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) sw.plot_amplitudes( - self.we_dense, + self.sorting_result_dense, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend], ) sw.plot_amplitudes( - self.we_sparse, + self.sorting_result_sparse, unit_ids=unit_ids, plot_histograms=True, backend=backend, @@ -434,12 +435,12 @@ def test_plot_all_amplitudes_distributions(self): possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - unit_ids = self.we_dense.unit_ids[:4] + unit_ids = self.sorting_result_dense.unit_ids[:4] sw.plot_all_amplitudes_distributions( - self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) sw.plot_all_amplitudes_distributions( - self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) def test_plot_unit_locations(self): @@ -447,10 +448,10 @@ def test_plot_unit_locations(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_locations( - self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) sw.plot_unit_locations( - self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) def test_plot_spike_locations(self): @@ -458,59 +459,59 @@ def test_plot_spike_locations(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_spike_locations( - self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) sw.plot_spike_locations( - self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) def test_plot_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_similarity(self.we_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_quality_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_quality_metrics(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_quality_metrics(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_template_metrics(self): possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_metrics(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_metrics(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_depths(self.we_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_depths(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_depths(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): possible_backends = list(sw.UnitSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( - self.we_dense, self.we_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.sorting_result_dense, self.sorting_result_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) sw.plot_unit_summary( - self.we_sparse, self.we_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.sorting_result_sparse, self.sorting_result_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) def test_plot_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.we_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary( - self.we_sparse, sparsity=self.sparsity_strict, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_sparse, sparsity=self.sparsity_strict, backend=backend, **self.backend_kwargs[backend] ) def test_plot_agreement_matrix(self): @@ -541,7 +542,7 @@ def test_plot_unit_probe_map(self): possible_backends = list(sw.UnitProbeMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_probe_map(self.we_dense) + sw.plot_unit_probe_map(self.sorting_result_dense) def test_plot_unit_presence(self): possible_backends = list(sw.UnitPresenceWidget.get_possible_backends()) @@ -581,11 +582,11 @@ def test_plot_multicomparison(self): # mytest.test_plot_unit_waveforms_density_map() # mytest.test_plot_unit_summary() - # mytest.test_plot_all_amplitudes_distributions() + mytest.test_plot_all_amplitudes_distributions() # mytest.test_plot_traces() # mytest.test_plot_unit_waveforms() # mytest.test_plot_unit_templates() - mytest.test_plot_unit_waveforms() + # mytest.test_plot_unit_waveforms() # mytest.test_plot_unit_depths() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() From 33c68472e6efe7f522d374910b7d5bd781b68c7a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 8 Feb 2024 21:40:01 +0100 Subject: [PATCH 043/192] wip sorting result and widgets --- src/spikeinterface/core/result_core.py | 113 +++++++++++++----- src/spikeinterface/core/sortingresult.py | 2 +- src/spikeinterface/core/template_tools.py | 1 + .../core/tests/test_result_core.py | 97 ++++++++------- .../widgets/tests/test_widgets.py | 34 +++--- src/spikeinterface/widgets/unit_depths.py | 19 ++- src/spikeinterface/widgets/unit_waveforms.py | 87 ++++++++------ 7 files changed, 219 insertions(+), 134 deletions(-) diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index ba31c19b25..5eaaff3f26 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -127,13 +127,14 @@ def get_waveforms_one_unit(self, unit_id, force_dense: bool = False,): some_spikes = spikes[self.sorting_result.random_spikes_indices] spike_mask = some_spikes["unit_index"] == unit_index wfs = self.data["waveforms"][spike_mask, :, :] - - if force_dense: - if self.sorting_result.sparsity is not None: + + if self.sorting_result.sparsity is not None: + chan_inds = self.sorting_result.sparsity.unit_id_to_channel_indices[unit_id] + wfs = wfs[:, :, :chan_inds.size] + if force_dense: num_channels = self.get_num_channels() dense_wfs = np.zeros((wfs.shape[0], wfs.shape[1], num_channels), dtype=wfs.dtype) - unit_sparsity = self.sorting_result.sparsity.mask[unit_index] - dense_wfs[:, :, unit_sparsity] = wfs + dense_wfs[:, :, chan_inds] = wfs wfs = dense_wfs return wfs @@ -155,6 +156,11 @@ class ComputeTemplates(ResultExtension): This must be run after "waveforms" extension (`SortingResult.compute("waveforms")`) Note that when "waveforms" is already done, then the recording is not needed anymore for this extension. + + Note: by default only the average is computed. Other operator (std, median, percentile) can be computed on demand + after the SortingResult.compute("templates") and then the data dict is updated on demand. + + """ extension_name = "templates" depend_on = ["waveforms"] @@ -162,8 +168,31 @@ class ComputeTemplates(ResultExtension): use_nodepipeline = False need_job_kwargs = False + def _set_params(self, operators = ["average", "std"]): + assert isinstance(operators, list) + for operator in operators: + if isinstance(operator, str): + assert operator in ("average", "std", "median", "mad") + else: + assert isinstance(operator, (list, tuple)) + assert len(operator) == 2 + assert operator[0] == "percentile" + + waveforms_extension = self.sorting_result.get_extension("waveforms") + + params = dict( + operators=operators, + nbefore=waveforms_extension.nbefore, + nafter=waveforms_extension.nafter, + return_scaled=waveforms_extension.params["return_scaled"], + ) + return params + def _run(self): - + self._compute_and_append(self.params["operators"]) + + + def _compute_and_append(self, operators): unit_ids = self.sorting_result.unit_ids channel_ids = self.sorting_result.channel_ids waveforms_extension = self.sorting_result.get_extension("waveforms") @@ -171,7 +200,7 @@ def _run(self): num_samples = waveforms.shape[1] - for operator in self.params["operators"]: + for operator in operators: if isinstance(operator, str) and operator in ("average", "std", "median"): key = operator elif isinstance(operator, (list, tuple)): @@ -190,7 +219,7 @@ def _run(self): if wfs.shape[0] == 0: continue - for operator in self.params["operators"]: + for operator in operators: if operator == "average": arr = np.average(wfs, axis=0) key = operator @@ -211,26 +240,6 @@ def _run(self): channel_indices = self.sparsity.unit_id_to_channel_indices[unit_id] self.data[key][unit_index, :, :][:, channel_indices] = arr[:, :channel_indices.size] - def _set_params(self, operators = ["average", "std"]): - assert isinstance(operators, list) - for operator in operators: - if isinstance(operator, str): - assert operator in ("average", "std", "median", "mad") - else: - assert isinstance(operator, (list, tuple)) - assert len(operator) == 2 - assert operator[0] == "percentile" - - waveforms_extension = self.sorting_result.get_extension("waveforms") - - params = dict( - operators=operators, - nbefore=waveforms_extension.nbefore, - nafter=waveforms_extension.nafter, - return_scaled=waveforms_extension.params["return_scaled"], - ) - return params - @property def nbefore(self): return self.params["nbefore"] @@ -256,6 +265,54 @@ def _get_data(self, operator="average", percentile=None): key = f"pencentile_{percentile}" return self.data[key] + def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True): + """ + Return templates (average, std, median or percentil) for multiple units. + + I not computed yet then this is computed on demand and optionally saved. + + Parameters + ---------- + unit_ids: list or None + Unit ids to retrieve waveforms for + mode: "average" | "median" | "std" | "percentile", default: "average" + The mode to compute the templates + percentile: float, default: None + Percentile to use for mode="percentile" + save: bool, default True + In case, the operator is not computed yet it can be saved to folder or zarr. + + Returns + ------- + templates: np.array + The returned templates (num_units, num_samples, num_channels) + """ + if operator != "percentile": + key = operator + else: + assert percentile is not None, "You must provide percentile=..." + key = f"pencentile_{percentile}" + + if key in self.data: + templates = self.data[key] + else: + if operator != "percentile": + self._compute_and_append([operator]) + self.params["operators"] += [operator] + else: + self._compute_and_append([(operator, percentile)]) + self.params["operators"] += [(operator, percentile)] + templates = self.data[key] + + if save: + self.save() + + if unit_ids is not None: + unit_indices = self.sorting_result.sorting.ids_to_indices(unit_ids) + templates = templates[unit_indices, :, :] + + return np.array(templates) + compute_templates = ComputeTemplates.function_factory() diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index abe284d5b0..d0085925ae 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -964,7 +964,7 @@ def load_extension(self, extension_name: str): The loaded instance of the extension """ - assert self.format != "memory" + assert self.format != "memory", "SortingResult.load_extension() do not work for format='memory' use SortingResult.get_extension()instead" extension_class = get_extension_class(extension_name) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 3b844ad845..360db97c8e 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -9,6 +9,7 @@ from .sortingresult import SortingResult +# TODO make this function a non private function def _get_dense_templates_array(one_object, return_scaled=True): if isinstance(one_object, Templates): templates_array = one_object.get_dense_templates() diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index 809a5587d8..c19ff6c4e8 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -40,12 +40,12 @@ def get_sorting_result(format="memory", sparse=True): if folder and folder.exists(): shutil.rmtree(folder) - sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=sparse, sparsity=None) + sorting_result = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=sparse, sparsity=None) - return sortres + return sorting_result -def _check_result_extension(sortres, extension_name): +def _check_result_extension(sorting_result, extension_name): # select unit_ids to several format for format in ("memory", "binary_folder", "zarr"): # for format in ("memory", ): @@ -60,10 +60,10 @@ def _check_result_extension(sortres, extension_name): folder = None # check unit slice - keep_unit_ids = sortres.sorting.unit_ids[::2] - sortres2 = sortres.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) + keep_unit_ids = sorting_result.sorting.unit_ids[::2] + sorting_result2 = sorting_result.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) - data = sortres2.get_extension(extension_name).data + data = sorting_result2.get_extension(extension_name).data # for k, arr in data.items(): # print(k, arr.shape) @@ -71,54 +71,63 @@ def _check_result_extension(sortres, extension_name): @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeWaveforms(format, sparse): - sortres = get_sorting_result(format=format, sparse=sparse) + sorting_result = get_sorting_result(format=format, sparse=sparse) job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) - sortres.select_random_spikes(max_spikes_per_unit=50, seed=2205) - ext = sortres.compute("waveforms", **job_kwargs) + sorting_result.select_random_spikes(max_spikes_per_unit=50, seed=2205) + ext = sorting_result.compute("waveforms", **job_kwargs) wfs = ext.data["waveforms"] - _check_result_extension(sortres, "waveforms") + _check_result_extension(sorting_result, "waveforms") @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeTemplates(format, sparse): - sortres = get_sorting_result(format=format, sparse=sparse) + sorting_result = get_sorting_result(format=format, sparse=sparse) - sortres.select_random_spikes(max_spikes_per_unit=20, seed=2205) + sorting_result.select_random_spikes(max_spikes_per_unit=20, seed=2205) with pytest.raises(AssertionError): # This require "waveforms first and should trig an error - sortres.compute("templates") + sorting_result.compute("templates") job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) - sortres.compute("waveforms", **job_kwargs) - sortres.compute("templates", operators=["average", "std", "median", ("percentile", 5.), ("percentile", 95.),]) + sorting_result.compute("waveforms", **job_kwargs) + # compute some operators + sorting_result.compute("templates", operators=["average", "std", ("percentile", 95.),]) - data = sortres.get_extension("templates").data + # ask for more operator later + ext = sorting_result.get_extension("templates") + templated_median = ext.get_templates(operator="median") + templated_per_5 = ext.get_templates(operator="percentile", percentile=5.) + + # they all should be in data + data = sorting_result.get_extension("templates").data for k in ['average', 'std', 'median', 'pencentile_5.0', 'pencentile_95.0']: assert k in data.keys() - assert data[k].shape[0] == sortres.unit_ids.size - assert data[k].shape[2] == sortres.channel_ids.size + assert data[k].shape[0] == sorting_result.unit_ids.size + assert data[k].shape[2] == sorting_result.channel_ids.size assert np.any(data[k] > 0) - import matplotlib.pyplot as plt - for unit_index, unit_id in enumerate(sortres.unit_ids): - fig, ax = plt.subplots() - for k in data.keys(): - wf0 = data[k][unit_index, :, :] - ax.plot(wf0.T.flatten(), label=k) - ax.legend() + + + # import matplotlib.pyplot as plt + # for unit_index, unit_id in enumerate(sorting_result.unit_ids): + # fig, ax = plt.subplots() + # for k in data.keys(): + # wf0 = data[k][unit_index, :, :] + # ax.plot(wf0.T.flatten(), label=k) + # ax.legend() # plt.show() - _check_result_extension(sortres, "templates") + _check_result_extension(sorting_result, "templates") @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeFastTemplates(format, sparse): - sortres = get_sorting_result(format=format, sparse=sparse) + sorting_result = get_sorting_result(format=format, sparse=sparse) # TODO check this because this is not passing with n_jobs=2 job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) @@ -126,24 +135,24 @@ def test_ComputeFastTemplates(format, sparse): ms_before=1.0 ms_after=2.5 - sortres.select_random_spikes(max_spikes_per_unit=20, seed=2205) - sortres.compute("fast_templates", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) + sorting_result.select_random_spikes(max_spikes_per_unit=20, seed=2205) + sorting_result.compute("fast_templates", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) - _check_result_extension(sortres, "fast_templates") + _check_result_extension(sorting_result, "fast_templates") # compare ComputeTemplates with dense and ComputeFastTemplates: should give the same on "average" - other_sortres = get_sorting_result(format=format, sparse=False) - other_sortres.select_random_spikes(max_spikes_per_unit=20, seed=2205) - other_sortres.compute("waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) - other_sortres.compute("templates", operators=["average",]) + other_sorting_result = get_sorting_result(format=format, sparse=False) + other_sorting_result.select_random_spikes(max_spikes_per_unit=20, seed=2205) + other_sorting_result.compute("waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) + other_sorting_result.compute("templates", operators=["average",]) - templates0 = sortres.get_extension("fast_templates").data["average"] - templates1 = other_sortres.get_extension("templates").data["average"] + templates0 = sorting_result.get_extension("fast_templates").data["average"] + templates1 = other_sorting_result.get_extension("templates").data["average"] np.testing.assert_almost_equal(templates0, templates1) # import matplotlib.pyplot as plt # fig, ax = plt.subplots() - # for unit_index, unit_id in enumerate(sortres.unit_ids): + # for unit_index, unit_id in enumerate(sorting_result.unit_ids): # wf0 = templates0[unit_index, :, :] # ax.plot(wf0.T.flatten(), label=f"{unit_id}") # wf1 = templates1[unit_index, :, :] @@ -154,13 +163,13 @@ def test_ComputeFastTemplates(format, sparse): @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeNoiseLevels(format, sparse): - sortres = get_sorting_result(format=format, sparse=sparse) + sorting_result = get_sorting_result(format=format, sparse=sparse) - sortres.compute("noise_levels", return_scaled=True) - print(sortres) + sorting_result.compute("noise_levels", return_scaled=True) + print(sorting_result) - noise_levels = sortres.get_extension("noise_levels").data["noise_levels"] - assert noise_levels.shape[0] == sortres.channel_ids.size + noise_levels = sorting_result.get_extension("noise_levels").data["noise_levels"] + assert noise_levels.shape[0] == sorting_result.channel_ids.size if __name__ == '__main__': @@ -171,11 +180,11 @@ def test_ComputeNoiseLevels(format, sparse): # test_ComputeWaveforms(format="zarr", sparse=True) # test_ComputeWaveforms(format="zarr", sparse=False) - # test_ComputeTemplates(format="memory", sparse=True) + test_ComputeTemplates(format="memory", sparse=True) # test_ComputeTemplates(format="memory", sparse=False) # test_ComputeTemplates(format="binary_folder", sparse=True) # test_ComputeTemplates(format="zarr", sparse=True) # test_ComputeFastTemplates(format="memory", sparse=True) - test_ComputeNoiseLevels(format="memory", sparse=False) + # test_ComputeNoiseLevels(format="memory", sparse=False) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index b93466cc4e..827b764d28 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -87,14 +87,14 @@ def setUpClass(cls): extensions_to_compute = dict( waveforms=dict(), templates=dict(), - noise_levels=dict(), - spike_amplitudes=dict(), - unit_locations=dict(), - spike_locations=dict(), - quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), - template_metrics=dict(), - correlograms=dict(), - template_similarity=dict(), + # noise_levels=dict(), + # spike_amplitudes=dict(), + # unit_locations=dict(), + # spike_locations=dict(), + # quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), + # template_metrics=dict(), + # correlograms=dict(), + # template_similarity=dict(), ) job_kwargs = dict(n_jobs=-1) @@ -117,7 +117,9 @@ def setUpClass(cls): cls.sorting_result_sparse.compute(extensions_to_compute, **job_kwargs) - cls.skip_backends = ["ipywidgets", "ephyviewer"] + # cls.skip_backends = ["ipywidgets", "ephyviewer"] + # TODO : delete this after debug + cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] if ON_GITHUB and not KACHERY_CLOUD_SET: cls.skip_backends.append("sortingview") @@ -126,11 +128,11 @@ def setUpClass(cls): cls.backend_kwargs = {"matplotlib": {}, "sortingview": {}, "ipywidgets": {"display": False}} - cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) + # cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) - from spikeinterface.sortingcomponents.peak_detection import detect_peaks + # from spikeinterface.sortingcomponents.peak_detection import detect_peaks - cls.peaks = detect_peaks(cls.recording, method="locally_exclusive") + # cls.peaks = detect_peaks(cls.recording, method="locally_exclusive", **job_kwargs) @classmethod def tearDownClass(cls): @@ -582,11 +584,10 @@ def test_plot_multicomparison(self): # mytest.test_plot_unit_waveforms_density_map() # mytest.test_plot_unit_summary() - mytest.test_plot_all_amplitudes_distributions() + # mytest.test_plot_all_amplitudes_distributions() # mytest.test_plot_traces() # mytest.test_plot_unit_waveforms() - # mytest.test_plot_unit_templates() - # mytest.test_plot_unit_waveforms() + mytest.test_plot_unit_templates() # mytest.test_plot_unit_depths() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() @@ -603,7 +604,8 @@ def test_plot_multicomparison(self): # mytest.test_plot_unit_probe_map() # mytest.test_plot_unit_presence() # mytest.test_plot_multicomparison() + plt.show() TestWidgets.tearDownClass() - plt.show() + diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 7ca585f9d3..0827b00aed 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -16,8 +16,8 @@ class UnitDepthsWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The input waveform extractor + sorting_result : SortingResult + The SortingResult object unit_colors : dict or None, default: None If given, a dictionary with unit ids as keys and colors as values depth_axis : int, default: 1 @@ -27,25 +27,24 @@ class UnitDepthsWidget(BaseWidget): """ def __init__( - self, waveform_extractor, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs + self, sorting_result, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs ): - we = waveform_extractor - unit_ids = we.sorting.unit_ids + unit_ids = sorting_result.sorting.unit_ids if unit_colors is None: - unit_colors = get_unit_colors(we.sorting) + unit_colors = get_unit_colors(sorting_result.sorting) colors = [unit_colors[unit_id] for unit_id in unit_ids] - self.check_extensions(waveform_extractor, "unit_locations") - ulc = waveform_extractor.load_extension("unit_locations") + self.check_extensions(sorting_result, "unit_locations") + ulc = sorting_result.get_extension("unit_locations") unit_locations = ulc.get_data(outputs="numpy") unit_depths = unit_locations[:, depth_axis] - unit_amplitudes = get_template_extremum_amplitude(we, peak_sign=peak_sign) + unit_amplitudes = get_template_extremum_amplitude(sorting_result, peak_sign=peak_sign) unit_amplitudes = np.abs([unit_amplitudes[unit_id] for unit_id in unit_ids]) - num_spikes = we.sorting.count_num_spikes_per_unit(outputs="array") + num_spikes = sorting_result.sorting.count_num_spikes_per_unit(outputs="array") plot_data = dict( unit_depths=unit_depths, diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 83e9f583f1..92edd677a6 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -5,10 +5,9 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from ..core import ChannelSparsity -from ..core.waveform_extractor import WaveformExtractor +from ..core import ChannelSparsity, SortingResult from ..core.basesorting import BaseSorting - +from ..core.template_tools import _get_dense_templates_array class UnitWaveformsWidget(BaseWidget): """ @@ -16,8 +15,8 @@ class UnitWaveformsWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The input waveform extractor + sorting_result : SortingResult + The SortingResult channel_ids: list or None, default: None The channel ids to display unit_ids : list or None, default: None @@ -26,7 +25,7 @@ class UnitWaveformsWidget(BaseWidget): If True, templates are plotted over the waveforms sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply - If WaveformExtractor is already sparse, the argument is ignored + If SortingResult is already sparse, the argument is ignored set_title : bool, default: True Create a plot title with the unit number if True plot_channels : bool, default: False @@ -77,7 +76,7 @@ class UnitWaveformsWidget(BaseWidget): def __init__( self, - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, channel_ids=None, unit_ids=None, plot_waveforms=True, @@ -104,26 +103,26 @@ def __init__( backend=None, **backend_kwargs, ): - we = waveform_extractor - sorting: BaseSorting = we.sorting + + sorting: BaseSorting = sorting_result.sorting if unit_ids is None: unit_ids = sorting.unit_ids if channel_ids is None: - channel_ids = we.channel_ids + channel_ids = sorting_result.channel_ids if unit_colors is None: unit_colors = get_unit_colors(sorting) - channel_locations = we.get_channel_locations()[we.channel_ids_to_indices(channel_ids)] + channel_locations = sorting_result.get_channel_locations()[sorting_result.channel_ids_to_indices(channel_ids)] extra_sparsity = False - if waveform_extractor.is_sparse(): + if sorting_result.is_sparse(): if sparsity is None: - sparsity = waveform_extractor.sparsity + sparsity = sorting_result.sparsity else: # assert provided sparsity is a subset of waveform sparsity - combined_mask = np.logical_or(we.sparsity.mask, sparsity.mask) - assert np.all(np.sum(combined_mask, 1) - np.sum(we.sparsity.mask, 1) == 0), ( + combined_mask = np.logical_or(sorting_result.sparsity.mask, sparsity.mask) + assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_result.sparsity.mask, 1) == 0), ( "The provided 'sparsity' needs to include only the sparse channels " "used to extract waveforms (for example, by using a smaller 'radius_um')." ) @@ -131,36 +130,47 @@ def __init__( else: if sparsity is None: # in this case, we construct a dense sparsity - unit_id_to_channel_ids = {u: we.channel_ids for u in we.unit_ids} + unit_id_to_channel_ids = {u: sorting_result.channel_ids for u in sorting_result.unit_ids} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( - unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=we.unit_ids, channel_ids=we.channel_ids + unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_result.unit_ids, channel_ids=sorting_result.channel_ids ) else: assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" # get templates - templates = we.get_all_templates(unit_ids=unit_ids) - templates_shading = self._get_template_shadings(we, unit_ids, templates_percentile_shading) + ext = sorting_result.get_extension("templates") + assert ext is not None, "plot_waveforms() need extension 'templates'" + templates = ext.get_templates(unit_ids=unit_ids, operator="average") + + + + templates_shading = self._get_template_shadings(sorting_result, unit_ids, templates_percentile_shading) xvectors, y_scale, y_offset, delta_x = get_waveforms_scales( - waveform_extractor, templates, channel_locations, x_offset_units + sorting_result, templates, channel_locations, x_offset_units ) wfs_by_ids = {} if plot_waveforms: + wf_ext = sorting_result.get_extension("waveforms") + assert wf_ext is not None, "plot_waveforms() need extension 'waveforms'" for unit_id in unit_ids: + unit_index = list(sorting.unit_ids).index(unit_id) if not extra_sparsity: - if waveform_extractor.is_sparse(): - wfs = we.get_waveforms(unit_id) + if sorting_result.is_sparse(): + # wfs = we.get_waveforms(unit_id) + wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) else: - wfs = we.get_waveforms(unit_id, sparsity=sparsity) + # wfs = we.get_waveforms(unit_id, sparsity=sparsity) + wfs = wf_ext.get_waveforms_one_unit(unit_id) + wfs = wfs[:, :, sparsity.mask[unit_index]] else: # in this case we have to slice the waveform sparsity based on the extra sparsity - unit_index = list(sorting.unit_ids).index(unit_id) # first get the sparse waveforms - wfs = we.get_waveforms(unit_id) + # wfs = we.get_waveforms(unit_id) + wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) # find additional slice to apply to sparse waveforms - (wfs_sparse_indices,) = np.nonzero(waveform_extractor.sparsity.mask[unit_index]) + (wfs_sparse_indices,) = np.nonzero(sorting_result.sparsity.mask[unit_index]) (extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index]) (extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices)) # apply extra sparsity @@ -168,8 +178,8 @@ def __init__( wfs_by_ids[unit_id] = wfs plot_data = dict( - waveform_extractor=waveform_extractor, - sampling_frequency=waveform_extractor.sampling_frequency, + sorting_result=sorting_result, + sampling_frequency=sorting_result.sampling_frequency, unit_ids=unit_ids, channel_ids=channel_ids, sparsity=sparsity, @@ -243,6 +253,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if len(wfs) > dp.max_spikes_per_unit: random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit] wfs = wfs[random_idxs] + wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds] wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T @@ -333,7 +344,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() cm = 1 / 2.54 - self.we = we = data_plot["waveform_extractor"] + self.we = we = data_plot["sorting_result"] width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -402,10 +413,12 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: display(self.widget) - def _get_template_shadings(self, we, unit_ids, templates_percentile_shading): - templates = we.get_all_templates(unit_ids=unit_ids) + def _get_template_shadings(self, sorting_result, unit_ids, templates_percentile_shading): + ext = sorting_result.get_extension("templates") + templates = ext.get_templates(unit_ids=unit_ids, operator="average") + if templates_percentile_shading is None: - templates_std = we.get_all_templates(unit_ids=unit_ids, mode="std") + templates_std = ext.get_templates(unit_ids=unit_ids, operator="std") templates_shading = [templates - templates_std, templates + templates_std] else: if isinstance(templates_percentile_shading, (int, float)): @@ -419,7 +432,8 @@ def _get_template_shadings(self, we, unit_ids, templates_percentile_shading): ), "'templates_percentile_shading' should be a have an even number of elements." templates_shading = [] for percentile in templates_percentile_shading: - template_percentile = we.get_all_templates(unit_ids=unit_ids, mode="percentile", percentile=percentile) + template_percentile = ext.get_templates(unit_ids=unit_ids, operator="percentile", percentile=percentile) + templates_shading.append(template_percentile) return templates_shading @@ -496,7 +510,7 @@ def _update_plot(self, change): fig_probe.canvas.flush_events() -def get_waveforms_scales(we, templates, channel_locations, x_offset_units=False): +def get_waveforms_scales(sorting_result, templates, channel_locations, x_offset_units=False): """ Return scales and x_vector for templates plotting """ @@ -522,7 +536,10 @@ def get_waveforms_scales(we, templates, channel_locations, x_offset_units=False) y_offset = channel_locations[:, 1][None, :] - xvect = delta_x * (np.arange(we.nsamples) - we.nbefore) / we.nsamples * 0.7 + nbefore = sorting_result.get_extension("waveforms").nbefore + nsamples = templates.shape[1] + + xvect = delta_x * (np.arange(nsamples) - nbefore) / nsamples * 0.7 if x_offset_units: ch_locs = channel_locations From 7e0ca59f6fc03de525c9975a15c96db268cd69dd Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 9 Feb 2024 10:12:49 +0100 Subject: [PATCH 044/192] WIP widgets --- src/spikeinterface/core/sorting_tools.py | 1 - src/spikeinterface/core/sortingresult.py | 3 + .../postprocessing/correlograms.py | 25 ++++++-- .../postprocessing/unit_localization.py | 9 ++- src/spikeinterface/widgets/amplitudes.py | 21 +++---- .../widgets/crosscorrelograms.py | 24 ++++---- src/spikeinterface/widgets/quality_metrics.py | 15 +++-- .../widgets/template_metrics.py | 15 +++-- .../widgets/tests/test_widgets.py | 59 ++++++++++--------- src/spikeinterface/widgets/unit_locations.py | 38 ++++++------ src/spikeinterface/widgets/unit_probe_map.py | 27 +++++---- .../widgets/unit_waveforms_density_map.py | 40 +++++++------ 12 files changed, 156 insertions(+), 121 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index df0af26fc0..bd2fd21bf8 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -57,7 +57,6 @@ def vector_to_list_of_spiketrain_numpy(sample_indices, unit_indices, num_units): spike_trains.append(sample_indices[unit_indices == u]) return spike_trains - def get_numba_vector_to_list_of_spiketrain(): if hasattr(get_numba_vector_to_list_of_spiketrain, "_cached_numba_function"): return get_numba_vector_to_list_of_spiketrain._cached_numba_function diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index d0085925ae..8b4d0022ee 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -1184,6 +1184,9 @@ def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): # backward compatibility with WaveformsExtractor sorting_result = sorting_result.sorting_result + if not isinstance(sorting_result, SortingResult): + raise ValueError(f"compute_{self.extension_name}() need a SortingResult instance") + if load_if_exists is not None: # backward compatibility with "load_if_exists" warnings.warn(f"compute_{cls.extension_name}(..., load_if_exists=True/False) is kept for backward compatibility but should not be used anymore") diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 171e23dfa3..f81556a883 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.sortingresult import register_result_extension, ResultExtension, SortingResult try: import numba @@ -68,7 +68,7 @@ def _select_extension_data(self, unit_ids): return new_data def _run(self): - ccgs, bins = _compute_correlograms(self.sorting_result.sorting, **self.params) + ccgs, bins = compute_correlograms_on_sorting(self.sorting_result.sorting, **self.params) self.data["ccgs"] = ccgs self.data["bins"] = bins @@ -77,7 +77,22 @@ def _get_data(self): register_result_extension(ComputeCorrelograms) -compute_correlograms = ComputeCorrelograms.function_factory() +compute_correlograms_sorting_result = ComputeCorrelograms.function_factory() + +def compute_correlograms( + sorting_result_or_sorting, + window_ms: float = 50.0, + bin_ms: float = 1.0, + method: str = "auto", +): + if isinstance(sorting_result_or_sorting, SortingResult): + return compute_correlograms_sorting_result(sorting_result_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + else: + return compute_correlograms_on_sorting(sorting_result_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + +compute_correlograms.__doc__ = compute_correlograms_sorting_result.__doc__ + + def _make_bins(sorting, window_ms, bin_ms): @@ -189,10 +204,10 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ # ccgs, bins = ccc.get_data() # return ccgs, bins # else: -# return _compute_correlograms(waveform_or_sorting_extractor, window_ms=window_ms, bin_ms=bin_ms, method=method) +# return compute_correlograms_on_sorting(waveform_or_sorting_extractor, window_ms=window_ms, bin_ms=bin_ms, method=method) -def _compute_correlograms(sorting, window_ms, bin_ms, method="auto"): +def compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): """ Computes several cross-correlogram in one course from several clusters. """ diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index e7aeed1269..38a714610e 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -81,8 +81,13 @@ def _run(self): self.data["unit_locations"] = unit_location def get_data(self, outputs="numpy"): - return self.data["unit_locations"] - + if outputs == "numpy": + return self.data["unit_locations"] + elif outputs == "by_unit": + locations_by_unit = {} + for unit_ind, unit_id in enumerate(self.sorting_result.unit_ids): + locations_by_unit[unit_id] = self.data["unit_locations"][unit_ind] + return locations_by_unit register_result_extension(ComputeUnitLocations) compute_unit_locations = ComputeUnitLocations.function_factory() diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 316af1472e..2bed83330a 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -6,7 +6,7 @@ from .base import BaseWidget, to_attr from .utils import get_some_colors -from ..core.waveform_extractor import WaveformExtractor +from ..core.sortingresult import SortingResult class AmplitudesWidget(BaseWidget): @@ -15,7 +15,7 @@ class AmplitudesWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor + sorting_result : SortingResult The input waveform extractor unit_ids : list or None, default: None List of unit ids @@ -38,7 +38,7 @@ class AmplitudesWidget(BaseWidget): def __init__( self, - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, unit_ids=None, unit_colors=None, segment_index=None, @@ -50,10 +50,11 @@ def __init__( backend=None, **backend_kwargs, ): - sorting = waveform_extractor.sorting - self.check_extensions(waveform_extractor, "spike_amplitudes") - sac = waveform_extractor.load_extension("spike_amplitudes") - amplitudes = sac.get_data(outputs="by_unit") + sorting = sorting_result.sorting + self.check_extensions(sorting_result, "spike_amplitudes") + + # TODO + amplitudes = sorting_result.load_extension("spike_amplitudes").get_data(outputs="by_unit") if unit_ids is None: unit_ids = sorting.unit_ids @@ -68,7 +69,7 @@ def __init__( else: segment_index = 0 amplitudes_segment = amplitudes[segment_index] - total_duration = waveform_extractor.get_num_samples(segment_index) / waveform_extractor.sampling_frequency + total_duration = sorting_result.get_num_samples(segment_index) / sorting_result.sampling_frequency spiketrains_segment = {} for i, unit_id in enumerate(sorting.unit_ids): @@ -98,7 +99,7 @@ def __init__( bins = 100 plot_data = dict( - waveform_extractor=waveform_extractor, + sorting_result=sorting_result, amplitudes=amplitudes_to_plot, unit_ids=unit_ids, unit_colors=unit_colors, @@ -186,7 +187,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() cm = 1 / 2.54 - we = data_plot["waveform_extractor"] + we = data_plot["sorting_result"] width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index dfc06180ee..445d7b5a6b 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -4,7 +4,7 @@ from typing import Union from .base import BaseWidget, to_attr -from ..core.waveform_extractor import WaveformExtractor +from ..core.sortingresult import SortingResult from ..core.basesorting import BaseSorting from ..postprocessing import compute_correlograms @@ -15,7 +15,7 @@ class CrossCorrelogramsWidget(BaseWidget): Parameters ---------- - waveform_or_sorting_extractor : WaveformExtractor or BaseSorting + sorting_result_or_sorting : SortingResult or BaseSorting The object to compute/get crosscorrelograms from unit_ids list or None, default: None List of unit ids @@ -23,10 +23,10 @@ class CrossCorrelogramsWidget(BaseWidget): For sortingview backend. Threshold for computing pair-wise cross-correlograms. If template similarity between two units is below this threshold, the cross-correlogram is not displayed window_ms : float, default: 100.0 - Window for CCGs in ms. If correlograms are already computed (e.g. with WaveformExtractor), + Window for CCGs in ms. If correlograms are already computed (e.g. with SortingResult), this argument is ignored bin_ms : float, default: 1.0 - Bin size in ms. If correlograms are already computed (e.g. with WaveformExtractor), + Bin size in ms. If correlograms are already computed (e.g. with SortingResult), this argument is ignored hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed @@ -36,7 +36,7 @@ class CrossCorrelogramsWidget(BaseWidget): def __init__( self, - waveform_or_sorting_extractor: Union[WaveformExtractor, BaseSorting], + sorting_result_or_sorting: Union[SortingResult, BaseSorting], unit_ids=None, min_similarity_for_correlograms=0.2, window_ms=100.0, @@ -49,16 +49,16 @@ def __init__( if min_similarity_for_correlograms is None: min_similarity_for_correlograms = 0 similarity = None - if isinstance(waveform_or_sorting_extractor, WaveformExtractor): - sorting = waveform_or_sorting_extractor.sorting - self.check_extensions(waveform_or_sorting_extractor, "correlograms") - ccc = waveform_or_sorting_extractor.load_extension("correlograms") + if isinstance(sorting_result_or_sorting, SortingResult): + sorting = sorting_result_or_sorting.sorting + self.check_extensions(sorting_result_or_sorting, "correlograms") + ccc = sorting_result_or_sorting.get_extension("correlograms") ccgs, bins = ccc.get_data() if min_similarity_for_correlograms > 0: - self.check_extensions(waveform_or_sorting_extractor, "similarity") - similarity = waveform_or_sorting_extractor.load_extension("similarity").get_data() + self.check_extensions(sorting_result_or_sorting, "template_similarity") + similarity = sorting_result_or_sorting.get_extension("template_similarity").get_data() else: - sorting = waveform_or_sorting_extractor + sorting = sorting_result_or_sorting ccgs, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms) if unit_ids is None: diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index 95446b36c1..bf63b0d494 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -1,7 +1,7 @@ from __future__ import annotations from .metrics import MetricsBaseWidget -from ..core.waveform_extractor import WaveformExtractor +from ..core.sortingresult import SortingResult class QualityMetricsWidget(MetricsBaseWidget): @@ -10,8 +10,8 @@ class QualityMetricsWidget(MetricsBaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The object to compute/get quality metrics from + sorting_result : SortingResult + The object to get quality metrics from unit_ids: list or None, default: None List of unit ids include_metrics: list or None, default: None @@ -26,7 +26,7 @@ class QualityMetricsWidget(MetricsBaseWidget): def __init__( self, - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, unit_ids=None, include_metrics=None, skip_metrics=None, @@ -35,11 +35,10 @@ def __init__( backend=None, **backend_kwargs, ): - self.check_extensions(waveform_extractor, "quality_metrics") - qlc = waveform_extractor.load_extension("quality_metrics") - quality_metrics = qlc.get_data() + self.check_extensions(sorting_result, "quality_metrics") + quality_metrics = sorting_result.get_extension("quality_metrics").get_data() - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting MetricsBaseWidget.__init__( self, diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index 4789176ced..9aaf071e3d 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -1,7 +1,7 @@ from __future__ import annotations from .metrics import MetricsBaseWidget -from ..core.waveform_extractor import WaveformExtractor +from ..core.sortingresult import SortingResult class TemplateMetricsWidget(MetricsBaseWidget): @@ -10,8 +10,8 @@ class TemplateMetricsWidget(MetricsBaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The object to compute/get template metrics from + sorting_result : SortingResult + The object to get quality metrics from unit_ids : list or None, default: None List of unit ids include_metrics : list or None, default: None @@ -26,7 +26,7 @@ class TemplateMetricsWidget(MetricsBaseWidget): def __init__( self, - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, unit_ids=None, include_metrics=None, skip_metrics=None, @@ -35,11 +35,10 @@ def __init__( backend=None, **backend_kwargs, ): - self.check_extensions(waveform_extractor, "template_metrics") - tmc = waveform_extractor.load_extension("template_metrics") - template_metrics = tmc.get_data() + self.check_extensions(sorting_result, "template_metrics") + template_metrics= sorting_result.get_extension("template_metrics").get_data() - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting MetricsBaseWidget.__init__( self, diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 827b764d28..724539c869 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -87,14 +87,14 @@ def setUpClass(cls): extensions_to_compute = dict( waveforms=dict(), templates=dict(), - # noise_levels=dict(), - # spike_amplitudes=dict(), - # unit_locations=dict(), - # spike_locations=dict(), - # quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), - # template_metrics=dict(), - # correlograms=dict(), - # template_similarity=dict(), + noise_levels=dict(), + spike_amplitudes=dict(), + unit_locations=dict(), + spike_locations=dict(), + quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), + template_metrics=dict(), + correlograms=dict(), + template_similarity=dict(), ) job_kwargs = dict(n_jobs=-1) @@ -128,16 +128,15 @@ def setUpClass(cls): cls.backend_kwargs = {"matplotlib": {}, "sortingview": {}, "ipywidgets": {"display": False}} - # cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) + cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) - # from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.sortingcomponents.peak_detection import detect_peaks - # cls.peaks = detect_peaks(cls.recording, method="locally_exclusive", **job_kwargs) + cls.peaks = detect_peaks(cls.recording, method="locally_exclusive", **job_kwargs) - @classmethod - def tearDownClass(cls): - del cls.recording, cls.sorting, cls.peaks, cls.gt_comp, cls.sorting_result_sparse, cls.sorting_result_dense - # cls._delete_widget_folders() + # @classmethod + # def tearDownClass(cls): + # del cls.recording, cls.sorting, cls.peaks, cls.gt_comp, cls.sorting_result_sparse, cls.sorting_result_dense def test_plot_traces(self): possible_backends = list(sw.TracesWidget.get_possible_backends()) @@ -322,6 +321,9 @@ def test_plot_unit_waveforms_density_map(self): sw.plot_unit_waveforms_density_map( self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) + sw.plot_unit_waveforms_density_map( + self.sorting_result_sparse, unit_ids=unit_ids, backend=backend, same_axis=True, **self.backend_kwargs[backend] + ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) @@ -365,7 +367,7 @@ def test_plot_autocorrelograms(self): **self.backend_kwargs[backend], ) - def test_plot_crosscorrelogram(self): + def test_plot_crosscorrelograms(self): possible_backends = list(sw.CrossCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -584,28 +586,29 @@ def test_plot_multicomparison(self): # mytest.test_plot_unit_waveforms_density_map() # mytest.test_plot_unit_summary() - # mytest.test_plot_all_amplitudes_distributions() + # mytest.test_plot_all_amplitudes_distributions() ## TODO vector amplitudes # mytest.test_plot_traces() # mytest.test_plot_unit_waveforms() - mytest.test_plot_unit_templates() + # mytest.test_plot_unit_templates() # mytest.test_plot_unit_depths() - # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() - # mytest.test_crosscorrelogram() - # mytest.test_isi_distribution() - # mytest.test_unit_locations() - # mytest.test_quality_metrics() - # mytest.test_template_metrics() - # mytest.test_amplitudes() + # mytest.test_plot_autocorrelograms() + # mytest.test_plot_crosscorrelograms() + # mytest.test_plot_isi_distribution() + # mytest.test_plot_unit_locations() + # mytest.test_plot_quality_metrics() + # mytest.test_plot_template_metrics() + # mytest.test_plot_amplitudes() ## TODO vector amplitudes # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() - # mytest.test_plot_rasters() + # mytest.test_plot_rasters() # mytest.test_plot_unit_probe_map() # mytest.test_plot_unit_presence() - # mytest.test_plot_multicomparison() + # mytest.test_plot_peak_activity() + mytest.test_plot_multicomparison() plt.show() - TestWidgets.tearDownClass() + # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 3fe2688ce1..f91f7291aa 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -7,7 +7,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from ..core.waveform_extractor import WaveformExtractor +from ..core.sortingresult import SortingResult class UnitLocationsWidget(BaseWidget): @@ -16,8 +16,8 @@ class UnitLocationsWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The object to compute/get unit locations from + sorting_result : SortingResult + The SortingResult that must contains "unit_locations" extension unit_ids : list or None, default: None List of unit ids with_channel_ids : bool, default: False @@ -37,7 +37,7 @@ class UnitLocationsWidget(BaseWidget): def __init__( self, - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, unit_ids=None, with_channel_ids=False, unit_colors=None, @@ -48,15 +48,15 @@ def __init__( backend=None, **backend_kwargs, ): - self.check_extensions(waveform_extractor, "unit_locations") - ulc = waveform_extractor.load_extension("unit_locations") + self.check_extensions(sorting_result, "unit_locations") + ulc = sorting_result.get_extension("unit_locations") unit_locations = ulc.get_data(outputs="by_unit") - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting - channel_ids = waveform_extractor.channel_ids - channel_locations = waveform_extractor.get_channel_locations() - probegroup = waveform_extractor.get_probegroup() + channel_ids = sorting_result.channel_ids + channel_locations = sorting_result.get_channel_locations() + probegroup = sorting_result.get_probegroup() if unit_colors is None: unit_colors = get_unit_colors(sorting) @@ -127,11 +127,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.plot_all_units: unit_colors = {} unit_ids = dp.all_unit_ids - for unit in dp.all_unit_ids: - if unit not in dp.unit_ids: - unit_colors[unit] = "gray" + for unit_id in dp.all_unit_ids: + if unit_id not in dp.unit_ids: + unit_colors[unit_id] = "gray" else: - unit_colors[unit] = dp.unit_colors[unit] + unit_colors[unit_id] = dp.unit_colors[unit_id] else: unit_ids = dp.unit_ids unit_colors = dp.unit_colors @@ -139,13 +139,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): patches = [ Ellipse( - (unit_locations[unit]), - color=unit_colors[unit], - zorder=5 if unit in dp.unit_ids else 3, - alpha=0.9 if unit in dp.unit_ids else 0.5, + (unit_locations[unit_id]), + color=unit_colors[unit_id], + zorder=5 if unit_id in dp.unit_ids else 3, + alpha=0.9 if unit_id in dp.unit_ids else 0.5, **ellipse_kwargs, ) - for i, unit in enumerate(unit_ids) + for unit_ind, unit_id in enumerate(unit_ids) ] for p in patches: self.ax.add_patch(p) diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index e1439e7356..fe6e9f3c03 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -8,8 +8,8 @@ from .base import BaseWidget, to_attr # from .utils import get_unit_colors -from ..core.waveform_extractor import WaveformExtractor - +from ..core.sortingresult import SortingResult +from ..core.template_tools import _get_dense_templates_array class UnitProbeMapWidget(BaseWidget): """ @@ -19,7 +19,7 @@ class UnitProbeMapWidget(BaseWidget): Parameters ---------- - waveform_extractor: WaveformExtractor + sorting_result: SortingResult unit_ids: list List of unit ids. channel_ids: list @@ -32,7 +32,7 @@ class UnitProbeMapWidget(BaseWidget): def __init__( self, - waveform_extractor, + sorting_result, unit_ids=None, channel_ids=None, animated=None, @@ -42,14 +42,14 @@ def __init__( **backend_kwargs, ): if unit_ids is None: - unit_ids = waveform_extractor.sorting.unit_ids + unit_ids = sorting_result.unit_ids self.unit_ids = unit_ids if channel_ids is None: - channel_ids = waveform_extractor.recording.channel_ids + channel_ids = sorting_result.channel_ids self.channel_ids = channel_ids data_plot = dict( - waveform_extractor=waveform_extractor, + sorting_result=sorting_result, unit_ids=unit_ids, channel_ids=channel_ids, animated=animated, @@ -73,15 +73,19 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - we = dp.waveform_extractor - probe = we.get_probe() + sorting_result = dp.sorting_result + probe = sorting_result.get_probe() probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) + templates = _get_dense_templates_array(sorting_result, return_scaled=True) + templates = templates[sorting_result.sorting.ids_to_indices(dp.unit_ids), :, :] + all_poly_contact = [] for i, unit_id in enumerate(dp.unit_ids): ax = self.axes.flatten()[i] - template = we.get_template(unit_id) + # template = we.get_template(unit_id) + template = templates[i, :, :] # static if dp.animated: contacts_values = np.zeros(template.shape[1]) @@ -116,7 +120,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def animate_func(frame): for i, unit_id in enumerate(self.unit_ids): - template = we.get_template(unit_id) + # template = we.get_template(unit_id) + template = templates[i, :, :] contacts_values = np.abs(template[frame, :]) poly_contact = all_poly_contact[i] poly_contact.set_array(contacts_values) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 631600c919..d990d2055a 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -14,15 +14,15 @@ class UnitWaveformDensityMapWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The waveformextractor for calculating waveforms + sorting_result : SortingResult + The SortingResult for calculating waveforms channel_ids : list or None, default: None The channel ids to display unit_ids : list or None, default: None List of unit ids sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply - If WaveformExtractor is already sparse, the argument is ignored + If SortingResult is already sparse, the argument is ignored use_max_channel : bool, default: False Use only the max channel peak_sign : "neg" | "pos" | "both", default: "neg" @@ -37,7 +37,7 @@ class UnitWaveformDensityMapWidget(BaseWidget): def __init__( self, - waveform_extractor, + sorting_result, channel_ids=None, unit_ids=None, sparsity=None, @@ -48,36 +48,36 @@ def __init__( backend=None, **backend_kwargs, ): - we = waveform_extractor if channel_ids is None: - channel_ids = we.channel_ids + channel_ids = sorting_result.channel_ids if unit_ids is None: - unit_ids = we.unit_ids + unit_ids = sorting_result.unit_ids if unit_colors is None: - unit_colors = get_unit_colors(we.sorting) + unit_colors = get_unit_colors(sorting_result.sorting) if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" - max_channels = get_template_extremum_channel(we, mode="extremum", peak_sign=peak_sign, outputs="index") + max_channels = get_template_extremum_channel(sorting_result, mode="extremum", peak_sign=peak_sign, outputs="index") # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all - if waveform_extractor.is_sparse(): - assert sparsity is None, "UnitWaveformDensity WaveformExtractor is already sparse" - used_sparsity = waveform_extractor.sparsity + if sorting_result.is_sparse(): + assert sparsity is None, "UnitWaveformDensity SortingResult is already sparse" + used_sparsity = sorting_result.sparsity elif sparsity is not None: assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" used_sparsity = sparsity else: # in this case, we construct a dense sparsity - used_sparsity = ChannelSparsity.create_dense(we) + used_sparsity = ChannelSparsity.create_dense(sorting_result) channel_inds = used_sparsity.unit_id_to_channel_indices # bins - templates = we.get_all_templates(unit_ids=unit_ids) + # templates = we.get_all_templates(unit_ids=unit_ids) + templates = sorting_result.get_extension("templates").get_templates(unit_ids=unit_ids) bin_min = np.min(templates) * 1.3 bin_max = np.max(templates) * 1.3 bin_size = (bin_max - bin_min) / 100 @@ -87,16 +87,22 @@ def __init__( if same_axis: all_hist2d = None # channel union across units - unit_inds = we.sorting.ids_to_indices(unit_ids) + unit_inds = sorting_result.sorting.ids_to_indices(unit_ids) (shared_chan_inds,) = np.nonzero(np.sum(used_sparsity.mask[unit_inds, :], axis=0)) else: all_hist2d = {} + wf_ext = sorting_result.get_extension("waveforms") for unit_index, unit_id in enumerate(unit_ids): chan_inds = channel_inds[unit_id] # this have already the sparsity - wfs = we.get_waveforms(unit_id, sparsity=sparsity) + # wfs = we.get_waveforms(unit_id, sparsity=sparsity) + + wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) + if sparsity is not None: + # external sparsity + wfs = wfs[:, sparsity.mask[:, unit_index]] if use_max_channel: chan_ind = max_channels[unit_id] @@ -145,7 +151,7 @@ def __init__( plot_data = dict( unit_ids=unit_ids, unit_colors=unit_colors, - channel_ids=we.channel_ids, + channel_ids=sorting_result.channel_ids, channel_inds=channel_inds, same_axis=same_axis, bin_min=bin_min, From cec1fd37a673683b7421b34b1756ab7e93003e6c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 9 Feb 2024 11:04:38 +0100 Subject: [PATCH 045/192] wip widgets --- src/spikeinterface/core/sorting_tools.py | 45 +++++++++++++++++++ .../core/tests/test_sorting_tools.py | 21 +++++++-- .../postprocessing/spike_amplitudes.py | 20 ++++++++- .../widgets/all_amplitudes_distributions.py | 19 ++++---- src/spikeinterface/widgets/amplitudes.py | 3 +- .../widgets/tests/test_widgets.py | 22 ++++----- 6 files changed, 103 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index bd2fd21bf8..bccf313c87 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -46,6 +46,51 @@ def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.arra return spike_trains +def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array): + """ + Similar to spike_vector_to_spike_trains but instead having the spike_trains (aka spike times) return + spike indices by segment and units. + + This is usefull to split back other unique vector like "spike_amplitudes", "spike_locations" into dict of dict + Internally calls numba if numba is installed. + + Parameters + ---------- + spike_vector: list[np.ndarray] + List of spike vectors optained with sorting.to_spike_vector(concatenated=False) + unit_ids: np.array + Unit ids + Returns + ------- + spike_indices: dict[dict]: + A dict containing, for each segment, the spike indices of all units + (as a dict: unit_id --> index). + """ + try: + import numba + + HAVE_NUMBA = True + except: + HAVE_NUMBA = False + + if HAVE_NUMBA: + # the trick here is to have a function getter + vector_to_list_of_spiketrain = get_numba_vector_to_list_of_spiketrain() + else: + vector_to_list_of_spiketrain = vector_to_list_of_spiketrain_numpy + + num_units = unit_ids.size + spike_indices = {} + for segment_index, spikes in enumerate(spike_vector): + indices = np.arange(spikes.size, dtype=np.int64) + unit_indices = np.array(spikes["unit_index"]).astype(np.int64, copy=False) + list_of_spike_indices = vector_to_list_of_spiketrain(indices, unit_indices, num_units) + spike_indices[segment_index] = dict(zip(unit_ids, list_of_spike_indices)) + + return spike_indices + + + def vector_to_list_of_spiketrain_numpy(sample_indices, unit_indices, num_units): """ diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 06819fa514..fe169e7448 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -5,7 +5,7 @@ from spikeinterface.core import NumpySorting from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core.sorting_tools import spike_vector_to_spike_trains, random_spikes_selection +from spikeinterface.core.sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, spike_vector_to_indices @pytest.mark.skipif( @@ -20,6 +20,20 @@ def test_spike_vector_to_spike_trains(): for unit_index, unit_id in enumerate(sorting.unit_ids): assert np.array_equal(spike_trains[0][unit_id], sorting.get_unit_spike_train(unit_id=unit_id, segment_index=0)) +def test_spike_vector_to_indices(): + sorting = NumpySorting.from_unit_dict({1: np.array([0, 51, 108]), 5: np.array([23, 87])}, 30_000) + spike_vector = sorting.to_spike_vector(concatenated=False) + spike_indices = spike_vector_to_indices(spike_vector, sorting.unit_ids) + + segment_index = 0 + assert len(spike_indices[segment_index]) == sorting.get_num_units() + for unit_index, unit_id in enumerate(sorting.unit_ids): + inds = spike_indices[segment_index][unit_id] + assert np.array_equal( + spike_vector[segment_index][inds]["sample_index"], + sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + ) + def test_random_spikes_selection(): recording, sorting = generate_ground_truth_recording( @@ -52,5 +66,6 @@ def test_random_spikes_selection(): if __name__ == "__main__": - test_spike_vector_to_spike_trains() - test_random_spikes_selection() + # test_spike_vector_to_spike_trains() + test_spike_vector_to_indices() + # test_random_spikes_selection() diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index b9000ddc06..ac1c2079d5 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -9,6 +9,7 @@ from spikeinterface.core.sortingresult import register_result_extension, ResultExtension from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type +from spikeinterface.core.sorting_tools import spike_vector_to_indices class ComputeSpikeAmplitudes(ResultExtension): """ @@ -118,8 +119,23 @@ def _run(self, **job_kwargs): ) self.data["amplitudes"] = amps - def _get_data(self): - return self.data["amplitudes"] + def _get_data(self, outputs="numpy"): + all_amplitudes = self.data["amplitudes"] + if outputs == "numpy": + return all_amplitudes + elif outputs == "by_unit": + unit_ids = self.sorting_result.unit_ids + spike_vector = self.sorting_result.sorting.to_spike_vector(concatenated=False) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids) + amplitudes_by_units = {} + for segment_index in range(self.sorting_result.sorting.get_num_segments()): + amplitudes_by_units[segment_index] = {} + for unit_id in unit_ids: + inds = spike_indices[segment_index][unit_id] + amplitudes_by_units[segment_index][unit_id] = all_amplitudes[inds] + return amplitudes_by_units + else: + raise ValueError(f"Wrong .get_data(outputs={outputs})") register_result_extension(ComputeSpikeAmplitudes) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index dcaa8653fd..595ade591b 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -39,11 +39,19 @@ def __init__( if unit_colors is None: unit_colors = get_some_colors(sorting_result.unit_ids) + amplitudes_by_units = {} + spikes = sorting_result.sorting.to_spike_vector() + for unit_id in unit_ids: + unit_index = sorting_result.sorting.id_to_index(unit_id) + spike_mask = spikes["unit_index"] == unit_index + amplitudes_by_units[unit_id] = amplitudes[spike_mask] + + plot_data = dict( unit_ids=unit_ids, unit_colors=unit_colors, num_segments=num_segments, - amplitudes=amplitudes, + amplitudes_by_units=amplitudes_by_units, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -58,14 +66,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = self.ax - unit_amps = [] - for i, unit_id in enumerate(dp.unit_ids): - amps = [] - for segment_index in range(dp.num_segments): - amps.append(dp.amplitudes[segment_index][unit_id]) - amps = np.concatenate(amps) - unit_amps.append(amps) - parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) + parts = ax.violinplot(list(dp.amplitudes_by_units.values()), showmeans=False, showmedians=False, showextrema=False) for i, pc in enumerate(parts["bodies"]): color = dp.unit_colors[dp.unit_ids[i]] diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 2bed83330a..67432d7b66 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -53,8 +53,7 @@ def __init__( sorting = sorting_result.sorting self.check_extensions(sorting_result, "spike_amplitudes") - # TODO - amplitudes = sorting_result.load_extension("spike_amplitudes").get_data(outputs="by_unit") + amplitudes = sorting_result.get_extension("spike_amplitudes").get_data(outputs="by_unit") if unit_ids is None: unit_ids = sorting.unit_ids diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 724539c869..b8c9aae3dc 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -89,12 +89,12 @@ def setUpClass(cls): templates=dict(), noise_levels=dict(), spike_amplitudes=dict(), - unit_locations=dict(), - spike_locations=dict(), - quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), - template_metrics=dict(), - correlograms=dict(), - template_similarity=dict(), + # unit_locations=dict(), + # spike_locations=dict(), + # quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), + # template_metrics=dict(), + # correlograms=dict(), + # template_similarity=dict(), ) job_kwargs = dict(n_jobs=-1) @@ -130,9 +130,9 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) - from spikeinterface.sortingcomponents.peak_detection import detect_peaks + # from spikeinterface.sortingcomponents.peak_detection import detect_peaks - cls.peaks = detect_peaks(cls.recording, method="locally_exclusive", **job_kwargs) + # cls.peaks = detect_peaks(cls.recording, method="locally_exclusive", **job_kwargs) # @classmethod # def tearDownClass(cls): @@ -586,7 +586,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_unit_waveforms_density_map() # mytest.test_plot_unit_summary() - # mytest.test_plot_all_amplitudes_distributions() ## TODO vector amplitudes + # mytest.test_plot_all_amplitudes_distributions() # mytest.test_plot_traces() # mytest.test_plot_unit_waveforms() # mytest.test_plot_unit_templates() @@ -598,7 +598,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_unit_locations() # mytest.test_plot_quality_metrics() # mytest.test_plot_template_metrics() - # mytest.test_plot_amplitudes() ## TODO vector amplitudes + mytest.test_plot_amplitudes() # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() @@ -606,7 +606,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_unit_probe_map() # mytest.test_plot_unit_presence() # mytest.test_plot_peak_activity() - mytest.test_plot_multicomparison() + # mytest.test_plot_multicomparison() plt.show() # TestWidgets.tearDownClass() From c465f56af05775a1e9f19c159dfa6a64e1679a08 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 9 Feb 2024 11:43:22 +0100 Subject: [PATCH 046/192] after 0.100 release --- pyproject.toml | 14 +++++++------- src/spikeinterface/__init__.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a3384a5482..804c89178e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.100.0" +version = "0.101.0" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -119,8 +119,8 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test = [ @@ -152,8 +152,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -170,8 +170,8 @@ docs = [ "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions # for release we need pypi, so this needs to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 97fb95b623..306c12d516 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -30,5 +30,5 @@ # This flag must be set to False for release # This avoids using versioning that contains ".dev0" (and this is a better choice) # This is mainly useful when using run_sorter in a container and spikeinterface install -# DEV_MODE = True -DEV_MODE = False +DEV_MODE = True +# DEV_MODE = False From 2809c9e7cd55ea94f0daaebb7993ff05fc7dbc35 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 9 Feb 2024 17:54:21 +0100 Subject: [PATCH 047/192] Widgets are passing tests with SortingResult --- .../postprocessing/spike_locations.py | 20 ++++- src/spikeinterface/widgets/peak_activity.py | 5 +- src/spikeinterface/widgets/sorting_summary.py | 31 ++++--- src/spikeinterface/widgets/spike_locations.py | 23 +++-- .../widgets/spikes_on_traces.py | 47 +++++----- .../widgets/template_similarity.py | 14 +-- .../widgets/tests/test_widgets.py | 85 ++++++++----------- src/spikeinterface/widgets/unit_summary.py | 37 ++++---- .../widgets/unit_waveforms_density_map.py | 8 +- 9 files changed, 135 insertions(+), 135 deletions(-) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 78c9e5bb39..03a1e7d52a 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -6,6 +6,7 @@ from spikeinterface.core.sortingresult import register_result_extension, ResultExtension from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.core.sorting_tools import spike_vector_to_indices from spikeinterface.core.node_pipeline import SpikeRetriever, run_node_pipeline @@ -122,8 +123,23 @@ def _run(self, **job_kwargs): ) self.data["spike_locations"] = spike_locations - def _get_data(self, outputs="concatenated"): - return self.data["spike_locations"] + def _get_data(self, outputs="numpy"): + all_spike_locations = self.data["spike_locations"] + if outputs == "numpy": + return all_spike_locations + elif outputs == "by_unit": + unit_ids = self.sorting_result.unit_ids + spike_vector = self.sorting_result.sorting.to_spike_vector(concatenated=False) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids) + spike_locations_by_units = {} + for segment_index in range(self.sorting_result.sorting.get_num_segments()): + spike_locations_by_units[segment_index] = {} + for unit_id in unit_ids: + inds = spike_indices[segment_index][unit_id] + spike_locations_by_units[segment_index][unit_id] = all_spike_locations[inds] + return spike_locations_by_units + else: + raise ValueError(f"Wrong .get_data(outputs={outputs})") ComputeSpikeLocations.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/widgets/peak_activity.py b/src/spikeinterface/widgets/peak_activity.py index 62f0c2d6d1..6375d0ff29 100644 --- a/src/spikeinterface/widgets/peak_activity.py +++ b/src/spikeinterface/widgets/peak_activity.py @@ -1,13 +1,10 @@ from __future__ import annotations import numpy as np -from typing import Union -from probeinterface import ProbeGroup from .base import BaseWidget, to_attr -from .utils import get_unit_colors -from ..core.waveform_extractor import WaveformExtractor + class PeakActivityMapWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 296d854222..aa952bc6ef 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -11,7 +11,7 @@ from .unit_templates import UnitTemplatesWidget -from ..core import WaveformExtractor +from ..core import SortingResult class SortingSummaryWidget(BaseWidget): @@ -20,13 +20,13 @@ class SortingSummaryWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object + sorting_result : SortingResult + The SortingResult object unit_ids : list or None, default: None List of unit ids sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply - If WaveformExtractor is already sparse, the argument is ignored + If SortingResult is already sparse, the argument is ignored max_amplitudes_per_unit : int or None, default: None Maximum number of spikes per unit for plotting amplitudes. If None, all spikes are plotted @@ -47,7 +47,7 @@ class SortingSummaryWidget(BaseWidget): def __init__( self, - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, unit_ids=None, sparsity=None, max_amplitudes_per_unit=None, @@ -58,15 +58,14 @@ def __init__( backend=None, **backend_kwargs, ): - self.check_extensions(waveform_extractor, ["correlograms", "spike_amplitudes", "unit_locations", "similarity"]) - we = waveform_extractor - sorting = we.sorting + self.check_extensions(sorting_result, ["correlograms", "spike_amplitudes", "unit_locations", "similarity"]) + sorting = sorting_result.sorting if unit_ids is None: unit_ids = sorting.get_unit_ids() plot_data = dict( - waveform_extractor=waveform_extractor, + sorting_result=sorting_result, unit_ids=unit_ids, sparsity=sparsity, min_similarity_for_correlograms=min_similarity_for_correlograms, @@ -83,7 +82,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) - we = dp.waveform_extractor + sorting_result = dp.sorting_result unit_ids = dp.unit_ids sparsity = dp.sparsity min_similarity_for_correlograms = dp.min_similarity_for_correlograms @@ -91,7 +90,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = make_serializable(dp.unit_ids) v_spike_amplitudes = AmplitudesWidget( - we, + sorting_result, unit_ids=unit_ids, max_spikes_per_unit=dp.max_amplitudes_per_unit, hide_unit_selector=True, @@ -100,7 +99,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): backend="sortingview", ).view v_average_waveforms = UnitTemplatesWidget( - we, + sorting_result, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True, @@ -109,7 +108,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): backend="sortingview", ).view v_cross_correlograms = CrossCorrelogramsWidget( - we, + sorting_result, unit_ids=unit_ids, min_similarity_for_correlograms=min_similarity_for_correlograms, hide_unit_selector=True, @@ -119,11 +118,11 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ).view v_unit_locations = UnitLocationsWidget( - we, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview" + sorting_result, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview" ).view w = TemplateSimilarityWidget( - we, unit_ids=unit_ids, immediate_plot=False, generate_url=False, display=False, backend="sortingview" + sorting_result, unit_ids=unit_ids, immediate_plot=False, generate_url=False, display=False, backend="sortingview" ) similarity = w.data_plot["similarity"] @@ -137,7 +136,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # unit ids v_units_table = generate_unit_table_view( - dp.waveform_extractor.sorting, dp.unit_table_properties, similarity_scores=similarity_scores + dp.sorting_result.sorting, dp.unit_table_properties, similarity_scores=similarity_scores ) if dp.curation: diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 736e6193a9..77c6537ea4 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -4,7 +4,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from ..core.waveform_extractor import WaveformExtractor +from ..core.sortingresult import SortingResult class SpikeLocationsWidget(BaseWidget): @@ -13,8 +13,8 @@ class SpikeLocationsWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The object to compute/get spike locations from + sorting_result : SortingResult + The object to get spike locations from unit_ids : list or None, default: None List of unit ids segment_index : int or None, default: None @@ -40,7 +40,7 @@ class SpikeLocationsWidget(BaseWidget): def __init__( self, - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, unit_ids=None, segment_index=None, max_spikes_per_unit=500, @@ -53,15 +53,14 @@ def __init__( backend=None, **backend_kwargs, ): - self.check_extensions(waveform_extractor, "spike_locations") - slc = waveform_extractor.load_extension("spike_locations") - spike_locations = slc.get_data(outputs="by_unit") + self.check_extensions(sorting_result, "spike_locations") + spike_locations_by_units = sorting_result.get_extension("spike_locations").get_data(outputs="by_unit") - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting - channel_ids = waveform_extractor.channel_ids - channel_locations = waveform_extractor.get_channel_locations() - probegroup = waveform_extractor.get_probegroup() + channel_ids = sorting_result.channel_ids + channel_locations = sorting_result.get_channel_locations() + probegroup = sorting_result.get_probegroup() if sorting.get_num_segments() > 1: assert segment_index is not None, "Specify segment index for multi-segment object" @@ -74,7 +73,7 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - all_spike_locs = spike_locations[segment_index] + all_spike_locs = spike_locations_by_units[segment_index] if max_spikes_per_unit is None: spike_locs = all_spike_locs else: diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 42fbd623cd..651170bd3d 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -7,7 +7,7 @@ from .traces import TracesWidget from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel -from ..core.waveform_extractor import WaveformExtractor +from ..core.sortingresult import SortingResult from ..core.baserecording import BaseRecording from ..core.basesorting import BaseSorting from ..postprocessing import compute_unit_locations @@ -19,8 +19,8 @@ class SpikesOnTracesWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor + sorting_result : SortingResult + The SortingResult channel_ids : list or None, default: None The channel ids to display unit_ids : list or None, default: None @@ -31,7 +31,7 @@ class SpikesOnTracesWidget(BaseWidget): List with start time and end time in seconds sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply - If WaveformExtractor is already sparse, the argument is ignored + If SortingResult is already sparse, the argument is ignored unit_colors : dict or None, default: None If given, a dictionary with unit ids as keys and colors as values If None, then the get_unit_colors() is internally used. (matplotlib backend) @@ -62,7 +62,7 @@ class SpikesOnTracesWidget(BaseWidget): def __init__( self, - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, segment_index=None, channel_ids=None, unit_ids=None, @@ -83,8 +83,10 @@ def __init__( backend=None, **backend_kwargs, ): - we = waveform_extractor - sorting: BaseSorting = we.sorting + + self.check_extensions(sorting_result, "unit_locations") + + sorting: BaseSorting = sorting_result.sorting if unit_ids is None: unit_ids = sorting.get_unit_ids() @@ -94,21 +96,21 @@ def __init__( unit_colors = get_unit_colors(sorting) # sparsity is done on all the units even if unit_ids is a few ones because some backend need then all - if waveform_extractor.is_sparse(): - sparsity = waveform_extractor.sparsity + if sorting_result.is_sparse(): + sparsity = sorting_result.sparsity else: if sparsity is None: # in this case, we construct a sparsity dictionary only with the best channel - extremum_channel_ids = get_template_extremum_channel(we) + extremum_channel_ids = get_template_extremum_channel(sorting_result) unit_id_to_channel_ids = {u: [ch] for u, ch in extremum_channel_ids.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( - unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=we.unit_ids, channel_ids=we.channel_ids + unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_result.unit_ids, + channel_ids=sorting_result.channel_ids ) else: assert isinstance(sparsity, ChannelSparsity) - # get templates - unit_locations = compute_unit_locations(we, outputs="by_unit") + unit_locations = sorting_result.get_extension("unit_locations").get_data(outputs="by_unit") options = dict( segment_index=segment_index, @@ -127,7 +129,7 @@ def __init__( ) plot_data = dict( - waveform_extractor=waveform_extractor, + sorting_result=sorting_result, options=options, unit_ids=unit_ids, sparsity=sparsity, @@ -145,9 +147,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.lines import Line2D dp = to_attr(data_plot) - we = dp.waveform_extractor - recording = we.recording - sorting = we.sorting + sorting_result = dp.sorting_result + recording = sorting_result.recording + sorting = sorting_result.sorting # first plot time series traces_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) @@ -209,8 +211,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if len(spike_frames_to_plot) > 0: vspacing = traces_widget.data_plot["vspacing"] traces = traces_widget.data_plot["list_traces"][0] - - waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] + + # TODO find a better way + nbefore = 30 + nafter = 60 + waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-nbefore, nafter) - frame_range[0] waveform_idxs = np.clip(waveform_idxs, 0, len(traces_widget.data_plot["times"]) - 1) times = traces_widget.data_plot["times"][waveform_idxs] @@ -242,7 +247,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() dp = to_attr(data_plot) - we = dp.waveform_extractor + sorting_result = dp.sorting_result ratios = [0.2, 0.8] @@ -253,7 +258,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - self._traces_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + self._traces_widget = TracesWidget(sorting_result.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) self.ax = self._traces_widget.ax self.axes = self._traces_widget.axes self.figure = self._traces_widget.figure diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index fe4db0cc6d..9c36b22309 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -3,7 +3,7 @@ import numpy as np from .base import BaseWidget, to_attr -from ..core.waveform_extractor import WaveformExtractor +from ..core.sortingresult import SortingResult class TemplateSimilarityWidget(BaseWidget): @@ -12,8 +12,8 @@ class TemplateSimilarityWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The object to compute/get template similarity from + sorting_result : SortingResult + The object to get template similarity from unit_ids : list or None, default: None List of unit ids default: None display_diagonal_values : bool, default: False @@ -29,7 +29,7 @@ class TemplateSimilarityWidget(BaseWidget): def __init__( self, - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, unit_ids=None, cmap="viridis", display_diagonal_values=False, @@ -38,11 +38,11 @@ def __init__( backend=None, **backend_kwargs, ): - self.check_extensions(waveform_extractor, "similarity") - tsc = waveform_extractor.load_extension("similarity") + self.check_extensions(sorting_result, "template_similarity") + tsc = sorting_result.get_extension("template_similarity") similarity = tsc.get_data().copy() - sorting = waveform_extractor.sorting + sorting = sorting_result.sorting if unit_ids is None: unit_ids = sorting.unit_ids else: diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index b8c9aae3dc..4628682ed5 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -2,7 +2,6 @@ import pytest import os from pathlib import Path -import shutil if __name__ != "__main__": import matplotlib @@ -16,22 +15,12 @@ compute_sparsity, generate_ground_truth_recording, start_sorting_result, - load_sorting_result, ) -import spikeinterface.extractors as se + import spikeinterface.widgets as sw import spikeinterface.comparison as sc from spikeinterface.preprocessing import scale -from spikeinterface.postprocessing import ( - compute_correlograms, - compute_spike_amplitudes, - compute_spike_locations, - compute_unit_locations, - compute_template_metrics, - compute_template_similarity, -) -from spikeinterface.qualitymetrics import compute_quality_metrics if hasattr(pytest, "global_test_folder"): @@ -45,20 +34,9 @@ class TestWidgets(unittest.TestCase): - # @classmethod - # def _delete_widget_folders(cls): - # for name in ( - # "recording", - # "sorting", - # "we_dense", - # "we_sparse", - # ): - # if (cache_folder / name).is_dir(): - # shutil.rmtree(cache_folder / name) @classmethod def setUpClass(cls): - # cls._delete_widget_folders() recording, sorting = generate_ground_truth_recording( durations=[30.0], @@ -89,12 +67,12 @@ def setUpClass(cls): templates=dict(), noise_levels=dict(), spike_amplitudes=dict(), - # unit_locations=dict(), - # spike_locations=dict(), - # quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), - # template_metrics=dict(), - # correlograms=dict(), - # template_similarity=dict(), + unit_locations=dict(), + spike_locations=dict(), + quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), + template_metrics=dict(), + correlograms=dict(), + template_similarity=dict(), ) job_kwargs = dict(n_jobs=-1) @@ -130,13 +108,9 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) - # from spikeinterface.sortingcomponents.peak_detection import detect_peaks - - # cls.peaks = detect_peaks(cls.recording, method="locally_exclusive", **job_kwargs) + from spikeinterface.sortingcomponents.peak_detection import detect_peaks - # @classmethod - # def tearDownClass(cls): - # del cls.recording, cls.sorting, cls.peaks, cls.gt_comp, cls.sorting_result_sparse, cls.sorting_result_dense + cls.peaks = detect_peaks(cls.recording, method="locally_exclusive", **job_kwargs) def test_plot_traces(self): possible_backends = list(sw.TracesWidget.get_possible_backends()) @@ -175,6 +149,14 @@ def test_plot_traces(self): **self.backend_kwargs[backend], ) + def test_plot_spikes_on_traces(self): + possible_backends = list(sw.SpikesOnTracesWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_spikes_on_traces(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + + + def test_plot_unit_waveforms(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: @@ -318,18 +300,18 @@ def test_plot_unit_waveforms_density_map(self): for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] + + # on dense sw.plot_unit_waveforms_density_map( - self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_dense, + unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) + # on sparse sw.plot_unit_waveforms_density_map( - self.sorting_result_sparse, unit_ids=unit_ids, backend=backend, same_axis=True, **self.backend_kwargs[backend] + self.sorting_result_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) - def test_plot_unit_waveforms_density_map_sparsity_radius(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) - for backend in possible_backends: - if backend not in self.skip_backends: - unit_ids = self.sorting.unit_ids[:2] + # externals parsity sw.plot_unit_waveforms_density_map( self.sorting_result_dense, sparsity=self.sparsity_radius, @@ -339,11 +321,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): **self.backend_kwargs[backend], ) - def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) - for backend in possible_backends: - if backend not in self.skip_backends: - unit_ids = self.sorting.unit_ids[:2] + # on sparse with same_axis sw.plot_unit_waveforms_density_map( self.sorting_result_sparse, sparsity=None, @@ -576,6 +554,8 @@ def test_plot_multicomparison(self): if backend == "matplotlib": _, axes = plt.subplots(len(mcmp.object_list), 1) sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) + + if __name__ == "__main__": @@ -584,29 +564,32 @@ def test_plot_multicomparison(self): TestWidgets.setUpClass() mytest = TestWidgets() - # mytest.test_plot_unit_waveforms_density_map() + mytest.test_plot_unit_waveforms_density_map() # mytest.test_plot_unit_summary() # mytest.test_plot_all_amplitudes_distributions() # mytest.test_plot_traces() + # mytest.test_plot_spikes_on_traces() # mytest.test_plot_unit_waveforms() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_depths() - # mytest.test_plot_unit_summary() # mytest.test_plot_autocorrelograms() # mytest.test_plot_crosscorrelograms() # mytest.test_plot_isi_distribution() # mytest.test_plot_unit_locations() + # mytest.test_plot_spike_locations() + # mytest.test_plot_similarity() # mytest.test_plot_quality_metrics() # mytest.test_plot_template_metrics() - mytest.test_plot_amplitudes() + # mytest.test_plot_amplitudes() # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() # mytest.test_plot_rasters() # mytest.test_plot_unit_probe_map() # mytest.test_plot_unit_presence() - # mytest.test_plot_peak_activity() + # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() + # mytest.test_plot_sorting_summary() plt.show() # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index cef09e29a7..016738d393 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -21,22 +21,22 @@ class UnitSummaryWidget(BaseWidget): Parameters ---------- - waveform_extractor : WaveformExtractor - The waveform extractor object + sorting_result : SortingResult + The SortingResult object unit_id : int or str The unit id to plot the summary of unit_colors : dict or None, default: None If given, a dictionary with unit ids as keys and colors as values, sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. - If WaveformExtractor is already sparse, the argument is ignored + If SortingResult is already sparse, the argument is ignored """ # possible_backends = {} def __init__( self, - waveform_extractor, + sorting_result, unit_id, unit_colors=None, sparsity=None, @@ -44,13 +44,12 @@ def __init__( backend=None, **backend_kwargs, ): - we = waveform_extractor if unit_colors is None: - unit_colors = get_unit_colors(we.sorting) + unit_colors = get_unit_colors(sorting_result.sorting) plot_data = dict( - we=we, + sorting_result=sorting_result, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, @@ -65,7 +64,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) unit_id = dp.unit_id - we = dp.we + sorting_result = dp.sorting_result unit_colors = dp.unit_colors sparsity = dp.sparsity @@ -82,20 +81,20 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 3 - if we.has_extension("correlograms") or we.has_extension("spike_amplitudes"): + if sorting_result.has_extension("correlograms") or sorting_result.has_extension("spike_amplitudes"): ncols += 1 - if we.has_extension("spike_amplitudes"): + if sorting_result.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) - if we.has_extension("unit_locations"): + if sorting_result.has_extension("unit_locations"): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, backend="matplotlib", ax=ax1 + sorting_result, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, backend="matplotlib", ax=ax1 ) - unit_locations = we.load_extension("unit_locations").get_data(outputs="by_unit") + unit_locations = sorting_result.get_extension("unit_locations").get_data(outputs="by_unit") unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] ax1.set_xlim(x - 80, x + 80) @@ -106,7 +105,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax2 = fig.add_subplot(gs[:2, 1]) w = UnitWaveformsWidget( - we, + sorting_result, unit_ids=[unit_id], unit_colors=unit_colors, plot_templates=True, @@ -121,7 +120,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax3 = fig.add_subplot(gs[:2, 2]) UnitWaveformDensityMapWidget( - we, + sorting_result, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, @@ -131,10 +130,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) ax3.set_ylabel(None) - if we.has_extension("correlograms"): + if sorting_result.has_extension("correlograms"): ax4 = fig.add_subplot(gs[:2, 3]) AutoCorrelogramsWidget( - we, + sorting_result, unit_ids=[unit_id], unit_colors=unit_colors, backend="matplotlib", @@ -144,12 +143,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_title(None) ax4.set_yticks([]) - if we.has_extension("spike_amplitudes"): + if sorting_result.has_extension("spike_amplitudes"): ax5 = fig.add_subplot(gs[2, :3]) ax6 = fig.add_subplot(gs[2, 3]) axes = np.array([ax5, ax6]) AmplitudesWidget( - we, + sorting_result, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index d990d2055a..b2ff92605d 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -93,7 +93,8 @@ def __init__( all_hist2d = {} wf_ext = sorting_result.get_extension("waveforms") - for unit_index, unit_id in enumerate(unit_ids): + for i, unit_id in enumerate(unit_ids): + unit_index = sorting_result.sorting.id_to_index(unit_id) chan_inds = channel_inds[unit_id] # this have already the sparsity @@ -102,7 +103,7 @@ def __init__( wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) if sparsity is not None: # external sparsity - wfs = wfs[:, sparsity.mask[:, unit_index]] + wfs = wfs[:, :, sparsity.mask[unit_index, :]] if use_max_channel: chan_ind = max_channels[unit_id] @@ -142,7 +143,8 @@ def __init__( # plot median templates_flat = {} - for unit_index, unit_id in enumerate(unit_ids): + for i, unit_id in enumerate(unit_ids): + unit_index = sorting_result.sorting.id_to_index(unit_id) chan_inds = channel_inds[unit_id] template = templates[unit_index, :, chan_inds] template_flat = template.flatten() From 9ad5a0d3f9b5c25db13e0c37affdbb7a4f4a8521 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 12 Feb 2024 13:34:33 +0100 Subject: [PATCH 048/192] wip export_report() and export_to_phy() --- src/spikeinterface/core/recording_tools.py | 1 + src/spikeinterface/core/sortingresult.py | 3 + src/spikeinterface/exporters/report.py | 47 ++++--- src/spikeinterface/exporters/tests/common.py | 35 ++--- .../exporters/tests/test_export_to_phy.py | 59 ++++---- .../exporters/tests/test_report.py | 7 +- src/spikeinterface/exporters/to_phy.py | 131 ++++++++---------- .../postprocessing/tests/test_correlograms.py | 10 +- .../widgets/unit_waveforms_density_map.py | 2 +- 9 files changed, 143 insertions(+), 152 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index ecb8564163..26f94cb84e 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -898,5 +898,6 @@ def get_rec_attributes(recording): num_samples=[recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())], is_filtered=recording.is_filtered(), properties=properties_to_attrs, + dtype=recording.get_dtype(), ) return rec_attributes diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 8b4d0022ee..f4fc6c1a71 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -735,6 +735,9 @@ def get_recording_property(self, key) -> np.ndarray: def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) + + def get_dtype(self): + return self.rec_attributes["dtype"] ## extensions zone diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index c937f9cb4c..9c85a3c860 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -11,7 +11,7 @@ def export_report( - waveform_extractor, + sorting_result, output_folder, remove_if_exists=False, format="png", @@ -28,8 +28,8 @@ def export_report( Parameters ---------- - waveform_extractor: a WaveformExtractor or None - If WaveformExtractor is provide then the compute is faster otherwise + sorting_result: SortingResult + A SortingResult object output_folder: str The output folder where the report files are saved remove_if_exists: bool, default: False @@ -48,15 +48,15 @@ def export_report( import matplotlib.pyplot as plt job_kwargs = fix_job_kwargs(job_kwargs) - we = waveform_extractor - sorting = we.sorting - unit_ids = sorting.unit_ids + sorting = sorting_result.sorting + unit_ids = sorting_result.unit_ids # load or compute spike_amplitudes - if we.has_extension("spike_amplitudes"): - spike_amplitudes = we.load_extension("spike_amplitudes").get_data(outputs="by_unit") + if sorting_result.has_extension("spike_amplitudes"): + spike_amplitudes = sorting_result.get_extension("spike_amplitudes").get_data(outputs="by_unit") elif force_computation: - spike_amplitudes = compute_spike_amplitudes(we, peak_sign=peak_sign, outputs="by_unit", **job_kwargs) + sorting_result.compute("spike_amplitudes", **job_kwargs) + spike_amplitudes = sorting_result.get_extension("spike_amplitudes").get_data(outputs="by_unit") else: spike_amplitudes = None print( @@ -64,10 +64,11 @@ def export_report( ) # load or compute quality_metrics - if we.has_extension("quality_metrics"): - metrics = we.load_extension("quality_metrics").get_data() + if sorting_result.has_extension("quality_metrics"): + metrics = sorting_result.get_extension("quality_metrics").get_data() elif force_computation: - metrics = compute_quality_metrics(we) + sorting_result.compute("quality_metrics") + metrics = sorting_result.get_extension("quality_metrics").get_data() else: metrics = None print( @@ -75,10 +76,10 @@ def export_report( ) # load or compute correlograms - if we.has_extension("correlograms"): - correlograms, bins = we.load_extension("correlograms").get_data() + if sorting_result.has_extension("correlograms"): + correlograms, bins = sorting_result.get_extension("correlograms").get_data() elif force_computation: - correlograms, bins = compute_correlograms(we, window_ms=100.0, bin_ms=1.0) + correlograms, bins = compute_correlograms(sorting_result, window_ms=100.0, bin_ms=1.0) else: correlograms = None print( @@ -86,8 +87,8 @@ def export_report( ) # pre-compute unit locations if not done - if not we.has_extension("unit_locations"): - unit_locations = compute_unit_locations(we) + if not sorting_result.has_extension("unit_locations"): + sorting_result.compute("unit_locations") output_folder = Path(output_folder).absolute() if output_folder.is_dir(): @@ -100,28 +101,28 @@ def export_report( # unit list units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" - units["max_on_channel_id"] = pd.Series(get_template_extremum_channel(we, peak_sign="neg", outputs="id")) - units["amplitude"] = pd.Series(get_template_extremum_amplitude(we, peak_sign="neg")) + units["max_on_channel_id"] = pd.Series(get_template_extremum_channel(sorting_result, peak_sign="neg", outputs="id")) + units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_result, peak_sign="neg")) units.to_csv(output_folder / "unit list.csv", sep="\t") unit_colors = sw.get_unit_colors(sorting) # global figures fig = plt.figure(figsize=(20, 10)) - w = sw.plot_unit_locations(we, figure=fig, unit_colors=unit_colors) + w = sw.plot_unit_locations(sorting_result, figure=fig, unit_colors=unit_colors) fig.savefig(output_folder / f"unit_localization.{format}") if not show_figures: plt.close(fig) fig, ax = plt.subplots(figsize=(20, 10)) - sw.plot_unit_depths(we, ax=ax, unit_colors=unit_colors) + sw.plot_unit_depths(sorting_result, ax=ax, unit_colors=unit_colors) fig.savefig(output_folder / f"unit_depths.{format}") if not show_figures: plt.close(fig) if spike_amplitudes and len(unit_ids) < 100: fig = plt.figure(figsize=(20, 10)) - sw.plot_all_amplitudes_distributions(we, figure=fig, unit_colors=unit_colors) + sw.plot_all_amplitudes_distributions(sorting_result, figure=fig, unit_colors=unit_colors) fig.savefig(output_folder / f"amplitudes_distribution.{format}") if not show_figures: plt.close(fig) @@ -138,7 +139,7 @@ def export_report( constrained_layout=False, figsize=(15, 7), ) - sw.plot_unit_summary(we, unit_id, figure=fig) + sw.plot_unit_summary(sorting_result, unit_id, figure=fig) fig.suptitle(f"unit {unit_id}") fig.savefig(units_folder / f"{unit_id}.{format}") if not show_figures: diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index f2a7e6c034..1ea22144ff 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -3,7 +3,7 @@ import pytest from pathlib import Path -from spikeinterface.core import generate_ground_truth_recording, extract_waveforms +from spikeinterface.core import generate_ground_truth_recording, start_sorting_result from spikeinterface.postprocessing import ( compute_spike_amplitudes, compute_template_similarity, @@ -17,7 +17,7 @@ cache_folder = Path("cache_folder") / "exporters" -def make_waveforms_extractor(sparse=True, with_group=False): +def make_sorting_result(sparse=True, with_group=False): recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=28000.0, @@ -39,30 +39,33 @@ def make_waveforms_extractor(sparse=True, with_group=False): recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1]) sorting.set_property("group", [0, 0, 1, 1]) - we = extract_waveforms(recording=recording, sorting=sorting, folder=None, mode="memory", sparse=sparse) - compute_principal_components(we) - compute_spike_amplitudes(we) - compute_template_similarity(we) - compute_quality_metrics(we, metric_names=["snr"]) + sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) + sorting_result.select_random_spikes() + sorting_result.compute("waveforms") + sorting_result.compute("templates") + sorting_result.compute("noise_levels") + sorting_result.compute("principal_components") + sorting_result.compute("template_similarity") + sorting_result.compute("quality_metrics", metric_names=["snr"]) - return we + return sorting_result @pytest.fixture(scope="module") -def waveforms_extractor_dense_for_export(): - return make_waveforms_extractor(sparse=False) +def sorting_result_dense_for_export(): + return make_sorting_result(sparse=False) @pytest.fixture(scope="module") -def waveforms_extractor_with_group_for_export(): - return make_waveforms_extractor(sparse=False, with_group=True) +def sorting_result_with_group_for_export(): + return make_sorting_result(sparse=False, with_group=True) @pytest.fixture(scope="module") -def waveforms_extractor_sparse_for_export(): - return make_waveforms_extractor(sparse=True) +def sorting_result_sparse_for_export(): + return make_sorting_result(sparse=True) if __name__ == "__main__": - we = make_waveforms_extractor(sparse=False) - print(we) + sorting_result = make_sorting_result(sparse=False) + print(sorting_result) diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 52dd383913..0d73c2ce0c 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -10,26 +10,21 @@ from spikeinterface.core import compute_sparsity from spikeinterface.exporters import export_to_phy -from spikeinterface.exporters.tests.common import ( - cache_folder, - make_waveforms_extractor, - waveforms_extractor_sparse_for_export, - waveforms_extractor_dense_for_export, - waveforms_extractor_with_group_for_export, -) +from spikeinterface.exporters.tests.common import (cache_folder, make_sorting_result) -def test_export_to_phy(waveforms_extractor_sparse_for_export): + +def test_export_to_phy(sorting_result_sparse_for_export): output_folder1 = cache_folder / "phy_output_1" output_folder2 = cache_folder / "phy_output_2" for f in (output_folder1, output_folder2): if f.is_dir(): shutil.rmtree(f) - waveform_extractor = waveforms_extractor_sparse_for_export + sorting_result = sorting_result_sparse_for_export export_to_phy( - waveform_extractor, + sorting_result, output_folder1, compute_pc_features=True, compute_amplitudes=True, @@ -40,7 +35,7 @@ def test_export_to_phy(waveforms_extractor_sparse_for_export): # Test for previous crash when copy_binary=False. export_to_phy( - waveform_extractor, + sorting_result, output_folder2, compute_pc_features=False, compute_amplitudes=False, @@ -51,7 +46,7 @@ def test_export_to_phy(waveforms_extractor_sparse_for_export): ) -def test_export_to_phy_by_property(waveforms_extractor_with_group_for_export): +def test_export_to_phy_by_property(sorting_result_with_group_for_export): output_folder = cache_folder / "phy_output" output_folder_rm = cache_folder / "phy_output_rm" @@ -59,11 +54,11 @@ def test_export_to_phy_by_property(waveforms_extractor_with_group_for_export): if f.is_dir(): shutil.rmtree(f) - waveform_extractor = waveforms_extractor_with_group_for_export + sorting_result = sorting_result_with_group_for_export - sparsity_group = compute_sparsity(waveform_extractor, method="by_property", by_property="group") + sparsity_group = compute_sparsity(sorting_result, method="by_property", by_property="group") export_to_phy( - waveform_extractor, + sorting_result, output_folder, compute_pc_features=True, compute_amplitudes=True, @@ -74,15 +69,15 @@ def test_export_to_phy_by_property(waveforms_extractor_with_group_for_export): ) template_inds = np.load(output_folder / "template_ind.npy") - assert template_inds.shape == (waveform_extractor.unit_ids.size, 4) + assert template_inds.shape == (sorting_result.unit_ids.size, 4) # Remove one channel # recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7]) - # waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm, sparse=False) - # sparsity_group = compute_sparsity(waveform_extractor_rm, method="by_property", by_property="group") + # sorting_result_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm, sparse=False) + # sparsity_group = compute_sparsity(sorting_result_rm, method="by_property", by_property="group") # export_to_phy( - # waveform_extractor_rm, + # sorting_result_rm, # output_folder_rm, # compute_pc_features=True, # compute_amplitudes=True, @@ -97,18 +92,18 @@ def test_export_to_phy_by_property(waveforms_extractor_with_group_for_export): # assert len(np.where(template_inds == -1)[0]) > 0 -def test_export_to_phy_by_sparsity(waveforms_extractor_dense_for_export): +def test_export_to_phy_by_sparsity(sorting_result_dense_for_export): output_folder_radius = cache_folder / "phy_output_radius" output_folder_multi_sparse = cache_folder / "phy_output_multi_sparse" for f in (output_folder_radius, output_folder_multi_sparse): if f.is_dir(): shutil.rmtree(f) - waveform_extractor = waveforms_extractor_dense_for_export + sorting_result = sorting_result_dense_for_export - sparsity_radius = compute_sparsity(waveform_extractor, method="radius", radius_um=50.0) + sparsity_radius = compute_sparsity(sorting_result, method="radius", radius_um=50.0) export_to_phy( - waveform_extractor, + sorting_result, output_folder_radius, compute_pc_features=True, compute_amplitudes=True, @@ -125,10 +120,10 @@ def test_export_to_phy_by_sparsity(waveforms_extractor_dense_for_export): assert -1 in pc_ind # pre-compute PC with another sparsity - sparsity_radius_small = compute_sparsity(waveform_extractor, method="radius", radius_um=30.0) - pc = compute_principal_components(waveform_extractor, sparsity=sparsity_radius_small) + sparsity_radius_small = compute_sparsity(sorting_result, method="radius", radius_um=30.0) + pc = compute_principal_components(sorting_result, sparsity=sparsity_radius_small) export_to_phy( - waveform_extractor, + sorting_result, output_folder_multi_sparse, compute_pc_features=True, compute_amplitudes=True, @@ -148,10 +143,10 @@ def test_export_to_phy_by_sparsity(waveforms_extractor_dense_for_export): if __name__ == "__main__": - we_sparse = make_waveforms_extractor(sparse=True) - we_group = make_waveforms_extractor(sparse=False, with_group=True) - we_dense = make_waveforms_extractor(sparse=False) + sorting_result_sparse = make_sorting_result(sparse=True) + sorting_result_group = make_sorting_result(sparse=False, with_group=True) + sorting_result_dense = make_sorting_result(sparse=False) - test_export_to_phy(we_sparse) - test_export_to_phy_by_property(we_group) - test_export_to_phy_by_sparsity(we_dense) + test_export_to_phy(sorting_result_sparse) + test_export_to_phy_by_property(sorting_result_group) + test_export_to_phy_by_sparsity(sorting_result_dense) diff --git a/src/spikeinterface/exporters/tests/test_report.py b/src/spikeinterface/exporters/tests/test_report.py index ee1a9b6b31..5ad01f7609 100644 --- a/src/spikeinterface/exporters/tests/test_report.py +++ b/src/spikeinterface/exporters/tests/test_report.py @@ -7,8 +7,7 @@ from spikeinterface.exporters.tests.common import ( cache_folder, - make_waveforms_extractor, - waveforms_extractor_sparse_for_export, + make_sorting_result, ) @@ -24,5 +23,5 @@ def test_export_report(waveforms_extractor_sparse_for_export): if __name__ == "__main__": - we = make_waveforms_extractor(sparse=True) - test_export_report(we) + sorting_result = make_sorting_result(sparse=True) + test_export_report(sorting_result) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 607aa3e846..03555c049e 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -11,20 +11,21 @@ from spikeinterface.core import ( write_binary_recording, BinaryRecordingExtractor, - WaveformExtractor, BinaryFolderRecording, ChannelSparsity, + SortingResult ) from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs from spikeinterface.postprocessing import ( compute_spike_amplitudes, compute_template_similarity, compute_principal_components, + ) def export_to_phy( - waveform_extractor: WaveformExtractor, + sorting_result: SortingResult, output_folder: str | Path, compute_pc_features: bool = True, compute_amplitudes: bool = True, @@ -43,8 +44,8 @@ def export_to_phy( Parameters ---------- - waveform_extractor: a WaveformExtractor or None - If WaveformExtractor is provide then the compute is faster otherwise + sorting_result: SortingResult + A SortingResult object output_folder: str | Path The output folder where the phy template-gui files are saved compute_pc_features: bool, default: True @@ -60,7 +61,7 @@ def export_to_phy( peak_sign: "neg" | "pos" | "both", default: "neg" Used by compute_spike_amplitudes template_mode: str, default: "median" - Parameter "mode" to be given to WaveformExtractor.get_template() + Parameter "mode" to be given to SortingResult.get_template() dtype: dtype or None, default: None Dtype to save binary data verbose: bool, default: True @@ -74,35 +75,34 @@ def export_to_phy( import pandas as pd assert isinstance( - waveform_extractor, spikeinterface.core.waveform_extractor.WaveformExtractor - ), "waveform_extractor must be a WaveformExtractor object" - sorting = waveform_extractor.sorting + sorting_result, SortingResult), "sorting_result must be a SortingResult object" + sorting = sorting_result.sorting assert ( - waveform_extractor.get_num_segments() == 1 - ), f"Export to phy only works with one segment, your extractor has {waveform_extractor.get_num_segments()} segments" - num_chans = waveform_extractor.get_num_channels() - fs = waveform_extractor.sampling_frequency + sorting_result.get_num_segments() == 1 + ), f"Export to phy only works with one segment, your extractor has {sorting_result.get_num_segments()} segments" + num_chans = sorting_result.get_num_channels() + fs = sorting_result.sampling_frequency job_kwargs = fix_job_kwargs(job_kwargs) # check sparsity - if (num_chans > 64) and (sparsity is None and not waveform_extractor.is_sparse()): + if (num_chans > 64) and (sparsity is None and not sorting_result.is_sparse()): warnings.warn( "Exporting to Phy with many channels and without sparsity might result in a heavy and less " - "informative visualization. You can use use a sparse WaveformExtractor or you can use the 'sparsity' " + "informative visualization. You can use use a sparse SortingResult or you can use the 'sparsity' " "argument to enforce sparsity (see compute_sparsity())" ) save_sparse = True - if waveform_extractor.is_sparse(): - used_sparsity = waveform_extractor.sparsity + if sorting_result.is_sparse(): + used_sparsity = sorting_result.sparsity if sparsity is not None: - warnings.warn("If the waveform_extractor is sparse the 'sparsity' argument is ignored") + warnings.warn("If the sorting_result is sparse the 'sparsity' argument is ignored") elif sparsity is not None: used_sparsity = sparsity else: - used_sparsity = ChannelSparsity.create_dense(waveform_extractor) + used_sparsity = ChannelSparsity.create_dense(sorting_result) save_sparse = False # convenient sparsity dict for the 3 cases to retrieve channl_inds sparse_dict = used_sparsity.unit_id_to_channel_indices @@ -121,7 +121,7 @@ def export_to_phy( if len(unit_ids) == 0: raise Exception("No non-empty units in the sorting result, can't save to Phy.") - output_folder = Path(output_folder).absolute() + output_folder = Path(output_folder).resolve() if output_folder.is_dir(): if remove_if_exists: shutil.rmtree(output_folder) @@ -132,22 +132,19 @@ def export_to_phy( # save dat file if dtype is None: - if waveform_extractor.has_recording(): - dtype = waveform_extractor.recording.get_dtype() - else: - dtype = waveform_extractor.dtype + dtype = sorting_result.get_dtype() - if waveform_extractor.has_recording(): + if sorting_result.has_recording(): if copy_binary: rec_path = output_folder / "recording.dat" - write_binary_recording(waveform_extractor.recording, file_paths=rec_path, dtype=dtype, **job_kwargs) - elif isinstance(waveform_extractor.recording, BinaryRecordingExtractor): - if isinstance(waveform_extractor.recording, BinaryFolderRecording): - bin_kwargs = waveform_extractor.recording._bin_kwargs + write_binary_recording(sorting_result.recording, file_paths=rec_path, dtype=dtype, **job_kwargs) + elif isinstance(sorting_result.recording, BinaryRecordingExtractor): + if isinstance(sorting_result.recording, BinaryFolderRecording): + bin_kwargs = sorting_result.recording._bin_kwargs else: - bin_kwargs = waveform_extractor.recording._kwargs + bin_kwargs = sorting_result.recording._kwargs rec_path = bin_kwargs["file_paths"][0] - dtype = waveform_extractor.recording.get_dtype() + dtype = sorting_result.recording.get_dtype() else: rec_path = "None" else: # don't save recording.dat @@ -172,7 +169,7 @@ def export_to_phy( f.write(f"dtype = '{dtype_str}'\n") f.write(f"offset = 0\n") f.write(f"sample_rate = {fs}\n") - f.write(f"hp_filtered = {waveform_extractor.is_filtered()}") + f.write(f"hp_filtered = {sorting_result.recording.is_filtered()}") # export spike_times/spike_templates/spike_clusters # here spike_labels is a remapping to unit_index @@ -185,22 +182,23 @@ def export_to_phy( # export templates/templates_ind/similar_templates # shape (num_units, num_samples, max_num_channels) + templates_ext = sorting_result.get_extension("templates") + templates_ext is not None, "export_to_phy need SortingResult with extension 'templates'" max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) - num_samples = waveform_extractor.nbefore + waveform_extractor.nafter + dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) + num_samples = dense_templates.shape[1] templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64") # here we pad template inds with -1 if len of sparse channels is unequal templates_ind = -np.ones((len(unit_ids), max_num_channels), dtype="int64") for unit_ind, unit_id in enumerate(unit_ids): chan_inds = sparse_dict[unit_id] - template = waveform_extractor.get_template(unit_id, mode=template_mode, sparsity=sparsity) + template = dense_templates[unit_ind][:, chan_inds] templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - if waveform_extractor.has_extension("similarity"): - tmc = waveform_extractor.load_extension("similarity") - template_similarity = tmc.get_data() - else: - template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") + if not sorting_result.has_extension("template_similarity"): + sorting_result.compute("template_similarity") + template_similarity = sorting_result.get_extension("template_similarity").get_data() np.save(str(output_folder / "templates.npy"), templates) if save_sparse: @@ -208,9 +206,9 @@ def export_to_phy( np.save(str(output_folder / "similar_templates.npy"), template_similarity) channel_maps = np.arange(num_chans, dtype="int32") - channel_map_si = waveform_extractor.channel_ids - channel_positions = waveform_extractor.get_channel_locations().astype("float32") - channel_groups = waveform_extractor.get_recording_property("group") + channel_map_si = sorting_result.channel_ids + channel_positions = sorting_result.get_channel_locations().astype("float32") + channel_groups = sorting_result.get_recording_property("group") if channel_groups is None: channel_groups = np.zeros(num_chans, dtype="int32") np.save(str(output_folder / "channel_map.npy"), channel_maps) @@ -219,36 +217,28 @@ def export_to_phy( np.save(str(output_folder / "channel_groups.npy"), channel_groups) if compute_amplitudes: - if waveform_extractor.has_extension("spike_amplitudes"): - sac = waveform_extractor.load_extension("spike_amplitudes") - amplitudes = sac.get_data(outputs="concatenated") - else: - amplitudes = compute_spike_amplitudes( - waveform_extractor, peak_sign=peak_sign, outputs="concatenated", **job_kwargs - ) - # one segment only - amplitudes = amplitudes[0][:, np.newaxis] + if not sorting_result.has_extension("spike_amplitudes"): + sorting_result.compute("spike_amplitudes", **job_kwargs) + amplitudes = sorting_result.get_extension("spike_amplitudes").get_data() + amplitudes = amplitudes[:, np.newaxis] np.save(str(output_folder / "amplitudes.npy"), amplitudes) if compute_pc_features: - if waveform_extractor.has_extension("principal_components"): - pc = waveform_extractor.load_extension("principal_components") - else: - pc = compute_principal_components( - waveform_extractor, n_components=5, mode="by_channel_local", sparsity=sparsity - ) - pc_sparsity = pc.get_sparsity() - if pc_sparsity is None: - pc_sparsity = used_sparsity - max_num_channels_pc = max(len(chan_inds) for chan_inds in pc_sparsity.unit_id_to_channel_indices.values()) - - pc.run_for_all_spikes(output_folder / "pc_features.npy", **job_kwargs) - - pc_feature_ind = -np.ones((len(unit_ids), max_num_channels_pc), dtype="int64") - for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = pc_sparsity.unit_id_to_channel_indices[unit_id] - pc_feature_ind[unit_ind, : len(chan_inds)] = chan_inds - np.save(str(output_folder / "pc_feature_ind.npy"), pc_feature_ind) + if not sorting_result.has_extension("principal_components"): + sorting_result.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + + # pc_sparsity = pc.get_sparsity() + # if pc_sparsity is None: + # pc_sparsity = used_sparsity + # max_num_channels_pc = max(len(chan_inds) for chan_inds in pc_sparsity.unit_id_to_channel_indices.values()) + raise NotImplementedError() + # pc.run_for_all_spikes(output_folder / "pc_features.npy", **job_kwargs) + + # pc_feature_ind = -np.ones((len(unit_ids), max_num_channels_pc), dtype="int64") + # for unit_ind, unit_id in enumerate(unit_ids): + # chan_inds = pc_sparsity.unit_id_to_channel_indices[unit_id] + # pc_feature_ind[unit_ind, : len(chan_inds)] = chan_inds + # np.save(str(output_folder / "pc_feature_ind.npy"), pc_feature_ind) # Save .tsv metadata cluster_group = pd.DataFrame( @@ -264,9 +254,8 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if waveform_extractor.has_extension("quality_metrics"): - qm = waveform_extractor.load_extension("quality_metrics") - qm_data = qm.get_data() + if sorting_result.has_extension("quality_metrics"): + qm_data = sorting_result.get_extension("quality_metrics").get_data() for column_name in qm_data.columns: # already computed by phy if column_name not in ["num_spikes", "firing_rate"]: diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 0c17371576..40a9a603b2 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -12,7 +12,7 @@ from spikeinterface import NumpySorting, generate_sorting from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeCorrelograms -from spikeinterface.postprocessing.correlograms import _compute_correlograms, _make_bins +from spikeinterface.postprocessing.correlograms import compute_correlograms_on_sorting, _make_bins @@ -45,7 +45,7 @@ def test_make_bins(): def _test_correlograms(sorting, window_ms, bin_ms, methods): for method in methods: - correlograms, bins = _compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) if method == "numpy": ref_correlograms = correlograms ref_bins = bins @@ -89,7 +89,7 @@ def test_flat_cross_correlogram(): # ~ fig, ax = plt.subplots() for method in methods: - correlograms, bins = _compute_correlograms(sorting, window_ms=50.0, bin_ms=1.0, method=method) + correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=50.0, bin_ms=1.0, method=method) cc = correlograms[0, 1, :].copy() m = np.mean(cc) assert np.all(cc > (m * 0.90)) @@ -121,7 +121,7 @@ def test_auto_equal_cross_correlograms(): sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=10000.0) for method in methods: - correlograms, bins = _compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) + correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=10.0, bin_ms=0.1, method=method) num_half_bins = correlograms.shape[2] // 2 @@ -171,7 +171,7 @@ def test_detect_injected_correlation(): sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=sampling_frequency) for method in methods: - correlograms, bins = _compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) + correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=10.0, bin_ms=0.1, method=method) cc_01 = correlograms[0, 1, :] cc_10 = correlograms[1, 0, :] diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index b2ff92605d..7617154b3e 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -146,7 +146,7 @@ def __init__( for i, unit_id in enumerate(unit_ids): unit_index = sorting_result.sorting.id_to_index(unit_id) chan_inds = channel_inds[unit_id] - template = templates[unit_index, :, chan_inds] + template = templates[i, :, chan_inds] template_flat = template.flatten() templates_flat[unit_id] = template_flat From 1845d745720b9ad4418ecf1411906ed5db70d80d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 12 Feb 2024 15:48:09 +0100 Subject: [PATCH 049/192] Remove WaveformExtractor at many places in curation/core/export --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/result_core.py | 4 +- src/spikeinterface/core/sparsity.py | 6 +- .../core/tests/test_generate.py | 2 +- .../core/tests/test_waveform_tools.py | 45 ++++++++++++- src/spikeinterface/core/waveform_tools.py | 64 ++++++++++++++++++- src/spikeinterface/curation/auto_merge.py | 52 ++++++--------- .../curation/remove_redundant.py | 30 ++++----- src/spikeinterface/curation/tests/common.py | 46 +++++++++++++ .../curation/tests/test_auto_merge.py | 37 +++++++---- .../curation/tests/test_remove_redundant.py | 40 +++++------- src/spikeinterface/exporters/tests/common.py | 6 -- .../exporters/tests/test_export_to_phy.py | 2 +- .../postprocessing/principal_component.py | 11 ++-- .../preprocessing/remove_artifacts.py | 38 +++++------ .../tests/test_remove_artifacts.py | 12 ++-- .../tests/test_metrics_functions.py | 1 - 17 files changed, 264 insertions(+), 134 deletions(-) create mode 100644 src/spikeinterface/curation/tests/common.py diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 0f07139934..1bac4697e2 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -103,7 +103,7 @@ ) from .sorting_tools import spike_vector_to_spike_trains -from .waveform_tools import extract_waveforms_to_buffers +from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_average from .snippets_tools import snippets_from_sorting # waveform extractor diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 5eaaff3f26..573596a205 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -10,7 +10,7 @@ import numpy as np from .sortingresult import ResultExtension, register_result_extension -from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates +from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_average from .recording_tools import get_noise_levels class ComputeWaveforms(ResultExtension): @@ -355,7 +355,7 @@ def _run(self, **job_kwargs): return_scaled = self.params["return_scaled"] # TODO jobw_kwargs - self.data["average"] = estimate_templates(recording, some_spikes, unit_ids, self.nbefore, self.nafter, return_scaled=return_scaled, **job_kwargs) + self.data["average"] = estimate_templates_average(recording, some_spikes, unit_ids, self.nbefore, self.nafter, return_scaled=return_scaled, **job_kwargs) def _set_params(self, ms_before: float = 1.0, diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index c4e703d911..b21cf66e1d 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -7,7 +7,7 @@ from .recording_tools import get_noise_levels from .sorting_tools import random_spikes_selection from .job_tools import _shared_job_kwargs_doc -from .waveform_tools import estimate_templates +from .waveform_tools import estimate_templates_average _sparsity_doc = """ @@ -497,7 +497,7 @@ def estimate_sparsity( * all units are computed in one read of recording * it doesn't require a folder * it doesn't consume too much memory - * it uses internally the `estimate_templates()` which is fast and parallel + * it uses internally the `estimate_templates_average()` which is fast and parallel Parameters ---------- @@ -552,7 +552,7 @@ def estimate_sparsity( spikes = sorting.to_spike_vector() spikes = spikes[random_spikes_indices] - templates_array = estimate_templates( + templates_array = estimate_templates_average( recording, spikes, sorting.unit_ids, diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 79d61a5ff2..fa92542596 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -3,7 +3,7 @@ import numpy as np -from spikeinterface.core import load_extractor, extract_waveforms +from spikeinterface.core import load_extractor from probeinterface import generate_multi_columns_probe from spikeinterface.core.generate import ( diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 71d30495d8..5e36df5186 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -11,6 +11,7 @@ extract_waveforms_to_single_buffer, split_waveforms_by_units, estimate_templates, + estimate_templates_average, ) @@ -162,7 +163,7 @@ def test_waveform_tools(): _check_all_wf_equal(list_wfs_sparse) -def test_estimate_templates(): +def test_estimate_templates_average(): recording, sorting = get_dataset() ms_before = 1.0 @@ -177,7 +178,7 @@ def test_estimate_templates(): job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - templates = estimate_templates( + templates = estimate_templates_average( recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs ) print(templates.shape) @@ -193,7 +194,45 @@ def test_estimate_templates(): # ax.plot(templates[unit_index, :, :].T.flatten()) # plt.show() +def test_estimate_templates(): + recording, sorting = get_dataset() + + ms_before = 1.0 + ms_after = 1.5 + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + spikes = sorting.to_spike_vector() + # take one spikes every 10 + spikes = spikes[::10] + + job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + for operator in ("average", "median"): + templates = estimate_templates( + recording, spikes, sorting.unit_ids, nbefore, nafter, operator=operator, return_scaled=True, **job_kwargs + ) + # print(templates.shape) + assert templates.shape[0] == sorting.unit_ids.size + assert templates.shape[1] == nbefore + nafter + assert templates.shape[2] == recording.get_num_channels() + + assert np.any(templates != 0) + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # for unit_index, unit_id in enumerate(sorting.unit_ids): + # ax.plot(templates[unit_index, :, :].T.flatten()) + + # plt.show() + + + + + if __name__ == "__main__": # test_waveform_tools() - test_estimate_templates() + # test_estimate_templates_average() + test_estimate_templates() \ No newline at end of file diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 870846148e..af40ea047a 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -408,7 +408,7 @@ def extract_waveforms_to_single_buffer( file_path=None, dtype=None, sparsity_mask=None, - copy=False, + copy=True, job_name=None, **job_kwargs, ): @@ -700,6 +700,66 @@ def has_exceeding_spikes(recording, sorting): def estimate_templates( + recording: BaseRecording, + spikes: np.ndarray, + unit_ids: list | np.ndarray, + nbefore: int, + nafter: int, + operator: str ="average", + return_scaled: bool = True, + job_name=None, + **job_kwargs, +): + """ + Estimate dense templates with "average" or "median". + If "average" internaly estimate_templates_average() is used to saved memory/ + + Parameters + ---------- + + recording: BaseRecording + The recording object + spikes: 1d numpy array with several fields + Spikes handled as a unique vector. + This vector can be obtained with: `spikes = sorting.to_spike_vector()` + unit_ids: list ot numpy + List of unit_ids + nbefore: int + Number of samples to cut out before a spike + nafter: int + Number of samples to cut out after a spike + return_scaled: bool, default: True + If True, the traces are scaled before averaging + + Returns + ------- + templates_array: np.array + The average templates with shape (num_units, nbefore + nafter, num_channels) + + """ + + if job_name is None: + job_name = "estimate_templates" + + if operator == "average": + templates_array = estimate_templates_average(recording, spikes, unit_ids, nbefore, nafter, return_scaled=return_scaled, job_name=job_name, **job_kwargs) + elif operator == "median": + all_waveforms, wf_array_info = extract_waveforms_to_single_buffer( recording, spikes, unit_ids, nbefore, nafter, + mode="shared_memory", return_scaled=return_scaled, copy=False,**job_kwargs,) + templates_array = np.zeros((len(unit_ids), all_waveforms.shape[1], all_waveforms.shape[2]), dtype=all_waveforms.dtype) + for unit_index , unit_id in enumerate(unit_ids): + wfs = all_waveforms[spikes["unit_index"] == unit_index] + templates_array[unit_index, :, :] = np.median(wfs, axis=0) + # release shared memory after the median + wf_array_info["shm"].unlink() + + else: + raise ValueError(f"estimate_templates(..., operator={operator}) wrong operator must be average or median") + + return templates_array + + +def estimate_templates_average( recording: BaseRecording, spikes: np.ndarray, unit_ids: list | np.ndarray, @@ -773,7 +833,7 @@ def estimate_templates( ) if job_name is None: - job_name = "estimate_templates" + job_name = "estimate_templates_average" processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 099a8337ea..85a44331f0 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -2,6 +2,7 @@ import numpy as np +from ..core import start_sorting_result from ..core.template_tools import get_template_extremum_channel from ..postprocessing import compute_correlograms from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates @@ -10,7 +11,7 @@ def get_potential_auto_merge( - waveform_extractor, + sorting_result, minimum_spikes=1000, maximum_distance_um=150.0, peak_sign="neg", @@ -56,8 +57,8 @@ def get_potential_auto_merge( Parameters ---------- - waveform_extractor: WaveformExtractor - The waveform extractor + sorting_result: SortingResult + The SortingResult minimum_spikes: int, default: 1000 Minimum number of spikes for each unit to consider a potential merge. Enough spikes are needed to estimate the correlogram @@ -112,8 +113,7 @@ def get_potential_auto_merge( """ import scipy - we = waveform_extractor - sorting = we.sorting + sorting = sorting_result.sorting unit_ids = sorting.unit_ids # to get fast computation we will not analyse pairs when: @@ -144,7 +144,7 @@ def get_potential_auto_merge( # STEP 2 : remove contaminated auto corr if "remove_contaminated" in steps: contaminations, nb_violations = compute_refrac_period_violations( - we, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_result, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) @@ -154,8 +154,8 @@ def get_potential_auto_merge( # STEP 3 : unit positions are estimated roughly with channel if "unit_positions" in steps: - chan_loc = we.get_channel_locations() - unit_max_chan = get_template_extremum_channel(we, peak_sign=peak_sign, mode="extremum", outputs="index") + chan_loc = sorting_result.get_channel_locations() + unit_max_chan = get_template_extremum_channel(sorting_result, peak_sign=peak_sign, mode="extremum", outputs="index") unit_max_chan = list(unit_max_chan.values()) unit_locations = chan_loc[unit_max_chan, :] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") @@ -187,7 +187,7 @@ def get_potential_auto_merge( # STEP 5 : check if potential merge with CC also have template similarity if "template_similarity" in steps: - templates = we.get_all_templates(mode="average") + templates = sorting_result.get_extension("templates").get_templates(operator="average") templates_diff = compute_templates_diff( sorting, templates, num_channels=num_channels, num_shift=num_shift, pair_mask=pair_mask ) @@ -196,7 +196,7 @@ def get_potential_auto_merge( # STEP 6 : validate the potential merges with CC increase the contamination quality metrics if "check_increase_score" in steps: pair_mask, pairs_decreased_score = check_improve_contaminations_score( - we, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms + sorting_result, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms ) # FINAL STEP : create the final list from pair_mask boolean matrix @@ -421,25 +421,9 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair return templates_diff -class MockWaveformExtractor: - """ - Mock WaveformExtractor to be able to run compute_refrac_period_violations() - needed for the auto_merge() function. - """ - - def __init__(self, recording, sorting): - self.recording = recording - self.sorting = sorting - - def get_total_samples(self): - return self.recording.get_total_samples() - - def get_total_duration(self): - return self.recording.get_total_duration() - def check_improve_contaminations_score( - we, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms + sorting_result, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms ): """ Check that the score is improve afeter a potential merge @@ -451,12 +435,12 @@ def check_improve_contaminations_score( Check that the contamination score is improved (decrease) after a potential merge """ - recording = we.recording - sorting = we.sorting + recording = sorting_result.recording + sorting = sorting_result.sorting pair_mask = pair_mask.copy() pairs_removed = [] - firing_rates = list(compute_firing_rates(we).values()) + firing_rates = list(compute_firing_rates(sorting_result).values()) inds1, inds2 = np.nonzero(pair_mask) for i in range(inds1.size): @@ -473,14 +457,14 @@ def check_improve_contaminations_score( sorting_merged = MergeUnitsSorting( sorting, [[unit_id1, unit_id2]], new_unit_ids=[unit_id1], delta_time_ms=censored_period_ms ).select_units([unit_id1]) - # make a lazy fake WaveformExtractor to compute contamination and firing rate - we_new = MockWaveformExtractor(recording, sorting_merged) + + sorting_result_new = start_sorting_result(sorting_merged, recording, format="memory", sparse=False) new_contaminations, _ = compute_refrac_period_violations( - we_new, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_result_new, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms ) c_new = new_contaminations[unit_id1] - f_new = compute_firing_rates(we_new)[unit_id1] + f_new = compute_firing_rates(sorting_result_new)[unit_id1] # old and new scores k = 1 + firing_contamination_balance diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index 21162b0bda..11bf6b15e2 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -1,7 +1,7 @@ from __future__ import annotations import numpy as np -from spikeinterface import WaveformExtractor +from spikeinterface import SortingResult from ..core.template_tools import get_template_extremum_channel_peak_shift, get_template_amplitudes from ..postprocessing import align_sorting @@ -11,7 +11,7 @@ def remove_redundant_units( - sorting_or_waveform_extractor, + sorting_or_sorting_result, align=True, unit_peak_shifts=None, delta_time=0.4, @@ -33,12 +33,12 @@ def remove_redundant_units( Parameters ---------- - sorting_or_waveform_extractor : BaseSorting or WaveformExtractor - If WaveformExtractor, the spike trains can be optionally realigned using the peak shift in the + sorting_or_sorting_result : BaseSorting or SortingResult + If SortingResult, the spike trains can be optionally realigned using the peak shift in the template to improve the matching procedure. If BaseSorting, the spike trains are not aligned. align : bool, default: False - If True, spike trains are aligned (if a WaveformExtractor is used) + If True, spike trains are aligned (if a SortingResult is used) delta_time : float, default: 0.4 The time in ms to consider matching spikes agreement_threshold : float, default: 0.2 @@ -65,17 +65,17 @@ def remove_redundant_units( Sorting object without redundant units """ - if isinstance(sorting_or_waveform_extractor, WaveformExtractor): - sorting = sorting_or_waveform_extractor.sorting - we = sorting_or_waveform_extractor + if isinstance(sorting_or_sorting_result, SortingResult): + sorting = sorting_or_sorting_result.sorting + sorting_result = sorting_or_sorting_result else: - assert not align, "The 'align' option is only available when a WaveformExtractor is used as input" - sorting = sorting_or_waveform_extractor - we = None + assert not align, "The 'align' option is only available when a SortingResult is used as input" + sorting = sorting_or_sorting_result + sorting_result = None if align and unit_peak_shifts is None: - assert we is not None, "For align=True must give a WaveformExtractor or explicit unit_peak_shifts" - unit_peak_shifts = get_template_extremum_channel_peak_shift(we) + assert sorting_result is not None, "For align=True must give a SortingResult or explicit unit_peak_shifts" + unit_peak_shifts = get_template_extremum_channel_peak_shift(sorting_result) if align: sorting_aligned = align_sorting(sorting, unit_peak_shifts) @@ -93,7 +93,7 @@ def remove_redundant_units( if remove_strategy in ("minimum_shift", "highest_amplitude"): # this is the values at spike index ! - peak_values = get_template_amplitudes(we, peak_sign=peak_sign, mode="at_index") + peak_values = get_template_amplitudes(sorting_result, peak_sign=peak_sign, mode="at_index") peak_values = {unit_id: np.max(np.abs(values)) for unit_id, values in peak_values.items()} if remove_strategy == "minimum_shift": @@ -125,7 +125,7 @@ def remove_redundant_units( elif remove_strategy == "with_metrics": # TODO # @aurelien @alessio - # here we can implement the choice of the best one given an external metrics table + # here sorting_result can implement the choice of the best one given an external metrics table # this will be implemented in a futur PR by the first who need it! raise NotImplementedError() else: diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py new file mode 100644 index 0000000000..af1163fb4b --- /dev/null +++ b/src/spikeinterface/curation/tests/common.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import pytest +from pathlib import Path + +from spikeinterface.core import generate_ground_truth_recording, start_sorting_result +from spikeinterface.qualitymetrics import compute_quality_metrics + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "curation" +else: + cache_folder = Path("cache_folder") / "curation" + + +job_kwargs = dict(n_jobs=-1) +def make_sorting_result(sparse=True): + recording, sorting = generate_ground_truth_recording( + durations=[300.0], + sampling_frequency=30000.0, + num_channels=4, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=20.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), + seed=2205, + ) + + sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) + sorting_result.select_random_spikes() + sorting_result.compute("waveforms", **job_kwargs) + sorting_result.compute("templates") + sorting_result.compute("noise_levels") + # sorting_result.compute("principal_components") + # sorting_result.compute("template_similarity") + # sorting_result.compute("quality_metrics", metric_names=["snr"]) + + return sorting_result + + +@pytest.fixture(scope="module") +def sorting_result_for_curation(): + return make_sorting_result(sparse=True) + + +if __name__ == "__main__": + sorting_result = make_sorting_result(sparse=False) + print(sorting_result) diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 068d3e824b..ffdfdfbd81 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -3,13 +3,14 @@ from pathlib import Path import numpy as np -from spikeinterface import WaveformExtractor, load_extractor, extract_waveforms, NumpySorting, set_global_tmp_folder -from spikeinterface.extractors import toy_example - +from spikeinterface.core import start_sorting_result from spikeinterface.core.generate import inject_some_split_units from spikeinterface.curation import get_potential_auto_merge -from spikeinterface.curation.auto_merge import normalize_correlogram + + + +from spikeinterface.curation.tests.common import make_sorting_result if hasattr(pytest, "global_test_folder"): @@ -17,12 +18,12 @@ else: cache_folder = Path("cache_folder") / "curation" -set_global_tmp_folder(cache_folder) - -def test_get_auto_merge_list(): - rec, sorting = toy_example(num_segments=1, num_units=5, duration=[300.0], firing_rate=20.0, seed=42) +def test_get_auto_merge_list(sorting_result_for_curation): + + sorting = sorting_result_for_curation.sorting + recording = sorting_result_for_curation.recording num_unit_splited = 1 num_split = 2 @@ -30,6 +31,8 @@ def test_get_auto_merge_list(): sorting, split_ids=sorting.unit_ids[:num_unit_splited], num_split=num_split, output_ids=True, seed=42 ) + + print(sorting_with_split) print(sorting_with_split.unit_ids) print(other_ids) @@ -41,11 +44,19 @@ def test_get_auto_merge_list(): # shutil.rmtree(wf_folder) # we = extract_waveforms(rec, sorting_with_split, mode="folder", folder=wf_folder, n_jobs=1) - we = extract_waveforms(rec, sorting_with_split, mode="memory", folder=None, n_jobs=1) + # we = extract_waveforms(rec, sorting_with_split, mode="memory", folder=None, n_jobs=1) # print(we) + job_kwargs = dict(n_jobs=-1) + + sorting_result = start_sorting_result(sorting_with_split, recording, format="memory") + sorting_result.select_random_spikes() + sorting_result.compute("waveforms", **job_kwargs) + sorting_result.compute("templates") + + potential_merges, outs = get_potential_auto_merge( - we, + sorting_result, minimum_spikes=1000, maximum_distance_um=150.0, peak_sign="neg", @@ -71,7 +82,10 @@ def test_get_auto_merge_list(): true_pair = tuple(true_pair) assert true_pair in potential_merges + + # import matplotlib.pyplot as plt + # from spikeinterface.curation.auto_merge import normalize_correlogram # templates_diff = outs['templates_diff'] # correlogram_diff = outs['correlogram_diff'] # bins = outs['bins'] @@ -122,4 +136,5 @@ def test_get_auto_merge_list(): if __name__ == "__main__": - test_get_auto_merge_list() + sorting_result = make_sorting_result(sparse=True) + test_get_auto_merge_list(sorting_result) diff --git a/src/spikeinterface/curation/tests/test_remove_redundant.py b/src/spikeinterface/curation/tests/test_remove_redundant.py index 9e27374de1..b304ab19b9 100644 --- a/src/spikeinterface/curation/tests/test_remove_redundant.py +++ b/src/spikeinterface/curation/tests/test_remove_redundant.py @@ -6,46 +6,38 @@ import numpy as np -from spikeinterface import WaveformExtractor, load_extractor, extract_waveforms, NumpySorting, set_global_tmp_folder +from spikeinterface import start_sorting_result from spikeinterface.core.generate import inject_some_duplicate_units -from spikeinterface.extractors import toy_example - -from spikeinterface.curation import remove_redundant_units +from spikeinterface.curation.tests.common import make_sorting_result -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "curation" -else: - cache_folder = Path("cache_folder") / "curation" +from spikeinterface.curation import remove_redundant_units -set_global_tmp_folder(cache_folder) +def test_remove_redundant_units(sorting_result_for_curation): -def test_remove_redundant_units(): - rec, sorting = toy_example(num_segments=1, duration=[100.0], seed=2205) + sorting = sorting_result_for_curation.sorting + recording = sorting_result_for_curation.recording sorting_with_dup = inject_some_duplicate_units(sorting, ratio=0.8, num=4, seed=2205) - print(sorting.unit_ids) - print(sorting_with_dup.unit_ids) - - # rec = rec.save() - # sorting_with_dup = sorting_with_dup.save() - # wf_folder = cache_folder / "wf_dup" - # if wf_folder.exists(): - # shutil.rmtree(wf_folder) - # we = extract_waveforms(rec, sorting_with_dup, folder=wf_folder) + # print(sorting.unit_ids) + # print(sorting_with_dup.unit_ids) - we = extract_waveforms(rec, sorting_with_dup, mode="memory", folder=None, n_jobs=1) + job_kwargs = dict(n_jobs=-1) + sorting_result = start_sorting_result(sorting_with_dup, recording, format="memory") + sorting_result.select_random_spikes() + sorting_result.compute("waveforms", **job_kwargs) + sorting_result.compute("templates") - # print(we) for remove_strategy in ("max_spikes", "minimum_shift", "highest_amplitude"): - sorting_clean = remove_redundant_units(we, remove_strategy=remove_strategy) + sorting_clean = remove_redundant_units(sorting_result, remove_strategy=remove_strategy) # print(sorting_clean) # print(sorting_clean.unit_ids) assert np.array_equal(sorting_clean.unit_ids, sorting.unit_ids) if __name__ == "__main__": - test_remove_redundant_units() + sorting_result = make_sorting_result(sparse=True) + test_remove_redundant_units(sorting_result) diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index 1ea22144ff..bc4a636684 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -4,12 +4,6 @@ from pathlib import Path from spikeinterface.core import generate_ground_truth_recording, start_sorting_result -from spikeinterface.postprocessing import ( - compute_spike_amplitudes, - compute_template_similarity, - compute_principal_components, -) -from spikeinterface.qualitymetrics import compute_quality_metrics if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "exporters" diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 0d73c2ce0c..b10e0c5a1d 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -73,7 +73,7 @@ def test_export_to_phy_by_property(sorting_result_with_group_for_export): # Remove one channel # recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7]) - # sorting_result_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm, sparse=False) + # sorting_result_rm = start_sorting_result(sorting, recording_rm, , sparse=False) # sparsity_group = compute_sparsity(sorting_result_rm, method="by_property", by_property="group") # export_to_phy( diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 20c1a2ca4e..b39ece1a05 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -45,14 +45,15 @@ class ComputePrincipalComponents(ResultExtension): Examples -------- - >>> we = si.extract_waveforms(recording, sorting, folder='waveforms') - >>> pc = st.compute_principal_components(we, n_components=3, mode='by_channel_local') + >>> sorting_result = start_sorting_result(sorting, recording) + >>> sorting_result.compute("principal_components", n_components=3, mode='by_channel_local') + >>> ext_pca = sorting_result.get_extension("principal_components") >>> # get pre-computed projections for unit_id=1 - >>> projections = pc.get_projections(unit_id=1) + >>> projections = ext_pca.get_projections(unit_id=1) >>> # retrieve fitted pca model(s) - >>> pca_model = pc.get_pca_model() + >>> pca_model = ext_pca.get_pca_model() >>> # compute projections on new waveforms - >>> proj_new = pc.project_new(new_waveforms) + >>> proj_new = ext_pca.project_new(new_waveforms) >>> # run for all spikes in the SortingExtractor >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 793b44f099..0fadf3ccc6 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -1,11 +1,13 @@ from __future__ import annotations +import warnings + import numpy as np from spikeinterface.core.core_tools import define_function_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import NumpySorting, extract_waveforms +from spikeinterface.core import NumpySorting, estimate_templates class RemoveArtifactsRecording(BasePreprocessor): @@ -80,11 +82,8 @@ class RemoveArtifactsRecording(BasePreprocessor): time_jitter: float, default: 0 If non 0, then for mode "median" or "average", a time jitter in ms can be allowed to minimize the residuals - waveforms_kwargs: dict or None, default: None - The arguments passed to the WaveformExtractor object when extracting the - artifacts, for mode "median" or "average". - By default, the global job kwargs are used, in addition to {"allow_unfiltered" : True, "mode":"memory"}. - To estimate sparse artifact + waveforms_kwargs: None + Depracted and ignored Returns ------- @@ -107,8 +106,12 @@ def __init__( sparsity=None, scale_amplitude=False, time_jitter=0, - waveforms_kwargs={"allow_unfiltered": True, "mode": "memory"}, - ): + waveforms_kwargs=None, + ): + if waveforms_kwargs is not None: + warnings("remove_artifacts() waveforms_kwargs is deprecated and ignored") + + available_modes = ("zeros", "linear", "cubic", "average", "median") num_seg = recording.get_num_segments() @@ -169,19 +172,16 @@ def __init__( ms_before is not None and ms_after is not None ), f"ms_before/after should not be None for mode {mode}" sorting = NumpySorting.from_times_labels(list_triggers, list_labels, recording.get_sampling_frequency()) - sorting = sorting.save() - waveforms_kwargs.update({"ms_before": ms_before, "ms_after": ms_after}) - w = extract_waveforms(recording, sorting, None, **waveforms_kwargs) + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + templates = estimate_templates(recording=recording, spikes=sorting.to_spike_vector(), + unit_ids=sorting.unit_ids, nbefore=nbefore, nafter=nafter, + operator=mode, return_scaled=False) artifacts = {} - sparsity = {} - for label in w.sorting.unit_ids: - artifacts[label] = w.get_template(label, mode=mode).astype(recording.dtype) - if w.is_sparse(): - unit_ind = w.sorting.id_to_index(label) - sparsity[label] = w.sparsity.mask[unit_ind] - else: - sparsity = None + for i, label in enumerate(sorting.unit_ids): + artifacts[label] = templates[i, :, :] if sparsity is not None: labels = [] diff --git a/src/spikeinterface/preprocessing/tests/test_remove_artifacts.py b/src/spikeinterface/preprocessing/tests/test_remove_artifacts.py index dd9fd84fbd..b8a6e83f67 100644 --- a/src/spikeinterface/preprocessing/tests/test_remove_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_remove_artifacts.py @@ -6,18 +6,18 @@ from spikeinterface.core import generate_recording from spikeinterface.preprocessing import remove_artifacts -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "preprocessing" -else: - cache_folder = Path("cache_folder") / "preprocessing" +# if hasattr(pytest, "global_test_folder"): +# cache_folder = pytest.global_test_folder / "preprocessing" +# else: +# cache_folder = Path("cache_folder") / "preprocessing" -set_global_tmp_folder(cache_folder) +# set_global_tmp_folder(cache_folder) def test_remove_artifacts(): # one segment only rec = generate_recording(durations=[10.0]) - rec = rec.save(folder=cache_folder / "recording") + # rec = rec.save(folder=cache_folder / "recording") rec.annotate(is_filtered=True) triggers = [15000, 30000] diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index e8d446f152..6d821a5115 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -324,7 +324,6 @@ def test_synchrony_metrics(sorting_result_simple): previous_sorting_result = sorting_result for sync_level in added_synchrony_levels: sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) - #waveform_extractor_sync = extract_waveforms(previous_waveform_extractor.recording, sorting_sync, mode="memory") sorting_result_sync = start_sorting_result(sorting_sync, sorting_result.recording, format="memory") From 7ed9192ff017ef89288cc510a6616a57a49c9db8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 12 Feb 2024 16:33:28 +0100 Subject: [PATCH 050/192] wip --- ...forms_extractor_backwards_compatibility.py | 5 +++ ...forms_extractor_backwards_compatibility.py | 2 +- .../curation/tests/test_auto_merge.py | 17 ++------ .../tests/test_sortingview_curation.py | 41 ++++++++++--------- 4 files changed, 32 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index 443d75c08a..2a602b1f38 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -86,5 +86,10 @@ def test_extract_waveforms(): +# @pytest.mark.skip(): +# def test_read_old_waveforms_extractor_binary(): +# folder = "" + + if __name__ == "__main__": test_extract_waveforms() diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index d4c6d0e01f..af08385dcb 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -457,7 +457,7 @@ def _read_old_waveforms_extractor_binary(folder): # TODO : implement this when extension will be prted in the new API # old_extension_to_new_class : { - # old extensions with same names and equvalent data + # old extensions with same names and equvalent data except similarity>template_similarity # "spike_amplitudes": , # "spike_locations": , # "amplitude_scalings": , diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index ffdfdfbd81..45b6ba370b 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -33,19 +33,10 @@ def test_get_auto_merge_list(sorting_result_for_curation): - print(sorting_with_split) - print(sorting_with_split.unit_ids) - print(other_ids) - - # rec = rec.save() - # sorting_with_split = sorting_with_split.save() - # wf_folder = cache_folder / "wf_auto_merge" - # if wf_folder.exists(): - # shutil.rmtree(wf_folder) - # we = extract_waveforms(rec, sorting_with_split, mode="folder", folder=wf_folder, n_jobs=1) - - # we = extract_waveforms(rec, sorting_with_split, mode="memory", folder=None, n_jobs=1) - # print(we) + # print(sorting_with_split) + # print(sorting_with_split.unit_ids) + # print(other_ids) + job_kwargs = dict(n_jobs=-1) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index ce6c7dd5a6..44191a2bed 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -29,25 +29,28 @@ set_global_tmp_folder(cache_folder) -# this needs to be run only once -def generate_sortingview_curation_dataset(): - import spikeinterface.widgets as sw - - local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") - recording, sorting = read_mearec(local_path) - - we = si.extract_waveforms(recording, sorting, folder=None, mode="memory") - - _ = compute_spike_amplitudes(we) - _ = compute_correlograms(we) - _ = compute_template_similarity(we) - _ = compute_unit_locations(we) - - # plot_sorting_summary with curation - w = sw.plot_sorting_summary(we, curation=True, backend="sortingview") - - # curation_link: - # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary +# this needs to be run only once: if we want to regenerate we need to start with sorting result +# TODO : regenerate the +# def generate_sortingview_curation_dataset(): +# import spikeinterface.widgets as sw + +# local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") +# recording, sorting = read_mearec(local_path) + +# sorting_result = si.start_sorting_result(sorting, recording, format="memory") +# sorting_result.select_random_spikes() +# sorting_result.compute("waveforms") +# sorting_result.compute("templates") +# sorting_result.compute("noise_levels") +# sorting_result.compute("spike_amplitudes") +# sorting_result.compute("template_similarity") +# sorting_result.compute("unit_locations") + +# # plot_sorting_summary with curation +# w = sw.plot_sorting_summary(sorting_result, curation=True, backend="sortingview") + +# # curation_link: +# # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") From 9220e9591be1d5748c44eaa83ed57327faf8ad60 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 12 Feb 2024 17:26:23 +0100 Subject: [PATCH 051/192] Starting to generalize templates in SC2 and matching --- .../sorters/internal/spyking_circus2.py | 38 +++++---- .../clustering/clustering_tools.py | 37 ++++----- .../clustering/random_projections.py | 81 ++++++++++--------- .../sortingcomponents/matching/naive.py | 16 ++-- 4 files changed, 91 insertions(+), 81 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 37c4ea74b3..57ae132cf0 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -8,6 +8,8 @@ from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.template import Templates +from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter from spikeinterface.sortingcomponents.tools import cache_preprocessing from spikeinterface.core.basesorting import minimum_spike_dtype @@ -41,7 +43,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "select_per_channel": False, }, "clustering": {"legacy": False}, - "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, + "matching": {"method": "naive", "method_kwargs": {}}, "apply_preprocessing": True, "shared_memory": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -218,27 +220,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): mode = "folder" waveforms_folder = sorter_output_folder / "waveforms" - we = extract_waveforms( - recording_f, - sorting, - waveforms_folder, - return_scaled=False, - precompute_template=["median"], - mode=mode, - **waveforms_params, - ) + # we = extract_waveforms( + # recording_f, + # sorting, + # waveforms_folder, + # return_scaled=False, + # precompute_template=["median"], + # mode=mode, + # **waveforms_params, + # ) + + nbefore = int(params["general"]["ms_before"] * sampling_frequency / 1000.0) + nafter = int(params["general"]["ms_after"] * sampling_frequency / 1000.0) + + templates_array = estimate_templates(recording, labeled_peaks, unit_ids, nbefore, nafter, + False, job_name=None, **job_kwargs) + + templates = Templates(templates_array, + sampling_frequency, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()) ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces matching_method = params["matching"]["method"] matching_params = params["matching"]["method_kwargs"].copy() matching_job_params = {} matching_job_params.update(job_kwargs) - if matching_method == "wobble": - matching_params["templates"] = we.get_all_templates(mode="median") - matching_params["nbefore"] = we.nbefore - matching_params["nafter"] = we.nafter - else: - matching_params["waveform_extractor"] = we + matching_params["templates"] = templates if matching_method == "circus-omp-svd": for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index b7f17d99e3..25d4f64456 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -536,12 +536,12 @@ def remove_duplicates( def remove_duplicates_via_matching( - waveform_extractor, + templates, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None, - method="circus-omp-svd", + method="naive", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core import BinaryRecordingExtractor @@ -553,21 +553,19 @@ def remove_duplicates_via_matching( job_kwargs = fix_job_kwargs(job_kwargs) - if waveform_extractor.is_sparse(): - sparsity = waveform_extractor.sparsity.mask + templates_array = templates.templates_array - templates = waveform_extractor.get_all_templates(mode="median").copy() - nb_templates = len(templates) - duration = waveform_extractor.nbefore + waveform_extractor.nafter + nb_templates = len(templates_array) + duration = templates.nbefore + templates.nafter - fs = waveform_extractor.recording.get_sampling_frequency() - num_chans = waveform_extractor.recording.get_num_channels() + fs = templates.sampling_frequency + num_chans = len(templates.channel_ids) - if waveform_extractor.is_sparse(): - for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - templates[count][:, ~sparsity[count]] = 0 + #if waveform_extractor.is_sparse(): + # for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): + # templates[count][:, ~sparsity[count]] = 0 - zdata = templates.reshape(nb_templates, -1) + zdata = templates_array.reshape(nb_templates, -1) padding = 2 * duration blanck = np.zeros(padding * num_chans, dtype=np.float32) @@ -586,10 +584,10 @@ def remove_duplicates_via_matching( f.close() recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") - recording = recording.set_probe(waveform_extractor.recording.get_probe()) + recording = recording.set_probe(templates.probe) recording.annotate(is_filtered=True) - margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) + margin = 2 * max(templates.nbefore, templates.nafter) half_marging = margin // 2 chunk_size = duration + 3 * margin @@ -597,16 +595,15 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() local_params.update( - {"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "optimize_amplitudes": False} + {"templates": templates, "amplitudes": [0.975, 1.025], "optimize_amplitudes": False} ) - spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) - indices = np.argsort(counts) + ignore_ids = [] similar_templates = [[], []] - for i in indices: + for i in range(nb_templates): t_start = padding + i * duration t_stop = padding + (i + 1) * duration @@ -662,7 +659,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording, local_params, waveform_extractor + del recording, sub_recording, local_params, templates os.remove(tmp_filename) return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 48660cc80c..b91153a5b8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -17,12 +17,14 @@ from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core import get_global_tmp_folder, get_channel_distances, get_random_data_chunks from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler -from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers +from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers, estimate_templates from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms +from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature +from spikeinterface.core.template import Templates from spikeinterface.core.node_pipeline import ( run_node_pipeline, ExtractDenseWaveforms, @@ -70,6 +72,8 @@ def main_function(cls, recording, peaks, params): if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1: params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count() + job_kwargs = fix_job_kwargs(params["job_kwargs"]) + d = params verbose = d["job_kwargs"]["verbose"] @@ -145,37 +149,42 @@ def main_function(cls, recording, peaks, params): labels = np.unique(peak_labels) labels = labels[labels >= 0] - best_spikes = {} - nb_spikes = 0 + # best_spikes = {} + # nb_spikes = 0 + + # all_indices = np.arange(0, peak_labels.size) - all_indices = np.arange(0, peak_labels.size) + # max_spikes = params["waveforms"]["max_spikes_per_unit"] + # selection_method = params["selection_method"] - max_spikes = params["waveforms"]["max_spikes_per_unit"] - selection_method = params["selection_method"] + # for unit_ind in labels: + # mask = peak_labels == unit_ind + # if selection_method == "closest_to_centroid": + # data = hdbscan_data[mask] + # centroid = np.median(data, axis=0) + # distances = sklearn.metrics.pairwise_distances(centroid[np.newaxis, :], data)[0] + # best_spikes[unit_ind] = all_indices[mask][np.argsort(distances)[:max_spikes]] + # elif selection_method == "random": + # best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[:max_spikes] + # nb_spikes += best_spikes[unit_ind].size - for unit_ind in labels: - mask = peak_labels == unit_ind - if selection_method == "closest_to_centroid": - data = hdbscan_data[mask] - centroid = np.median(data, axis=0) - distances = sklearn.metrics.pairwise_distances(centroid[np.newaxis, :], data)[0] - best_spikes[unit_ind] = all_indices[mask][np.argsort(distances)[:max_spikes]] - elif selection_method == "random": - best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[:max_spikes] - nb_spikes += best_spikes[unit_ind].size + # spikes = np.zeros(nb_spikes, dtype=minimum_spike_dtype) - spikes = np.zeros(nb_spikes, dtype=minimum_spike_dtype) + # mask = np.zeros(0, dtype=np.int32) + # for unit_ind in labels: + # mask = np.concatenate((mask, best_spikes[unit_ind])) - mask = np.zeros(0, dtype=np.int32) - for unit_ind in labels: - mask = np.concatenate((mask, best_spikes[unit_ind])) + # idx = np.argsort(mask) + # mask = mask[idx] - idx = np.argsort(mask) - mask = mask[idx] + # spikes["sample_index"] = peaks[mask]["sample_index"] + # spikes["segment_index"] = peaks[mask]["segment_index"] + # spikes["unit_index"] = peak_labels[mask] - spikes["sample_index"] = peaks[mask]["sample_index"] - spikes["segment_index"] = peaks[mask]["segment_index"] - spikes["unit_index"] = peak_labels[mask] + spikes = np.zeros(len(peaks), dtype=minimum_spike_dtype) + spikes["sample_index"] = peaks["sample_index"] + spikes["segment_index"] = peaks["segment_index"] + spikes["unit_index"] = peak_labels if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) @@ -192,16 +201,14 @@ def main_function(cls, recording, peaks, params): mode = "folder" sorting = sorting.save(folder=sorting_folder) - we = extract_waveforms( - recording, - sorting, - waveform_folder, - return_scaled=False, - mode=mode, - precompute_template=["median"], - **params["job_kwargs"], - **params["waveforms"], - ) + nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) + nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + + templates_array = estimate_templates(recording, spikes, unit_ids, nbefore, nafter, + False, job_name=None, **job_kwargs) + + templates = Templates(templates_array, + fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()) cleaning_matching_params = params["job_kwargs"].copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: @@ -216,10 +223,10 @@ def main_function(cls, recording, peaks, params): cleaning_params["tmp_folder"] = tmp_folder labels, peak_labels = remove_duplicates_via_matching( - we, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) - del we, sorting + del sorting if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index f79f5c3f08..9cb3815f1e 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -31,7 +31,7 @@ class NaiveMatching(BaseTemplateMatchingEngine): """ default_params = { - "waveform_extractor": None, + "templates": None, "peak_sign": "neg", "exclude_sweep_ms": 0.1, "detect_threshold": 5, @@ -45,20 +45,20 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls.default_params.copy() d.update(kwargs) - assert d["waveform_extractor"] is not None, "'waveform_extractor' must be supplied" + assert d["templates"] is not None, "'templates' must be supplied" - we = d["waveform_extractor"] + templates = d["templates"] if d["noise_levels"] is None: - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"]) + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["radius_um"] - d["nbefore"] = we.nbefore - d["nafter"] = we.nafter + d["nbefore"] = templates.nbefore + d["nafter"] = templates.nafter d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) @@ -73,8 +73,8 @@ def get_margin(cls, recording, kwargs): def serialize_method_kwargs(cls, kwargs): kwargs = dict(kwargs) - we = kwargs.pop("waveform_extractor") - kwargs["templates"] = we.get_all_templates(mode="average") + templates = kwargs.pop("templates") + kwargs["templates"] = templates.templates_array return kwargs From 8fd37811959fb294cc43c1a08734d01d1beaba20 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 12 Feb 2024 17:46:56 +0100 Subject: [PATCH 052/192] WIP --- .../sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/matching/naive.py | 14 ++--- .../sortingcomponents/matching/tdc.py | 58 +++++++++---------- .../sortingcomponents/matching/wobble.py | 25 +++++--- 4 files changed, 54 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 57ae132cf0..ce6e3c26ca 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -43,7 +43,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "select_per_channel": False, }, "clustering": {"legacy": False}, - "matching": {"method": "naive", "method_kwargs": {}}, + "matching": {"method": "wobble", "method_kwargs": {}}, "apply_preprocessing": True, "shared_memory": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 9cb3815f1e..f64f4d8176 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -7,6 +7,7 @@ from spikeinterface.core import WaveformExtractor, get_template_channel_sparsity, get_template_extremum_channel from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive +from spikeinterface.core.template import Templates spike_dtype = [ ("sample_index", "int64"), @@ -45,7 +46,10 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls.default_params.copy() d.update(kwargs) - assert d["templates"] is not None, "'templates' must be supplied" + assert isinstance(d["templates"], Templates), ( + f"The templates supplied is of type {type(d['waveform_extractor'])} " + f"and must be a Templates" + ) templates = d["templates"] @@ -72,10 +76,6 @@ def get_margin(cls, recording, kwargs): @classmethod def serialize_method_kwargs(cls, kwargs): kwargs = dict(kwargs) - - templates = kwargs.pop("templates") - kwargs["templates"] = templates.templates_array - return kwargs @classmethod @@ -88,7 +88,7 @@ def main_function(cls, traces, method_kwargs): abs_threholds = method_kwargs["abs_threholds"] exclude_sweep_size = method_kwargs["exclude_sweep_size"] neighbours_mask = method_kwargs["neighbours_mask"] - templates = method_kwargs["templates"] + templates_array = method_kwargs["templates"].templates_array nbefore = method_kwargs["nbefore"] nafter = method_kwargs["nafter"] @@ -114,7 +114,7 @@ def main_function(cls, traces, method_kwargs): i1 = peak_sample_ind[i] + nafter waveforms = traces[i0:i1, :] - dist = np.sum(np.sum((templates - waveforms[None, :, :]) ** 2, axis=1), axis=1) + dist = np.sum(np.sum((templates_array - waveforms[None, :, :]) ** 2, axis=1), axis=1) cluster_index = np.argmin(dist) spikes["cluster_index"][i] = cluster_index diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 5c2303f9a0..25f8129b3d 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -11,6 +11,7 @@ ) from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive +from spikeinterface.core.template import Templates spike_dtype = [ ("sample_index", "int64"), @@ -47,7 +48,7 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): """ default_params = { - "waveform_extractor": None, + "templates": None, "peak_sign": "neg", "peak_shift_ms": 0.2, "detect_threshold": 5, @@ -68,35 +69,33 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls.default_params.copy() d.update(kwargs) - assert isinstance(d["waveform_extractor"], WaveformExtractor), ( - f"The waveform_extractor supplied is of type {type(d['waveform_extractor'])} " - f"and must be a WaveformExtractor" + assert isinstance(d["templates"], Templates), ( + f"The templates supplied is of type {type(d['templates'])} " + f"and must be a Templates" ) - we = d["waveform_extractor"] - unit_ids = we.unit_ids - channel_ids = we.channel_ids + templates = d["templates"] + unit_ids = templates.unit_ids + channel_ids = templates.channel_ids - sr = we.sampling_frequency + sr = templates.sampling_frequency - # TODO load as sharedmem - templates = we.get_all_templates(mode="average") - d["templates"] = templates - d["nbefore"] = we.nbefore - d["nafter"] = we.nafter + d["nbefore"] = templates.nbefore + d["nafter"] = templates.nafter + templates_array = templates.templates_array nbefore_short = int(d["ms_before"] * sr / 1000.0) nafter_short = int(d["ms_before"] * sr / 1000.0) - assert nbefore_short <= we.nbefore - assert nafter_short <= we.nafter + assert nbefore_short <= templates.nbefore + assert nafter_short <= templates.nafter d["nbefore_short"] = nbefore_short d["nafter_short"] = nafter_short - s0 = we.nbefore - nbefore_short - s1 = -(we.nafter - nafter_short) + s0 = templates.nbefore - nbefore_short + s1 = -(templates.nafter - nafter_short) if s1 == 0: s1 = None - templates_short = templates[:, slice(s0, s1), :].copy() + templates_short = templates_array[:, slice(s0, s1), :].copy() d["templates_short"] = templates_short d["peak_shift"] = int(d["peak_shift_ms"] / 1000 * sr) @@ -110,7 +109,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["radius_um"] - sparsity = compute_sparsity(we, method="snr", peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) + sparsity = compute_sparsity(templates, method="best_channels")#, peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) template_sparsity_inds = sparsity.unit_id_to_channel_indices template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") for unit_index, unit_id in enumerate(unit_ids): @@ -119,12 +118,12 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["template_sparsity"] = template_sparsity - extremum_channel = get_template_extremum_channel(we, peak_sign=d["peak_sign"], outputs="index") + extremum_channel = get_template_extremum_channel(templates, peak_sign=d["peak_sign"], outputs="index") # as numpy vector extremum_channel = np.array([extremum_channel[unit_id] for unit_id in unit_ids], dtype="int64") d["extremum_channel"] = extremum_channel - channel_locations = we.recording.get_channel_locations() + channel_locations = templates.probe.contact_positions # TODO try it with real locaion unit_locations = channel_locations[extremum_channel] @@ -143,11 +142,11 @@ def initialize_and_check_kwargs(cls, recording, kwargs): # compute unitary discriminent vector (chans,) = np.nonzero(d["template_sparsity"][unit_ind, :]) - template_sparse = templates[unit_ind, :, :][:, chans] + template_sparse = templates_array[unit_ind, :, :][:, chans] closest_vec = [] # against N closets for u in closest_u: - vec = templates[u, :, :][:, chans] - template_sparse + vec = templates_array[u, :, :][:, chans] - template_sparse vec /= np.sum(vec**2) closest_vec.append((u, vec)) # against noise @@ -175,9 +174,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): @classmethod def serialize_method_kwargs(cls, kwargs): kwargs = dict(kwargs) - - # remove waveform_extractor - kwargs.pop("waveform_extractor") return kwargs @classmethod @@ -222,6 +218,8 @@ def _tdc_find_spikes(traces, d, level=0): peak_sign = d["peak_sign"] templates = d["templates"] templates_short = d["templates_short"] + templates_array = templates.templates_array + margin = d["margin"] possible_clusters_by_channel = d["possible_clusters_by_channel"] @@ -266,7 +264,7 @@ def _tdc_find_spikes(traces, d, level=0): # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) - ## numba with cluster+channel spasity + ## numba with cluster+channel spasity union_channels = np.any(d["template_sparsity"][possible_clusters, :], axis=0) # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) distances = numba_sparse_dist(wf_short, templates_short, union_channels, possible_clusters) @@ -279,7 +277,7 @@ def _tdc_find_spikes(traces, d, level=0): cluster_index = possible_clusters[ind] chan_sparsity = d["template_sparsity"][cluster_index, :] - template_sparse = templates[cluster_index, :, :][:, chan_sparsity] + template_sparse = templates_array[cluster_index, :, :][:, chan_sparsity] # find best shift @@ -293,7 +291,7 @@ def _tdc_find_spikes(traces, d, level=0): ## numba version numba_best_shift( traces, - templates[cluster_index, :, :], + templates_array[cluster_index, :, :], sample_index, d["nbefore"], possible_shifts, @@ -327,7 +325,7 @@ def _tdc_find_spikes(traces, d, level=0): amplitude = 1.0 # remove template - template = templates[cluster_index, :, :] + template = templates_array[cluster_index, :, :] s0 = sample_index - d["nbefore"] s1 = sample_index + d["nafter"] traces[s0:s1, :] -= template * amplitude diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 1b11796fa5..0bbc147dd8 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt from .main import BaseTemplateMatchingEngine +from spikeinterface.core.template import Templates @dataclass @@ -309,7 +310,7 @@ class WobbleMatch(BaseTemplateMatchingEngine): """ default_params = { - "waveform_extractor": None, + "templates": None, } spike_dtype = [ ("sample_index", "int64"), @@ -336,29 +337,37 @@ def initialize_and_check_kwargs(cls, recording, kwargs): Updated Keyword arguments. """ d = cls.default_params.copy() - required_kwargs_keys = ["nbefore", "nafter", "templates"] + + required_kwargs_keys = ["templates"] for required_key in required_kwargs_keys: assert required_key in kwargs, f"`{required_key}` is a required key in the kwargs" + + + parameters = kwargs.get("parameters", {}) templates = kwargs["templates"] - templates = templates.astype(np.float32, casting="safe") + assert isinstance(templates, Templates), ( + f"The templates supplied is of type {type(d['templates'])} " + f"and must be a Templates" + ) + templates_array = templates.templates_array.astype(np.float32, casting="safe") # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) - template_meta = TemplateMetadata.from_parameters_and_templates(params, templates) + template_meta = TemplateMetadata.from_parameters_and_templates(params, templates_array) sparsity = Sparsity.from_parameters_and_templates( - params, templates + params, templates_array ) # TODO: replace with spikeinterface sparsity # Perform initial computations on templates necessary for computing the objective - sparse_templates = np.where(sparsity.visible_channels[:, np.newaxis, :], templates, 0) + sparse_templates = np.where(sparsity.visible_channels[:, np.newaxis, :], templates_array, 0) temporal, singular, spatial = compress_templates(sparse_templates, params.approx_rank) temporal_jittered = upsample_and_jitter(temporal, params.jitter_factor, template_meta.num_samples) compressed_templates = (temporal, singular, spatial, temporal_jittered) pairwise_convolution = convolve_templates( compressed_templates, params.jitter_factor, params.approx_rank, template_meta.jittered_indices, sparsity ) - norm_squared = compute_template_norm(sparsity.visible_channels, templates) + norm_squared = compute_template_norm(sparsity.visible_channels, templates_array) template_data = TemplateData( compressed_templates=compressed_templates, pairwise_convolution=pairwise_convolution, @@ -370,6 +379,8 @@ def initialize_and_check_kwargs(cls, recording, kwargs): kwargs["template_meta"] = template_meta kwargs["sparsity"] = sparsity kwargs["template_data"] = template_data + kwargs["nbefore"] = templates.nbefore + kwargs["nafter"] = templates.nafter d.update(kwargs) return d From 4eb829e361bc632fda3d0c298d802d8dc3e4f6ac Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 12 Feb 2024 18:56:12 +0100 Subject: [PATCH 053/192] Improve sparsity for snr and ptp when using Templates as input --- src/spikeinterface/core/__init__.py | 1 - src/spikeinterface/core/result_core.py | 40 +++++- src/spikeinterface/core/sparsity.py | 86 ++++++++---- src/spikeinterface/core/template_tools.py | 125 +++++------------- .../core/tests/test_sparsity.py | 20 ++- .../core/tests/test_waveform_tools.py | 4 +- .../sortingcomponents/matching/naive.py | 2 +- 7 files changed, 139 insertions(+), 139 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 1bac4697e2..53925b08be 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -136,7 +136,6 @@ get_template_extremum_channel, get_template_extremum_channel_peak_shift, get_template_extremum_amplitude, - get_template_channel_sparsity, ) # channel sparsity diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 573596a205..1feabf8af1 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -12,6 +12,7 @@ from .sortingresult import ResultExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_average from .recording_tools import get_noise_levels +from .template import Templates class ComputeWaveforms(ResultExtension): """ @@ -257,13 +258,29 @@ def _select_extension_data(self, unit_ids): return new_data - def _get_data(self, operator="average", percentile=None): + def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator else: assert percentile is not None, "You must provide percentile=..." key = f"pencentile_{percentile}" - return self.data[key] + + templates_array = self.data[key] + + if outputs == "numpy": + return templates_array + elif outputs == "Templates": + return Templates( + templates_array=templates_array, + sampling_frequency=self.sorting_result.sampling_frequency, + nbefore=self.nbefore, + channel_ids=self.sorting_result.channel_ids, + unit_ids=self.sorting_result.unit_ids, + probe=self.sorting_result.get_probe(), + ) + else: + raise ValueError("outputs must be numpy or Templates") + def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True): """ @@ -369,8 +386,23 @@ def _set_params(self, ) return params - def _get_data(self): - return self.data["average"] + def _get_data(self, outputs="numpy"): + templates_array = self.data["average"] + + if outputs == "numpy": + return templates_array + elif outputs == "Templates": + return Templates( + templates_array=templates_array, + sampling_frequency=self.sorting_result.sampling_frequency, + nbefore=self.nbefore, + channel_ids=self.sorting_result.channel_ids, + unit_ids=self.sorting_result.unit_ids, + probe=self.sorting_result.get_probe(), + ) + else: + raise ValueError("outputs must be numpy or Templates") + def _select_extension_data(self, unit_ids): keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index b21cf66e1d..62bf75c25d 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -10,6 +10,7 @@ from .waveform_tools import estimate_templates_average + _sparsity_doc = """ method: str * "best_channels": N best channels with the largest amplitude. Use the "num_channels" argument to specify the @@ -302,50 +303,72 @@ def from_radius(cls, templates_or_sorting_result, radius_um, peak_sign="neg"): return cls(mask, templates_or_sorting_result.unit_ids, templates_or_sorting_result.channel_ids) @classmethod - def from_snr(cls, sorting_result, threshold, peak_sign="neg"): + def from_snr(cls, templates_or_sorting_result, threshold, noise_levels=None, peak_sign="neg"): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. """ from .template_tools import get_template_amplitudes + from .sortingresult import SortingResult + from .template import Templates - assert sorting_result.sparsity is None, "To compute sparsity you need a dense SortingResult" - mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") + assert templates_or_sorting_result.sparsity is None, "To compute sparsity you need a dense SortingResult or Templates" - peak_values = get_template_amplitudes(sorting_result, peak_sign=peak_sign, mode="extremum", return_scaled=True) + unit_ids = templates_or_sorting_result.unit_ids + channel_ids = templates_or_sorting_result.channel_ids - ext = sorting_result.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" - assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" - noise = ext.data["noise_levels"] + if isinstance(templates_or_sorting_result, SortingResult): + ext = templates_or_sorting_result.get_extension("noise_levels") + assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" + assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" + noise_levels = ext.data["noise_levels"] + elif isinstance(templates_or_sorting_result, Templates): + assert noise_levels is not None - for unit_ind, unit_id in enumerate(sorting_result.unit_ids): - chan_inds = np.nonzero((np.abs(peak_values[unit_id]) / noise) >= threshold) + mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") + + peak_values = get_template_amplitudes(templates_or_sorting_result, peak_sign=peak_sign, mode="extremum", return_scaled=True) + + for unit_ind, unit_id in enumerate(unit_ids): + chan_inds = np.nonzero((np.abs(peak_values[unit_id]) / noise_levels) >= threshold) mask[unit_ind, chan_inds] = True - return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) + return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, sorting_result, threshold): + def from_ptp(cls, templates_or_sorting_result, threshold, noise_levels=None): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the SNR threshold. """ - assert sorting_result.sparsity is None, "To compute sparsity with ptp you need a dense SortingResult" + + assert templates_or_sorting_result.sparsity is None, "To compute sparsity you need a dense SortingResult or Templates" + + from .template_tools import get_template_amplitudes + from .sortingresult import SortingResult + from .template import Templates + + unit_ids = templates_or_sorting_result.unit_ids + channel_ids = templates_or_sorting_result.channel_ids + + if isinstance(templates_or_sorting_result, SortingResult): + ext = templates_or_sorting_result.get_extension("noise_levels") + assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" + assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" + noise_levels = ext.data["noise_levels"] + elif isinstance(templates_or_sorting_result, Templates): + assert noise_levels is not None from .template_tools import _get_dense_templates_array - mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") - templates_array = _get_dense_templates_array(sorting_result, return_scaled=True) + mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") + + templates_array = _get_dense_templates_array(templates_or_sorting_result, return_scaled=True) templates_ptps = np.ptp(templates_array, axis=1) - ext = sorting_result.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" - assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" - noise = ext.data["noise_levels"] - for unit_ind, unit_id in enumerate(sorting_result.unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] / noise >= threshold) + for unit_ind, unit_id in enumerate(unit_ids): + chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) mask[unit_ind, chan_inds] = True - return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) + return cls(mask, unit_ids, channel_ids) @classmethod def from_energy(cls, sorting_result, threshold): @@ -409,6 +432,7 @@ def create_dense(cls, sorting_result): def compute_sparsity( templates_or_sorting_result, + noise_levels=None, method="radius", peak_sign="neg", num_channels=5, @@ -444,10 +468,14 @@ def compute_sparsity( # to keep backward compatibility templates_or_sorting_result = templates_or_sorting_result.sorting_result - if method in ("best_channels", "radius"): - assert isinstance(templates_or_sorting_result, (Templates, WaveformExtractor, SortingResult)), "compute_sparsity() need Templates or WaveformExtractor or SortingResult" + if method in ("best_channels", "radius", "snr", "ptp"): + assert isinstance(templates_or_sorting_result, (Templates, SortingResult)), f"compute_sparsity(method='{method}') need Templates or SortingResult" else: - assert isinstance(templates_or_sorting_result, (WaveformExtractor, SortingResult)), f"compute_sparsity(method='{method}') need WaveformExtractor or SortingResult" + assert isinstance(templates_or_sorting_result, SortingResult), f"compute_sparsity(method='{method}') need SortingResult" + + if method in ("snr", "ptp") and isinstance(templates_or_sorting_result, Templates): + assert noise_levels is not None, f"compute_sparsity(..., method='{method}') with Templates need noise_levels as input" + if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" @@ -457,13 +485,13 @@ def compute_sparsity( sparsity = ChannelSparsity.from_radius(templates_or_sorting_result, radius_um, peak_sign=peak_sign) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_snr(templates_or_sorting_result, threshold, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_snr(templates_or_sorting_result, threshold, noise_levels=noise_levels, peak_sign=peak_sign) + elif method == "ptp": + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp(templates_or_sorting_result, threshold, noise_levels=noise_levels, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_energy(templates_or_sorting_result, threshold) - elif method == "ptp": - assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp(templates_or_sorting_result, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" sparsity = ChannelSparsity.from_property(templates_or_sorting_result, by_property) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 360db97c8e..2a931a8c88 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -3,9 +3,7 @@ import warnings from .template import Templates -from .waveform_extractor import WaveformExtractor -from .sparsity import compute_sparsity, _sparsity_doc -from .recording_tools import get_channel_distances, get_noise_levels +from .sparsity import _sparsity_doc from .sortingresult import SortingResult @@ -13,8 +11,6 @@ def _get_dense_templates_array(one_object, return_scaled=True): if isinstance(one_object, Templates): templates_array = one_object.get_dense_templates() - elif isinstance(one_object, WaveformExtractor): - templates_array = one_object.get_all_templates(mode="average") elif isinstance(one_object, SortingResult): ext = one_object.get_extension("templates") if ext is not None: @@ -28,15 +24,13 @@ def _get_dense_templates_array(one_object, return_scaled=True): else: raise ValueError("SortingResult need extension 'templates' or 'fast_templates' to be computed") else: - raise ValueError("Input should be Templates or WaveformExtractor or SortingResult") + raise ValueError("Input should be Templates or SortingResult or SortingResult") return templates_array def _get_nbefore(one_object): if isinstance(one_object, Templates): return one_object.nbefore - elif isinstance(one_object, WaveformExtractor): - return one_object.nbefore elif isinstance(one_object, SortingResult): ext = one_object.get_extension("templates") if ext is not None: @@ -46,21 +40,21 @@ def _get_nbefore(one_object): return ext.nbefore raise ValueError("SortingResult need extension 'templates' or 'fast_templates' to be computed") else: - raise ValueError("Input should be Templates or WaveformExtractor or SortingResult") + raise ValueError("Input should be Templates or SortingResult or SortingResult") def get_template_amplitudes( - templates_or_waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", return_scaled: bool = True + templates_or_sorting_result, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", return_scaled: bool = True ): """ Get amplitude per channel for each unit. Parameters ---------- - templates_or_waveform_extractor: Templates | WaveformExtractor - A Templates or a WaveformExtractor object + templates_or_sorting_result: Templates | SortingResult + A Templates or a SortingResult object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "extremum" @@ -77,12 +71,12 @@ def get_template_amplitudes( assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" - unit_ids = templates_or_waveform_extractor.unit_ids - before = _get_nbefore(templates_or_waveform_extractor) + unit_ids = templates_or_sorting_result.unit_ids + before = _get_nbefore(templates_or_sorting_result) peak_values = {} - templates_array = _get_dense_templates_array(templates_or_waveform_extractor, return_scaled=return_scaled) + templates_array = _get_dense_templates_array(templates_or_sorting_result, return_scaled=return_scaled) for unit_ind, unit_id in enumerate(unit_ids): template = templates_array[unit_ind, :, :] @@ -108,7 +102,7 @@ def get_template_amplitudes( def get_template_extremum_channel( - templates_or_waveform_extractor, + templates_or_sorting_result, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", outputs: "id" | "index" = "id", @@ -118,8 +112,8 @@ def get_template_extremum_channel( Parameters ---------- - templates_or_waveform_extractor: Templates | WaveformExtractor - A Templates or a WaveformExtractor object + templates_or_sorting_result: Templates | SortingResult + A Templates or a SortingResult object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "extremum" @@ -139,10 +133,10 @@ def get_template_extremum_channel( assert mode in ("extremum", "at_index") assert outputs in ("id", "index") - unit_ids = templates_or_waveform_extractor.unit_ids - channel_ids = templates_or_waveform_extractor.channel_ids + unit_ids = templates_or_sorting_result.unit_ids + channel_ids = templates_or_sorting_result.channel_ids - peak_values = get_template_amplitudes(templates_or_waveform_extractor, peak_sign=peak_sign, mode=mode) + peak_values = get_template_amplitudes(templates_or_sorting_result, peak_sign=peak_sign, mode=mode) extremum_channels_id = {} extremum_channels_index = {} for unit_id in unit_ids: @@ -156,67 +150,8 @@ def get_template_extremum_channel( return extremum_channels_index -def get_template_channel_sparsity( - templates_or_waveform_extractor, - method="radius", - peak_sign="neg", - num_channels=5, - radius_um=100.0, - threshold=5, - by_property=None, - outputs="id", -): - """ - Get channel sparsity (subset of channels) for each template with several methods. - - Parameters - ---------- - templates_or_waveform_extractor: Templates | WaveformExtractor - A Templates or a WaveformExtractor object - - {} - - outputs: str - * "id": channel id - * "index": channel index - - Returns - ------- - sparsity: dict - Dictionary with unit ids as keys and sparse channel ids or indices (id or index based on "outputs") - as values - """ - from spikeinterface.core.sparsity import compute_sparsity - - warnings.warn( - "The 'get_template_channel_sparsity()' function is deprecated. " "Use 'compute_sparsity()' instead", - DeprecationWarning, - stacklevel=2, - ) - - assert outputs in ("id", "index"), "'outputs' can either be 'id' or 'index'" - sparsity = compute_sparsity( - templates_or_waveform_extractor, - method=method, - peak_sign=peak_sign, - num_channels=num_channels, - radius_um=radius_um, - threshold=threshold, - by_property=by_property, - ) - - # handle output ids or indexes - if outputs == "id": - return sparsity.unit_id_to_channel_ids - elif outputs == "index": - return sparsity.unit_id_to_channel_indices - - -get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc) - - def get_template_extremum_channel_peak_shift( - templates_or_waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg" + templates_or_sorting_result, peak_sign: "neg" | "pos" | "both" = "neg" ): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. @@ -225,8 +160,8 @@ def get_template_extremum_channel_peak_shift( Parameters ---------- - templates_or_waveform_extractor: Templates | WaveformExtractor - A Templates or a WaveformExtractor object + templates_or_sorting_result: Templates | SortingResult + A Templates or a SortingResult object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels @@ -235,15 +170,15 @@ def get_template_extremum_channel_peak_shift( shifts: dict Dictionary with unit ids as keys and shifts as values """ - unit_ids = templates_or_waveform_extractor.unit_ids - channel_ids = templates_or_waveform_extractor.channel_ids - nbefore = _get_nbefore(templates_or_waveform_extractor) + unit_ids = templates_or_sorting_result.unit_ids + channel_ids = templates_or_sorting_result.channel_ids + nbefore = _get_nbefore(templates_or_sorting_result) - extremum_channels_ids = get_template_extremum_channel(templates_or_waveform_extractor, peak_sign=peak_sign) + extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_result, peak_sign=peak_sign) shifts = {} - templates_array = _get_dense_templates_array(templates_or_waveform_extractor) + templates_array = _get_dense_templates_array(templates_or_sorting_result) for unit_ind, unit_id in enumerate(unit_ids): template = templates_array[unit_ind, :, :] @@ -264,7 +199,7 @@ def get_template_extremum_channel_peak_shift( def get_template_extremum_amplitude( - templates_or_waveform_extractor, + templates_or_sorting_result, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index", ): @@ -273,8 +208,8 @@ def get_template_extremum_amplitude( Parameters ---------- - templates_or_waveform_extractor: Templates | WaveformExtractor - A Templates or a WaveformExtractor object + templates_or_sorting_result: Templates | SortingResult + A Templates or a SortingResult object peak_sign: "neg" | "pos" | "both" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "at_index" @@ -289,12 +224,12 @@ def get_template_extremum_amplitude( """ assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" - unit_ids = templates_or_waveform_extractor.unit_ids - channel_ids = templates_or_waveform_extractor.channel_ids + unit_ids = templates_or_sorting_result.unit_ids + channel_ids = templates_or_sorting_result.channel_ids - extremum_channels_ids = get_template_extremum_channel(templates_or_waveform_extractor, peak_sign=peak_sign, mode=mode) + extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_result, peak_sign=peak_sign, mode=mode) - extremum_amplitudes = get_template_amplitudes(templates_or_waveform_extractor, peak_sign=peak_sign, mode=mode) + extremum_amplitudes = get_template_amplitudes(templates_or_sorting_result, peak_sign=peak_sign, mode=mode) unit_amplitudes = {} for unit_id in unit_ids: diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index d7c25a63d5..3f702971d6 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -3,7 +3,7 @@ import numpy as np import json -from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity +from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, Templates from spikeinterface.core.core_tools import check_json from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import start_sorting_result @@ -197,15 +197,16 @@ def test_estimate_sparsity(): def test_compute_sparsity(): recording, sorting = get_dataset() - - # using SortingResult + sorting_result = start_sorting_result(sorting=sorting, recording=recording, sparse=False) sorting_result.select_random_spikes() sorting_result.compute("fast_templates", return_scaled=True) sorting_result.compute("noise_levels", return_scaled=True) + # this is needed for method="energy" sorting_result.compute("waveforms", return_scaled=True) - print(sorting_result) + + # using object SortingResult sparsity = compute_sparsity(sorting_result, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(sorting_result, method="radius", radius_um=50., peak_sign="neg") sparsity = compute_sparsity(sorting_result, method="snr", threshold=5, peak_sign="neg") @@ -213,11 +214,16 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_result, method="energy", threshold=5) sparsity = compute_sparsity(sorting_result, method="by_property", by_property="group") - # using Templates - # TODO later + # using object Templates + templates = sorting_result.get_extension("fast_templates").get_data(outputs="Templates") + noise_levels = sorting_result.get_extension("noise_levels").get_data() + sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") + sparsity = compute_sparsity(templates, method="radius", radius_um=50., peak_sign="neg") + sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") + sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) -if __name__ == "__main__": # test_ChannelSparsity() +if __name__ == "__main__": # test_ChannelSparsity() # test_estimate_sparsity() test_compute_sparsity() diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 5e36df5186..9afc3771aa 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -233,6 +233,6 @@ def test_estimate_templates(): if __name__ == "__main__": - # test_waveform_tools() - # test_estimate_templates_average() + test_waveform_tools() + test_estimate_templates_average() test_estimate_templates() \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index f79f5c3f08..e4b89c999a 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -4,7 +4,7 @@ import numpy as np -from spikeinterface.core import WaveformExtractor, get_template_channel_sparsity, get_template_extremum_channel +from spikeinterface.core import WaveformExtractor, get_template_extremum_channel from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive From da26f389adbcfbc97249a2f2d4ddd872b857eea8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 12 Feb 2024 19:12:38 +0100 Subject: [PATCH 054/192] Apply black --- .../comparison/groundtruthstudy.py | 1 - src/spikeinterface/comparison/hybrid.py | 2 - .../comparison/tests/test_hybrid.py | 3 +- src/spikeinterface/core/__init__.py | 13 +- src/spikeinterface/core/generate.py | 2 +- src/spikeinterface/core/node_pipeline.py | 13 +- src/spikeinterface/core/result_core.py | 83 ++++--- src/spikeinterface/core/sorting_tools.py | 4 +- src/spikeinterface/core/sortingresult.py | 221 ++++++++++-------- src/spikeinterface/core/sparsity.py | 68 ++++-- src/spikeinterface/core/template_tools.py | 24 +- .../core/tests/test_node_pipeline.py | 2 +- .../core/tests/test_result_core.py | 46 ++-- .../core/tests/test_sorting_tools.py | 9 +- .../core/tests/test_sortingresult.py | 27 ++- .../core/tests/test_sparsity.py | 10 +- .../core/tests/test_template_tools.py | 8 +- .../core/tests/test_waveform_tools.py | 7 +- ...forms_extractor_backwards_compatibility.py | 13 +- src/spikeinterface/core/waveform_tools.py | 27 ++- ...forms_extractor_backwards_compatibility.py | 74 +++--- src/spikeinterface/curation/auto_merge.py | 12 +- src/spikeinterface/curation/tests/common.py | 4 +- .../curation/tests/test_auto_merge.py | 10 +- .../curation/tests/test_remove_redundant.py | 1 - src/spikeinterface/exporters/tests/common.py | 2 +- .../exporters/tests/test_export_to_phy.py | 3 +- src/spikeinterface/exporters/to_phy.py | 6 +- src/spikeinterface/full.py | 9 +- .../postprocessing/amplitude_scalings.py | 29 +-- .../postprocessing/correlograms.py | 15 +- src/spikeinterface/postprocessing/isi.py | 1 - .../postprocessing/principal_component.py | 44 ++-- .../postprocessing/spike_amplitudes.py | 24 +- .../postprocessing/spike_locations.py | 33 ++- .../postprocessing/template_metrics.py | 12 +- .../postprocessing/template_similarity.py | 16 +- .../tests/common_extension_tests.py | 22 +- .../tests/test_amplitude_scalings.py | 3 +- .../postprocessing/tests/test_correlograms.py | 3 +- .../postprocessing/tests/test_isi.py | 1 - .../tests/test_principal_component.py | 7 +- .../tests/test_spike_amplitudes.py | 1 + .../tests/test_spike_locations.py | 15 +- .../tests/test_template_metrics.py | 2 - .../tests/test_template_similarity.py | 1 + .../tests/test_unit_localization.py | 5 +- .../postprocessing/unit_localization.py | 17 +- .../preprocessing/remove_artifacts.py | 17 +- .../qualitymetrics/misc_metrics.py | 36 ++- .../qualitymetrics/pca_metrics.py | 3 - .../quality_metric_calculator.py | 10 +- .../tests/test_metrics_functions.py | 53 +++-- .../qualitymetrics/tests/test_pca_metrics.py | 19 +- .../tests/test_quality_metric_calculator.py | 82 ++++--- .../sortingcomponents/matching/naive.py | 2 +- .../widgets/all_amplitudes_distributions.py | 11 +- src/spikeinterface/widgets/peak_activity.py | 1 - src/spikeinterface/widgets/sorting_summary.py | 14 +- .../widgets/spikes_on_traces.py | 15 +- .../widgets/template_metrics.py | 2 +- .../widgets/tests/test_widgets.py | 44 ++-- src/spikeinterface/widgets/unit_depths.py | 4 +- src/spikeinterface/widgets/unit_probe_map.py | 3 +- src/spikeinterface/widgets/unit_summary.py | 7 +- src/spikeinterface/widgets/unit_waveforms.py | 11 +- .../widgets/unit_waveforms_density_map.py | 4 +- 67 files changed, 732 insertions(+), 561 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 178273f90d..3c390434bb 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -303,7 +303,6 @@ def start_sorting_result_gt(self, case_keys=None, **kwargs): sorting_result.select_random_spikes(**select_params) sorting_result.compute("fast_templates", **job_kwargs) - def get_waveform_extractor(self, case_key=None, dataset_key=None): if case_key is not None: dataset_key = self.cases[case_key]["dataset"] diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index 3a87ab1832..aaf2898987 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -7,7 +7,6 @@ BaseRecording, BaseSorting, load_waveforms, - ) from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.core.generate import ( @@ -20,7 +19,6 @@ # TODO aurelien : this is still using the WaveformExtractor!!! can you change it to use SortingResult ? - class HybridUnitsRecording(InjectTemplatesRecording): """ Class for creating a hybrid recording where additional units are added diff --git a/src/spikeinterface/comparison/tests/test_hybrid.py b/src/spikeinterface/comparison/tests/test_hybrid.py index ab371a38bc..d1da0005f9 100644 --- a/src/spikeinterface/comparison/tests/test_hybrid.py +++ b/src/spikeinterface/comparison/tests/test_hybrid.py @@ -1,7 +1,7 @@ import pytest import shutil from pathlib import Path -from spikeinterface.core import extract_waveforms, load_waveforms,load_extractor +from spikeinterface.core import extract_waveforms, load_waveforms, load_extractor from spikeinterface.core.testing import check_recordings_equal from spikeinterface.comparison import ( create_hybrid_units_recording, @@ -38,7 +38,6 @@ def test_hybrid_units_recording(): print(wvf_extractor) print(wvf_extractor.sorting_result) - recording = wvf_extractor.recording templates = wvf_extractor.get_all_templates() templates[:, 0, :] = 0 diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 53925b08be..2a24e42f50 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -146,13 +146,16 @@ # SortingResult and ResultExtension from .sortingresult import SortingResult, ResultExtension, start_sorting_result, load_sorting_result from .result_core import ( - ComputeWaveforms, compute_waveforms, - ComputeTemplates, compute_templates, - ComputeFastTemplates, compute_fast_templates, - ComputeNoiseLevels, compute_noise_levels, + ComputeWaveforms, + compute_waveforms, + ComputeTemplates, + compute_templates, + ComputeFastTemplates, + compute_fast_templates, + ComputeNoiseLevels, + compute_noise_levels, ) # Important not for compatibility!! # This wil be uncommented after 0.100 from .waveforms_extractor_backwards_compatibility import extract_waveforms, load_waveforms - diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 84b8b4a90e..1494a3b03e 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1259,7 +1259,7 @@ def generate_single_fake_waveform( bins = np.arange(-n, n + 1) smooth_kernel = np.exp(-(bins**2) / (2 * smooth_size**2)) smooth_kernel /= np.sum(smooth_kernel) - # smooth_kernel = smooth_kernel[4:] + # smooth_kernel = smooth_kernel[4:] wf = np.convolve(wf, smooth_kernel, mode="same") # ensure the the peak to be extatly at nbefore (smooth can modify this) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 93f06751bf..78e9a82cf0 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -172,7 +172,13 @@ class SpikeRetriever(PeakSource): """ def __init__( - self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg", + self, + recording, + sorting, + channel_from_template=True, + extremum_channel_inds=None, + radius_um=50, + peak_sign="neg", include_spikes_in_margin=False, ): PipelineNode.__init__(self, recording, return_output=False) @@ -199,7 +205,6 @@ def __init__( for segment_index in range(recording.get_num_segments()): i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) - def get_trace_margin(self): return 0 @@ -212,7 +217,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] if self.include_spikes_in_margin: - i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin]) + i0, i1 = np.searchsorted( + peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin] + ) else: i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 1feabf8af1..60dab77994 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -14,12 +14,14 @@ from .recording_tools import get_noise_levels from .template import Templates + class ComputeWaveforms(ResultExtension): """ ResultExtension that extract some waveforms of each units. The sparsity is controlled by the SortingResult sparsity. """ + extension_name = "waveforms" depend_on = [] need_recording = True @@ -47,7 +49,7 @@ def _run(self, **job_kwargs): # retrieve spike vector and the sampling spikes = sorting.to_spike_vector() some_spikes = spikes[self.sorting_result.random_spikes_indices] - + if self.format == "binary_folder": # in that case waveforms are extacted directly in files file_path = self._get_binary_extension_folder() / "waveforms.npy" @@ -81,12 +83,13 @@ def _run(self, **job_kwargs): self.data["waveforms"] = all_waveforms - def _set_params(self, - ms_before: float = 1.0, - ms_after: float = 2.0, - return_scaled: bool = True, - dtype=None, - ): + def _set_params( + self, + ms_before: float = 1.0, + ms_after: float = 2.0, + return_scaled: bool = True, + dtype=None, + ): recording = self.sorting_result.recording if dtype is None: dtype = recording.get_dtype() @@ -121,17 +124,21 @@ def _select_extension_data(self, unit_ids): return new_data - def get_waveforms_one_unit(self, unit_id, force_dense: bool = False,): + def get_waveforms_one_unit( + self, + unit_id, + force_dense: bool = False, + ): sorting = self.sorting_result.sorting unit_index = sorting.id_to_index(unit_id) spikes = sorting.to_spike_vector() some_spikes = spikes[self.sorting_result.random_spikes_indices] spike_mask = some_spikes["unit_index"] == unit_index wfs = self.data["waveforms"][spike_mask, :, :] - + if self.sorting_result.sparsity is not None: chan_inds = self.sorting_result.sparsity.unit_id_to_channel_indices[unit_id] - wfs = wfs[:, :, :chan_inds.size] + wfs = wfs[:, :, : chan_inds.size] if force_dense: num_channels = self.get_num_channels() dense_wfs = np.zeros((wfs.shape[0], wfs.shape[1], num_channels), dtype=wfs.dtype) @@ -144,8 +151,6 @@ def _get_data(self): return self.data["waveforms"] - - compute_waveforms = ComputeWaveforms.function_factory() register_result_extension(ComputeWaveforms) @@ -153,23 +158,24 @@ def _get_data(self): class ComputeTemplates(ResultExtension): """ ResultExtension that compute templates (average, str, median, percentile, ...) - + This must be run after "waveforms" extension (`SortingResult.compute("waveforms")`) Note that when "waveforms" is already done, then the recording is not needed anymore for this extension. Note: by default only the average is computed. Other operator (std, median, percentile) can be computed on demand after the SortingResult.compute("templates") and then the data dict is updated on demand. - + """ + extension_name = "templates" depend_on = ["waveforms"] need_recording = False use_nodepipeline = False need_job_kwargs = False - def _set_params(self, operators = ["average", "std"]): + def _set_params(self, operators=["average", "std"]): assert isinstance(operators, list) for operator in operators: if isinstance(operator, str): @@ -186,21 +192,20 @@ def _set_params(self, operators = ["average", "std"]): nbefore=waveforms_extension.nbefore, nafter=waveforms_extension.nafter, return_scaled=waveforms_extension.params["return_scaled"], - ) + ) return params def _run(self): self._compute_and_append(self.params["operators"]) - def _compute_and_append(self, operators): unit_ids = self.sorting_result.unit_ids channel_ids = self.sorting_result.channel_ids waveforms_extension = self.sorting_result.get_extension("waveforms") waveforms = waveforms_extension.data["waveforms"] - + num_samples = waveforms.shape[1] - + for operator in operators: if isinstance(operator, str) and operator in ("average", "std", "median"): key = operator @@ -219,7 +224,7 @@ def _compute_and_append(self, operators): wfs = waveforms[spike_mask, :, :] if wfs.shape[0] == 0: continue - + for operator in operators: if operator == "average": arr = np.average(wfs, axis=0) @@ -239,7 +244,7 @@ def _compute_and_append(self, operators): self.data[key][unit_index, :, :] = arr else: channel_indices = self.sparsity.unit_id_to_channel_indices[unit_id] - self.data[key][unit_index, :, :][:, channel_indices] = arr[:, :channel_indices.size] + self.data[key][unit_index, :, :][:, channel_indices] = arr[:, : channel_indices.size] @property def nbefore(self): @@ -264,7 +269,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): else: assert percentile is not None, "You must provide percentile=..." key = f"pencentile_{percentile}" - + templates_array = self.data[key] if outputs == "numpy": @@ -280,7 +285,6 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): ) else: raise ValueError("outputs must be numpy or Templates") - def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True): """ @@ -331,7 +335,6 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save return np.array(templates) - compute_templates = ComputeTemplates.function_factory() register_result_extension(ComputeTemplates) @@ -339,8 +342,9 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save class ComputeFastTemplates(ResultExtension): """ ResultExtension which is similar to the extension "templates" (ComputeTemplates) **but only for average**. - This is way faster because it do not need "waveforms" to be computed first. + This is way faster because it do not need "waveforms" to be computed first. """ + extension_name = "fast_templates" depend_on = [] need_recording = True @@ -368,17 +372,20 @@ def _run(self, **job_kwargs): # retrieve spike vector and the sampling spikes = sorting.to_spike_vector() some_spikes = spikes[self.sorting_result.random_spikes_indices] - + return_scaled = self.params["return_scaled"] # TODO jobw_kwargs - self.data["average"] = estimate_templates_average(recording, some_spikes, unit_ids, self.nbefore, self.nafter, return_scaled=return_scaled, **job_kwargs) - - def _set_params(self, - ms_before: float = 1.0, - ms_after: float = 2.0, - return_scaled: bool = True, - ): + self.data["average"] = estimate_templates_average( + recording, some_spikes, unit_ids, self.nbefore, self.nafter, return_scaled=return_scaled, **job_kwargs + ) + + def _set_params( + self, + ms_before: float = 1.0, + ms_after: float = 2.0, + return_scaled: bool = True, + ): params = dict( ms_before=float(ms_before), ms_after=float(ms_after), @@ -403,7 +410,6 @@ def _get_data(self, outputs="numpy"): else: raise ValueError("outputs must be numpy or Templates") - def _select_extension_data(self, unit_ids): keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) @@ -412,6 +418,7 @@ def _select_extension_data(self, unit_ids): return new_data + compute_fast_templates = ComputeFastTemplates.function_factory() register_result_extension(ComputeFastTemplates) @@ -439,6 +446,7 @@ class ComputeNoiseLevels(ResultExtension): noise_levels: np.array noise level vector. """ + extension_name = "noise_levels" depend_on = [] need_recording = True @@ -449,7 +457,9 @@ def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, return_scaled=True, seed=None): - params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, return_scaled=return_scaled, seed=seed) + params = dict( + num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, return_scaled=return_scaled, seed=seed + ) return params def _select_extension_data(self, unit_ids): @@ -457,7 +467,7 @@ def _select_extension_data(self, unit_ids): return self.data def _run(self): - self.data["noise_levels"] = get_noise_levels(self.sorting_result.recording, **self.params) + self.data["noise_levels"] = get_noise_levels(self.sorting_result.recording, **self.params) def _get_data(self): return self.data["noise_levels"] @@ -465,4 +475,3 @@ def _get_data(self): register_result_extension(ComputeNoiseLevels) compute_noise_levels = ComputeNoiseLevels.function_factory() - diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index bccf313c87..2b4af70ebf 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -46,6 +46,7 @@ def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.arra return spike_trains + def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array): """ Similar to spike_vector_to_spike_trains but instead having the spike_trains (aka spike times) return @@ -90,8 +91,6 @@ def spike_vector_to_indices(spike_vector: list[np.array], unit_ids: np.array): return spike_indices - - def vector_to_list_of_spiketrain_numpy(sample_indices, unit_indices, num_units): """ Slower implementation of vetor_to_dict using numpy boolean mask. @@ -102,6 +101,7 @@ def vector_to_list_of_spiketrain_numpy(sample_indices, unit_indices, num_units): spike_trains.append(sample_indices[unit_indices == u]) return spike_trains + def get_numba_vector_to_list_of_spiketrain(): if hasattr(get_numba_vector_to_list_of_spiketrain, "_cached_numba_function"): return get_numba_vector_to_list_of_spiketrain._cached_numba_function diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index f4fc6c1a71..a95bec94a5 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -30,12 +30,10 @@ from .node_pipeline import run_node_pipeline - - # high level function -def start_sorting_result(sorting, recording, format="memory", folder=None, - sparse=True, sparsity=None, **sparsity_kwargs - ): +def start_sorting_result( + sorting, recording, format="memory", folder=None, sparse=True, sparsity=None, **sparsity_kwargs +): """ Create a SortingResult by pairing a Sorting and the corresponding Recording. @@ -96,15 +94,18 @@ def start_sorting_result(sorting, recording, format="memory", folder=None, if sparsity is not None: # some checks assert isinstance(sparsity, ChannelSparsity), "'sparsity' must be a ChannelSparsity object" - assert np.array_equal(sorting.unit_ids, sparsity.unit_ids), "start_sorting_result(): if external sparsity is given unit_ids must correspond" - assert np.array_equal(recording.channel_ids, recording.channel_ids), "start_sorting_result(): if external sparsity is given unit_ids must correspond" + assert np.array_equal( + sorting.unit_ids, sparsity.unit_ids + ), "start_sorting_result(): if external sparsity is given unit_ids must correspond" + assert np.array_equal( + recording.channel_ids, recording.channel_ids + ), "start_sorting_result(): if external sparsity is given unit_ids must correspond" elif sparse: - sparsity = estimate_sparsity( recording, sorting, **sparsity_kwargs) + sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs) else: sparsity = None - sorting_result = SortingResult.create( - sorting, recording, format=format, folder=folder, sparsity=sparsity) + sorting_result = SortingResult.create(sorting, recording, format=format, folder=folder, sparsity=sparsity) return sorting_result @@ -150,7 +151,10 @@ class SortingResult: the SortingResult object can be reload even if references to the original sorting and/or to the original recording are lost. """ - def __init__(self, sorting=None, recording=None, rec_attributes=None, format=None, sparsity=None, random_spikes_indices=None): + + def __init__( + self, sorting=None, recording=None, rec_attributes=None, format=None, sparsity=None, random_spikes_indices=None + ): # very fast init because checks are done in load and create self.sorting = sorting # self.recorsding will be a property @@ -180,13 +184,18 @@ def __repr__(self) -> str: ## create and load zone @classmethod - def create(cls, - sorting: BaseSorting, - recording: BaseRecording, - format: Literal["memory", "binary_folder", "zarr", ] = "memory", - folder=None, - sparsity=None, - ): + def create( + cls, + sorting: BaseSorting, + recording: BaseRecording, + format: Literal[ + "memory", + "binary_folder", + "zarr", + ] = "memory", + folder=None, + sparsity=None, + ): # some checks assert sorting.sampling_frequency == recording.sampling_frequency # check that multiple probes are non-overlapping @@ -223,13 +232,12 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto"): format = "zarr" else: format = "binary_folder" - + if format == "binary_folder": sortres = SortingResult.load_from_binary_folder(folder, recording=recording) elif format == "zarr": sortres = SortingResult.load_from_zarr(folder, recording=recording) - - + sortres.folder = folder if load_extensions: @@ -251,8 +259,9 @@ def create_memory(cls, sorting, recording, sparsity, rec_attributes): # a copy of sorting is created directly in shared memory format to avoid further duplication of spikes. sorting_copy = SharedMemorySorting.from_sorting(sorting, with_metadata=True) - sortres = SortingResult(sorting=sorting_copy, recording=recording, rec_attributes=rec_attributes, - format="memory", sparsity=sparsity) + sortres = SortingResult( + sorting=sorting_copy, recording=recording, rec_attributes=rec_attributes, format="memory", sparsity=sparsity + ) return sortres @classmethod @@ -266,7 +275,6 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attribut raise ValueError(f"Folder already exists {folder}") folder.mkdir(parents=True) - info_file = folder / f"spikeinterface_info.json" info = dict( version=spikeinterface.__version__, @@ -320,7 +328,7 @@ def load_from_binary_folder(cls, folder, recording=None): # load internal sorting copy and make it sharedmem sorting = SharedMemorySorting.from_sorting(NumpyFolderSorting(folder / "sorting"), with_metadata=True) - + # load recording if possible if recording is None: # try to load the recording if not provided @@ -350,7 +358,7 @@ def load_from_binary_folder(cls, folder, recording=None): rec_attributes["probegroup"] = probeinterface.read_probeinterface(probegroup_file) else: rec_attributes["probegroup"] = None - + # sparsity # sparsity_file = folder / "sparsity.json" sparsity_file = folder / "sparsity_mask.npy" @@ -374,12 +382,14 @@ def load_from_binary_folder(cls, folder, recording=None): rec_attributes=rec_attributes, format="binary_folder", sparsity=sparsity, - random_spikes_indices=random_spikes_indices) + random_spikes_indices=random_spikes_indices, + ) return sortres - + def _get_zarr_root(self, mode="r+"): import zarr + zarr_root = zarr.open(self.folder, mode=mode) return zarr_root @@ -399,13 +409,9 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): zarr_root = zarr.open(folder, mode="w") - info = dict( - version=spikeinterface.__version__, - dev_mode=spikeinterface.DEV_MODE, - object="SortingResult" - ) + info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingResult") zarr_root.attrs["spikeinterface_info"] = check_json(info) - + # the recording rec_dict = recording.to_dict(relative_to=folder, recursive=True) zarr_rec = np.array([rec_dict], dtype=object) @@ -416,8 +422,9 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) else: - warnings.warn("SortingResult with zarr : the Recording is not json serializable, the recording link will be lost for futur load") - + warnings.warn( + "SortingResult with zarr : the Recording is not json serializable, the recording link will be lost for futur load" + ) # sorting provenance sort_dict = sorting.to_dict(relative_to=folder, recursive=True) @@ -442,7 +449,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): else: rec_attributes = rec_attributes.copy() probegroup = rec_attributes.pop("probegroup") - + recording_info.attrs["recording_attributes"] = check_json(rec_attributes) # recording_info.create_dataset("recording_attributes", data=check_json(rec_attributes), object_codec=numcodecs.JSON()) @@ -457,6 +464,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): # write sorting copy from .zarrextractors import add_sorting_to_zarr_group + # Alessio : we need to find a way to propagate compressor for all steps. # kwargs = dict(compressor=...) zarr_kwargs = dict() @@ -464,10 +472,10 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): recording_info = zarr_root.create_group("extensions") - @classmethod def load_from_zarr(cls, folder, recording=None): import zarr + folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" @@ -475,13 +483,15 @@ def load_from_zarr(cls, folder, recording=None): # load internal sorting and make it sharedmem # TODO propagate storage_options - sorting = SharedMemorySorting.from_sorting(ZarrSortingExtractor(folder, zarr_group="sorting"), with_metadata=True) - + sorting = SharedMemorySorting.from_sorting( + ZarrSortingExtractor(folder, zarr_group="sorting"), with_metadata=True + ) + # load recording if possible if recording is None: rec_dict = zarr_root["recording"][0] try: - + recording = load_extractor(rec_dict, base_folder=folder) except: recording = None @@ -518,11 +528,11 @@ def load_from_zarr(cls, folder, recording=None): rec_attributes=rec_attributes, format="zarr", sparsity=sparsity, - random_spikes_indices=random_spikes_indices) + random_spikes_indices=random_spikes_indices, + ) return sortres - def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingResult": """ Internal used by both save_as(), copy() and select_units() which are more or less the same. @@ -532,7 +542,7 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> recording = self.recording else: recording = None - + if self.sparsity is not None and unit_ids is None: sparsity = self.sparsity elif self.sparsity is not None and unit_ids is not None: @@ -546,7 +556,7 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> if sorting_provenance is None: # if the original sorting objetc is not available anymore (kilosort folder deleted, ....), take the copy sorting_provenance = self.sorting - + if unit_ids is not None: # when only some unit_ids then the sorting must be sliced # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! @@ -582,7 +592,7 @@ def save_as(self, format="memory", folder=None) -> "SortingResult": """ Save SortingResult object into another format. Uselfull for memory to zarr or memory to binray. - + Note that the recording provenance or sorting provenance can be lost. Mainly propagate the copied sorting and recording property. @@ -596,7 +606,6 @@ def save_as(self, format="memory", folder=None) -> "SortingResult": """ return self._save_or_select(format=format, folder=folder, unit_ids=None) - def select_units(self, unit_ids, format="memory", folder=None) -> "SortingResult": """ This method is equivalent to `save_as()`but with a subset of units. @@ -610,7 +619,7 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingResult The unit ids to keep in the new WaveformExtractor object folder : Path or None The new folder where selected waveforms are copied - format: + format: a Returns ------- @@ -631,7 +640,6 @@ def is_read_only(self) -> bool: return False return not os.access(self.folder, os.W_OK) - ## map attribute and property zone @property @@ -657,14 +665,14 @@ def has_recording(self) -> bool: def is_sparse(self) -> bool: return self.sparsity is not None - + def get_sorting_provenance(self): """ Get the original sorting if possible otherwise return None """ if self.format == "memory": # the orginal sorting provenance is not keps in that case - sorting_provenance = None + sorting_provenance = None elif self.format == "binary_folder": for type in ("json", "pickle"): @@ -735,14 +743,12 @@ def get_recording_property(self, key) -> np.ndarray: def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) - + def get_dtype(self): return self.rec_attributes["dtype"] ## extensions zone - - def compute(self, input, save=True, **kwargs): """ Compute one extension or several extension. @@ -775,7 +781,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): If not then the extension will only live in memory as long as the object is deleted. save=False is convinient to try some parameters without changing an already saved extension. - **kwargs: + **kwargs: All other kwargs are transimited to extension.set_params() or job_kwargs Returns @@ -792,11 +798,8 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): """ - - extension_class = get_extension_class(extension_name) - if extension_class.need_job_kwargs: params, job_kwargs = split_job_kwargs(kwargs) else: @@ -809,15 +812,15 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): for dependency_name in extension_class.depend_on: if "|" in dependency_name: # at least one extension must be done : usefull for "templates|fast_templates" for instance - ok = any(self.get_extension(name) is not None for name in dependency_name.split("|")) + ok = any(self.get_extension(name) is not None for name in dependency_name.split("|")) else: ok = self.get_extension(dependency_name) is not None assert ok, f"Extension {extension_name} need {dependency_name} to be computed first" - + extension_instance = extension_class(self) extension_instance.set_params(save=save, **params) extension_instance.run(save=save, **job_kwargs) - + self.extensions[extension_name] = extension_instance # TODO : need discussion @@ -858,7 +861,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): if not extension_class.use_nodepipeline: pipeline_mode = False break - + if not pipeline_mode: # simple loop for extension_name, extension_params in extensions.items(): @@ -868,7 +871,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): else: self.compute_one_extension(extension_name, save=save, **extension_params) else: - + all_nodes = [] result_routage = [] extension_instances = {} @@ -895,15 +898,11 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): extension_name, variable_name = result_routage[r] extension_instances[extension_name].data[variable_name] = result - for extension_name, extension_instance in extension_instances.items(): self.extensions[extension_name] = extension_instance if save: extension_instance.save() - - - def get_saved_extension_names(self): """ Get extension saved in folder or zarr that can be loaded. @@ -921,13 +920,16 @@ def get_saved_extension_names(self): saved_extension_names = [] for extension_class in _possible_extensions: extension_name = extension_class.extension_name - + if self.format == "binary_folder": - extension_folder = self.folder / "extensions" /extension_name + extension_folder = self.folder / "extensions" / extension_name is_saved = extension_folder.is_dir() and (extension_folder / "params.json").is_file() elif self.format == "zarr": if extension_group is not None: - is_saved = extension_name in extension_group.keys() and "params" in extension_group[extension_name].attrs.keys() + is_saved = ( + extension_name in extension_group.keys() + and "params" in extension_group[extension_name].attrs.keys() + ) else: is_saved = False if is_saved: @@ -941,7 +943,7 @@ def get_extension(self, extension_name: str): If not loaded then load is automatic. Return None if the extension is not computed yet (this avoid the use of has_extension() and then get it) - + """ if extension_name in self.extensions: return self.extensions[extension_name] @@ -949,7 +951,7 @@ def get_extension(self, extension_name: str): if self.has_extension(extension_name): self.load_extension(extension_name) return self.extensions[extension_name] - + return None def load_extension(self, extension_name: str): @@ -967,7 +969,9 @@ def load_extension(self, extension_name: str): The loaded instance of the extension """ - assert self.format != "memory", "SortingResult.load_extension() do not work for format='memory' use SortingResult.get_extension()instead" + assert ( + self.format != "memory" + ), "SortingResult.load_extension() do not work for format='memory' use SortingResult.get_extension()instead" extension_class = get_extension_class(extension_name) @@ -1005,7 +1009,7 @@ def get_loaded_extension_names(self): Return the loaded or already computed extensions names. """ return list(self.extensions.keys()) - + def has_extension(self, extension_name: str) -> bool: """ Check if the extension exists in memory (dict) or in the folder or in zarr. @@ -1024,7 +1028,9 @@ def select_random_spikes(self, **random_kwargs): # random_spikes_indices is a vector that refer to the spike vector of the sorting in absolut index assert self.random_spikes_indices is None, "select random spikes is already computed" - self.random_spikes_indices = random_spikes_selection(self.sorting, self.rec_attributes["num_samples"], **random_kwargs) + self.random_spikes_indices = random_spikes_selection( + self.sorting, self.rec_attributes["num_samples"], **random_kwargs + ) if self.format == "binary_folder": np.save(self.folder / "random_spikes_indices.npy", self.random_spikes_indices) @@ -1038,8 +1044,12 @@ def get_selected_indices_in_spike_train(self, unit_id, segment_index): assert self.random_spikes_indices is not None, "random spikes selection is not computeds" unit_index = self.sorting.id_to_index(unit_id) spikes = self.sorting.to_spike_vector() - spike_indices_in_seg = np.flatnonzero((spikes["segment_index"] == segment_index) & (spikes["unit_index"] == unit_index)) - common_element, inds_left, inds_right = np.intersect1d(spike_indices_in_seg, self.random_spikes_indices, return_indices=True) + spike_indices_in_seg = np.flatnonzero( + (spikes["segment_index"] == segment_index) & (spikes["unit_index"] == unit_index) + ) + common_element, inds_left, inds_right = np.intersect1d( + spike_indices_in_seg, self.random_spikes_indices, return_indices=True + ) selected_spikes_in_spike_train = inds_left return selected_spikes_in_spike_train @@ -1047,6 +1057,7 @@ def get_selected_indices_in_spike_train(self, unit_id, segment_index): global _possible_extensions _possible_extensions = [] + def register_result_extension(extension_class): """ This maintains a list of possible extensions that are available. @@ -1055,7 +1066,7 @@ def register_result_extension(extension_class): For instance with: import spikeinterface as si only one extension will be available - but with + but with import spikeinterface.postprocessing more extensions will be available """ @@ -1088,7 +1099,9 @@ def get_extension_class(extension_name: str): """ global _possible_extensions extensions_dict = {ext.extension_name: ext for ext in _possible_extensions} - assert extension_name in extensions_dict, f"Extension '{extension_name}' is not registered, please import related module before" + assert ( + extension_name in extensions_dict + ), f"Extension '{extension_name}' is not registered, please import related module before" ext_class = extensions_dict[extension_name] return ext_class @@ -1125,10 +1138,10 @@ class ResultExtension: All ResultExtension will have a function associate for instance (this use the function_factory): comptute_unit_location(sorting_result, ...) will be equivalent to sorting_result.compute("unit_location", ...) - - """ - + + """ + extension_name = None depend_on = [] need_recording = False @@ -1158,7 +1171,7 @@ def _set_params(self, **params): def _select_extension_data(self, unit_ids): # must be implemented in subclass raise NotImplementedError - + def _get_pipeline_nodes(self): # must be implemented in subclass only if use_nodepipeline=True raise NotImplementedError @@ -1167,7 +1180,7 @@ def _get_data(self): # must be implemented in subclass raise NotImplementedError - # + # ####### @classmethod @@ -1180,9 +1193,10 @@ def function_factory(cls): class FuncWrapper: def __init__(self, extension_name): self.extension_name = extension_name + def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): from .waveforms_extractor_backwards_compatibility import MockWaveformExtractor - + if isinstance(sorting_result, MockWaveformExtractor): # backward compatibility with WaveformsExtractor sorting_result = sorting_result.sorting_result @@ -1192,12 +1206,13 @@ def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): if load_if_exists is not None: # backward compatibility with "load_if_exists" - warnings.warn(f"compute_{cls.extension_name}(..., load_if_exists=True/False) is kept for backward compatibility but should not be used anymore") + warnings.warn( + f"compute_{cls.extension_name}(..., load_if_exists=True/False) is kept for backward compatibility but should not be used anymore" + ) assert isinstance(load_if_exists, bool) if load_if_exists: ext = sorting_result.get_extension(self.extension_name) return ext - ext = sorting_result.compute(cls.extension_name, *args, **kwargs) return ext.get_data() @@ -1222,7 +1237,7 @@ def sorting_result(self): @property def format(self): return self.sorting_result.format - + @property def sparsity(self): return self.sorting_result.sparsity @@ -1230,13 +1245,12 @@ def sparsity(self): @property def folder(self): return self.sorting_result.folder - + def _get_binary_extension_folder(self): - extension_folder = self.folder / "extensions" /self.extension_name + extension_folder = self.folder / "extensions" / self.extension_name return extension_folder - - def _get_zarr_extension_group(self, mode='r+'): + def _get_zarr_extension_group(self, mode="r+"): zarr_root = self.sorting_result._get_zarr_root(mode=mode) extension_group = zarr_root["extensions"][self.extension_name] return extension_group @@ -1257,7 +1271,7 @@ def load_params(self): params = json.load(f) elif self.format == "zarr": - extension_group = self._get_zarr_extension_group(mode='r') + extension_group = self._get_zarr_extension_group(mode="r") assert "params" in extension_group.attrs, f"No params file in extension {self.extension_name} folder" params = extension_group.attrs["params"] @@ -1280,25 +1294,27 @@ def load_data(self): ext_data = np.load(ext_data_file) elif ext_data_file.suffix == ".csv": import pandas as pd + ext_data = pd.read_csv(ext_data_file, index_col=0) elif ext_data_file.suffix == ".pkl": ext_data = pickle.load(ext_data_file.open("rb")) else: continue self.data[ext_data_name] = ext_data - + elif self.format == "zarr": # Alessio # TODO: we need decide if we make a copy to memory or keep the lazy loading. For binary_folder it used to be lazy with memmap # but this make the garbage complicated when a data is hold by a plot but the o SortingResult is delete # lets talk - extension_group = self._get_zarr_extension_group(mode='r') + extension_group = self._get_zarr_extension_group(mode="r") for ext_data_name in extension_group.keys(): ext_data_ = extension_group[ext_data_name] if "dict" in ext_data_.attrs: ext_data = ext_data_[0] elif "dataframe" in ext_data_.attrs: import xarray + ext_data = xarray.open_zarr( ext_data_.store, group=f"{extension_group.name}/{ext_data_name}" ).to_pandas() @@ -1340,7 +1356,7 @@ def _save_data(self, **kwargs): if self.sorting_result.is_read_only(): raise ValueError(f"The SortingResult is read only save extension {self.extension_name} is not possible") - + if self.format == "binary_folder": import pandas as pd @@ -1366,16 +1382,16 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - + import pandas as pd import numcodecs - + extension_group = self._get_zarr_extension_group(mode="r+") compressor = kwargs.get("compressor", None) if compressor is None: compressor = get_default_zarr_compressor() - + for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: del extension_group[ext_data_name] @@ -1414,6 +1430,7 @@ def _reset_extension_folder(self): elif self.format == "zarr": import zarr + zarr_root = zarr.open(self.folder, mode="r+") extension_group = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) @@ -1426,7 +1443,6 @@ def reset(self): self.params = None self.data = dict() - def set_params(self, save=True, **params): """ Set parameters for the extension and @@ -1457,7 +1473,6 @@ def _save_params(self): # ), "'sparsity' parameter must be a ChannelSparsity object!" # params_to_save["sparsity"] = params_to_save["sparsity"].to_dict() - if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() extension_folder.mkdir(exist_ok=True, parents=True) @@ -1468,7 +1483,9 @@ def _save_params(self): extension_group.attrs["params"] = check_json(params_to_save) def get_pipeline_nodes(self): - assert self.use_nodepipeline, "ResultExtension.get_pipeline_nodes() must be called only when use_nodepipeline=True" + assert ( + self.use_nodepipeline + ), "ResultExtension.get_pipeline_nodes() must be called only when use_nodepipeline=True" return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 62bf75c25d..b7440b79d6 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -10,7 +10,6 @@ from .waveform_tools import estimate_templates_average - _sparsity_doc = """ method: str * "best_channels": N best channels with the largest amplitude. Use the "num_channels" argument to specify the @@ -276,7 +275,9 @@ def from_best_channels(cls, templates_or_sorting_result, num_channels, peak_sign """ from .template_tools import get_template_amplitudes - mask = np.zeros((templates_or_sorting_result.unit_ids.size, templates_or_sorting_result.channel_ids.size), dtype="bool") + mask = np.zeros( + (templates_or_sorting_result.unit_ids.size, templates_or_sorting_result.channel_ids.size), dtype="bool" + ) peak_values = get_template_amplitudes(templates_or_sorting_result, peak_sign=peak_sign) for unit_ind, unit_id in enumerate(templates_or_sorting_result.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] @@ -292,7 +293,9 @@ def from_radius(cls, templates_or_sorting_result, radius_um, peak_sign="neg"): """ from .template_tools import get_template_extremum_channel - mask = np.zeros((templates_or_sorting_result.unit_ids.size, templates_or_sorting_result.channel_ids.size), dtype="bool") + mask = np.zeros( + (templates_or_sorting_result.unit_ids.size, templates_or_sorting_result.channel_ids.size), dtype="bool" + ) channel_locations = templates_or_sorting_result.get_channel_locations() distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) best_chan = get_template_extremum_channel(templates_or_sorting_result, peak_sign=peak_sign, outputs="index") @@ -312,8 +315,9 @@ def from_snr(cls, templates_or_sorting_result, threshold, noise_levels=None, pea from .sortingresult import SortingResult from .template import Templates - - assert templates_or_sorting_result.sparsity is None, "To compute sparsity you need a dense SortingResult or Templates" + assert ( + templates_or_sorting_result.sparsity is None + ), "To compute sparsity you need a dense SortingResult or Templates" unit_ids = templates_or_sorting_result.unit_ids channel_ids = templates_or_sorting_result.channel_ids @@ -321,14 +325,18 @@ def from_snr(cls, templates_or_sorting_result, threshold, noise_levels=None, pea if isinstance(templates_or_sorting_result, SortingResult): ext = templates_or_sorting_result.get_extension("noise_levels") assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" - assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" + assert ext.params[ + "return_scaled" + ], "To compute sparsity from snr you need return_scaled=True for extensions" noise_levels = ext.data["noise_levels"] elif isinstance(templates_or_sorting_result, Templates): assert noise_levels is not None mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - peak_values = get_template_amplitudes(templates_or_sorting_result, peak_sign=peak_sign, mode="extremum", return_scaled=True) + peak_values = get_template_amplitudes( + templates_or_sorting_result, peak_sign=peak_sign, mode="extremum", return_scaled=True + ) for unit_ind, unit_id in enumerate(unit_ids): chan_inds = np.nonzero((np.abs(peak_values[unit_id]) / noise_levels) >= threshold) @@ -342,7 +350,9 @@ def from_ptp(cls, templates_or_sorting_result, threshold, noise_levels=None): Use the "threshold" argument to specify the SNR threshold. """ - assert templates_or_sorting_result.sparsity is None, "To compute sparsity you need a dense SortingResult or Templates" + assert ( + templates_or_sorting_result.sparsity is None + ), "To compute sparsity you need a dense SortingResult or Templates" from .template_tools import get_template_amplitudes from .sortingresult import SortingResult @@ -354,12 +364,15 @@ def from_ptp(cls, templates_or_sorting_result, threshold, noise_levels=None): if isinstance(templates_or_sorting_result, SortingResult): ext = templates_or_sorting_result.get_extension("noise_levels") assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" - assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" + assert ext.params[ + "return_scaled" + ], "To compute sparsity from snr you need return_scaled=True for extensions" noise_levels = ext.data["noise_levels"] elif isinstance(templates_or_sorting_result, Templates): assert noise_levels is not None from .template_tools import _get_dense_templates_array + mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") templates_array = _get_dense_templates_array(templates_or_sorting_result, return_scaled=True) @@ -376,7 +389,7 @@ def from_energy(cls, sorting_result, threshold): Construct sparsity from a threshold based on per channel energy ratio. Use the "threshold" argument to specify the SNR threshold. """ - assert sorting_result.sparsity is None, "To compute sparsity with energy you need a dense SortingResult" + assert sorting_result.sparsity is None, "To compute sparsity with energy you need a dense SortingResult" mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") @@ -407,8 +420,12 @@ def from_property(cls, sorting_result, by_property): Use the "by_property" argument to specify the property name. """ # check consistency - assert by_property in sorting_result.recording.get_property_keys(), f"Property {by_property} is not a recording property" - assert by_property in sorting_result.sorting.get_property_keys(), f"Property {by_property} is not a sorting property" + assert ( + by_property in sorting_result.recording.get_property_keys() + ), f"Property {by_property} is not a recording property" + assert ( + by_property in sorting_result.sorting.get_property_keys() + ), f"Property {by_property} is not a sorting property" mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") rec_by = sorting_result.recording.split_by(by_property) @@ -469,13 +486,18 @@ def compute_sparsity( templates_or_sorting_result = templates_or_sorting_result.sorting_result if method in ("best_channels", "radius", "snr", "ptp"): - assert isinstance(templates_or_sorting_result, (Templates, SortingResult)), f"compute_sparsity(method='{method}') need Templates or SortingResult" + assert isinstance( + templates_or_sorting_result, (Templates, SortingResult) + ), f"compute_sparsity(method='{method}') need Templates or SortingResult" else: - assert isinstance(templates_or_sorting_result, SortingResult), f"compute_sparsity(method='{method}') need SortingResult" - - if method in ("snr", "ptp") and isinstance(templates_or_sorting_result, Templates): - assert noise_levels is not None, f"compute_sparsity(..., method='{method}') with Templates need noise_levels as input" + assert isinstance( + templates_or_sorting_result, SortingResult + ), f"compute_sparsity(method='{method}') need SortingResult" + if method in ("snr", "ptp") and isinstance(templates_or_sorting_result, Templates): + assert ( + noise_levels is not None + ), f"compute_sparsity(..., method='{method}') with Templates need noise_levels as input" if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" @@ -485,10 +507,16 @@ def compute_sparsity( sparsity = ChannelSparsity.from_radius(templates_or_sorting_result, radius_um, peak_sign=peak_sign) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_snr(templates_or_sorting_result, threshold, noise_levels=noise_levels, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_snr( + templates_or_sorting_result, threshold, noise_levels=noise_levels, peak_sign=peak_sign + ) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp(templates_or_sorting_result, threshold, noise_levels=noise_levels, ) + sparsity = ChannelSparsity.from_ptp( + templates_or_sorting_result, + threshold, + noise_levels=noise_levels, + ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_energy(templates_or_sorting_result, threshold) @@ -588,7 +616,7 @@ def estimate_sparsity( nafter, return_scaled=False, job_name="estimate_sparsity", - **job_kwargs + **job_kwargs, ) templates = Templates( templates_array=templates_array, diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 2a931a8c88..098b8d7237 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -14,13 +14,17 @@ def _get_dense_templates_array(one_object, return_scaled=True): elif isinstance(one_object, SortingResult): ext = one_object.get_extension("templates") if ext is not None: - templates_array = ext.data["average"] - assert return_scaled == ext.params["return_scaled"], f"templates have been extracted with return_scaled={not return_scaled} you cannot get then with return_scaled={return_scaled}" + templates_array = ext.data["average"] + assert ( + return_scaled == ext.params["return_scaled"] + ), f"templates have been extracted with return_scaled={not return_scaled} you cannot get then with return_scaled={return_scaled}" else: ext = one_object.get_extension("fast_templates") - assert return_scaled == ext.params["return_scaled"], f"fast_templates have been extracted with return_scaled={not return_scaled} you cannot get then with return_scaled={return_scaled}" + assert ( + return_scaled == ext.params["return_scaled"] + ), f"fast_templates have been extracted with return_scaled={not return_scaled} you cannot get then with return_scaled={return_scaled}" if ext is not None: - templates_array = ext.data["average"] + templates_array = ext.data["average"] else: raise ValueError("SortingResult need extension 'templates' or 'fast_templates' to be computed") else: @@ -28,6 +32,7 @@ def _get_dense_templates_array(one_object, return_scaled=True): return templates_array + def _get_nbefore(one_object): if isinstance(one_object, Templates): return one_object.nbefore @@ -43,10 +48,11 @@ def _get_nbefore(one_object): raise ValueError("Input should be Templates or SortingResult or SortingResult") - - def get_template_amplitudes( - templates_or_sorting_result, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", return_scaled: bool = True + templates_or_sorting_result, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" = "extremum", + return_scaled: bool = True, ): """ Get amplitude per channel for each unit. @@ -150,9 +156,7 @@ def get_template_extremum_channel( return extremum_channels_index -def get_template_extremum_channel_peak_shift( - templates_or_sorting_result, peak_sign: "neg" | "pos" | "both" = "neg" -): +def get_template_extremum_channel_peak_shift(templates_or_sorting_result, peak_sign: "neg" | "pos" | "both" = "neg"): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 2e25aa618e..a4ae651de6 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -14,7 +14,7 @@ PipelineNode, ExtractDenseWaveforms, sorting_to_peaks, - spike_peak_dtype + spike_peak_dtype, ) diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_result_core.py index c19ff6c4e8..4c562ffb0d 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_result_core.py @@ -16,7 +16,10 @@ def get_sorting_result(format="memory", sparse=True): recording, sorting = generate_ground_truth_recording( - durations=[30.0], sampling_frequency=16000.0, num_channels=20, num_units=5, + durations=[30.0], + sampling_frequency=16000.0, + num_channels=20, + num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), generate_unit_locations_kwargs=dict( margin_um=5.0, @@ -39,8 +42,10 @@ def get_sorting_result(format="memory", sparse=True): folder = cache_folder / f"test_ComputeWaveforms.zarr" if folder and folder.exists(): shutil.rmtree(folder) - - sorting_result = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=sparse, sparsity=None) + + sorting_result = start_sorting_result( + sorting, recording, format=format, folder=folder, sparse=sparse, sparsity=None + ) return sorting_result @@ -48,7 +53,7 @@ def get_sorting_result(format="memory", sparse=True): def _check_result_extension(sorting_result, extension_name): # select unit_ids to several format for format in ("memory", "binary_folder", "zarr"): - # for format in ("memory", ): + # for format in ("memory", ): if format != "memory": if format == "zarr": folder = cache_folder / f"test_SortingResult_{extension_name}_select_units_with_{format}.zarr" @@ -86,32 +91,37 @@ def test_ComputeTemplates(format, sparse): sorting_result = get_sorting_result(format=format, sparse=sparse) sorting_result.select_random_spikes(max_spikes_per_unit=20, seed=2205) - + with pytest.raises(AssertionError): # This require "waveforms first and should trig an error sorting_result.compute("templates") - + job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) sorting_result.compute("waveforms", **job_kwargs) # compute some operators - sorting_result.compute("templates", operators=["average", "std", ("percentile", 95.),]) + sorting_result.compute( + "templates", + operators=[ + "average", + "std", + ("percentile", 95.0), + ], + ) # ask for more operator later ext = sorting_result.get_extension("templates") templated_median = ext.get_templates(operator="median") - templated_per_5 = ext.get_templates(operator="percentile", percentile=5.) + templated_per_5 = ext.get_templates(operator="percentile", percentile=5.0) # they all should be in data data = sorting_result.get_extension("templates").data - for k in ['average', 'std', 'median', 'pencentile_5.0', 'pencentile_95.0']: + for k in ["average", "std", "median", "pencentile_5.0", "pencentile_95.0"]: assert k in data.keys() assert data[k].shape[0] == sorting_result.unit_ids.size assert data[k].shape[2] == sorting_result.channel_ids.size assert np.any(data[k] > 0) - - # import matplotlib.pyplot as plt # for unit_index, unit_id in enumerate(sorting_result.unit_ids): # fig, ax = plt.subplots() @@ -132,8 +142,8 @@ def test_ComputeFastTemplates(format, sparse): # TODO check this because this is not passing with n_jobs=2 job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) - ms_before=1.0 - ms_after=2.5 + ms_before = 1.0 + ms_after = 2.5 sorting_result.select_random_spikes(max_spikes_per_unit=20, seed=2205) sorting_result.compute("fast_templates", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) @@ -144,7 +154,12 @@ def test_ComputeFastTemplates(format, sparse): other_sorting_result = get_sorting_result(format=format, sparse=False) other_sorting_result.select_random_spikes(max_spikes_per_unit=20, seed=2205) other_sorting_result.compute("waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) - other_sorting_result.compute("templates", operators=["average",]) + other_sorting_result.compute( + "templates", + operators=[ + "average", + ], + ) templates0 = sorting_result.get_extension("fast_templates").data["average"] templates1 = other_sorting_result.get_extension("templates").data["average"] @@ -160,6 +175,7 @@ def test_ComputeFastTemplates(format, sparse): # ax.legend() # plt.show() + @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeNoiseLevels(format, sparse): @@ -172,7 +188,7 @@ def test_ComputeNoiseLevels(format, sparse): assert noise_levels.shape[0] == sorting_result.channel_ids.size -if __name__ == '__main__': +if __name__ == "__main__": # test_ComputeWaveforms(format="memory", sparse=True) # test_ComputeWaveforms(format="memory", sparse=False) # test_ComputeWaveforms(format="binary_folder", sparse=True) diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index fe169e7448..6100da07dd 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -5,7 +5,11 @@ from spikeinterface.core import NumpySorting from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core.sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, spike_vector_to_indices +from spikeinterface.core.sorting_tools import ( + spike_vector_to_spike_trains, + random_spikes_selection, + spike_vector_to_indices, +) @pytest.mark.skipif( @@ -20,6 +24,7 @@ def test_spike_vector_to_spike_trains(): for unit_index, unit_id in enumerate(sorting.unit_ids): assert np.array_equal(spike_trains[0][unit_id], sorting.get_unit_spike_train(unit_id=unit_id, segment_index=0)) + def test_spike_vector_to_indices(): sorting = NumpySorting.from_unit_dict({1: np.array([0, 51, 108]), 5: np.array([23, 87])}, 30_000) spike_vector = sorting.to_spike_vector(concatenated=False) @@ -31,7 +36,7 @@ def test_spike_vector_to_indices(): inds = spike_indices[segment_index][unit_id] assert np.array_equal( spike_vector[segment_index][inds]["sample_index"], - sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index), ) diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index bde0210e37..c7ef3e3776 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -17,7 +17,10 @@ def get_dataset(): recording, sorting = generate_ground_truth_recording( - durations=[30.0], sampling_frequency=16000.0, num_channels=10, num_units=5, + durations=[30.0], + sampling_frequency=16000.0, + num_channels=10, + num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), seed=2205, @@ -34,7 +37,6 @@ def test_SortingResult_memory(): _check_sorting_results(sortres, sorting) - def test_SortingResult_binary_folder(): recording, sorting = get_dataset() @@ -42,7 +44,9 @@ def test_SortingResult_binary_folder(): if folder.exists(): shutil.rmtree(folder) - sortres = start_sorting_result(sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None) + sortres = start_sorting_result( + sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None + ) sortres = load_sorting_result(folder, format="auto") _check_sorting_results(sortres, sorting) @@ -54,12 +58,11 @@ def test_SortingResult_zarr(): if folder.exists(): shutil.rmtree(folder) - sortres = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) + sortres = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) sortres = load_sorting_result(folder, format="auto") _check_sorting_results(sortres, sorting) - def _check_sorting_results(sortres, original_sorting): print() @@ -73,7 +76,7 @@ def _check_sorting_results(sortres, original_sorting): probe = sortres.get_probe() sparsity = sortres.sparsity - + # compute sortres.compute("dummy", param1=5.5) # equivalent @@ -89,7 +92,6 @@ def _check_sorting_results(sortres, original_sorting): ext = sortres.get_extension("dummy") assert ext is None - assert sortres.has_recording() if sortres.random_spikes_indices is None: @@ -115,7 +117,7 @@ def _check_sorting_results(sortres, original_sorting): sortres2 = sortres.save_as(format=format, folder=folder) ext = sortres2.get_extension("dummy") assert ext is not None - + data = sortres2.get_extension("dummy").data assert "result_one" in data assert data["result_two"].size == original_sorting.to_spike_vector().size @@ -150,11 +152,11 @@ class DummyResultExtension(ResultExtension): need_recording = False use_nodepipeline = False - def _set_params(self, param0="yep", param1=1.2, param2=[1,2, 3.]): + def _set_params(self, param0="yep", param1=1.2, param2=[1, 2, 3.0]): params = dict(param0=param0, param1=param1, param2=param2) params["more_option"] = "yep" return params - + def _run(self, **kwargs): # print("dummy run") self.data["result_one"] = "abcd" @@ -162,7 +164,7 @@ def _run(self, **kwargs): # and represent nothing (the trick is to use unit_index for testing slice) spikes = self.sorting_result.sorting.to_spike_vector() self.data["result_two"] = spikes["unit_index"].copy() - + def _select_extension_data(self, unit_ids): keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) @@ -175,10 +177,11 @@ def _select_extension_data(self, unit_ids): new_data["result_two"] = self.data["result_two"][keep_spike_mask] return new_data - + def _get_data(self): return self.data["result_one"] + compute_dummy = DummyResultExtension.function_factory() diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 3f702971d6..361a06ece0 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -163,7 +163,7 @@ def get_dataset(): def test_estimate_sparsity(): recording, sorting = get_dataset() num_units = sorting.unit_ids.size - + # small radius should give a very sparse = one channel per unit sparsity = estimate_sparsity( recording, @@ -195,9 +195,10 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) + def test_compute_sparsity(): recording, sorting = get_dataset() - + sorting_result = start_sorting_result(sorting=sorting, recording=recording, sparse=False) sorting_result.select_random_spikes() sorting_result.compute("fast_templates", return_scaled=True) @@ -205,10 +206,9 @@ def test_compute_sparsity(): # this is needed for method="energy" sorting_result.compute("waveforms", return_scaled=True) - # using object SortingResult sparsity = compute_sparsity(sorting_result, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(sorting_result, method="radius", radius_um=50., peak_sign="neg") + sparsity = compute_sparsity(sorting_result, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(sorting_result, method="snr", threshold=5, peak_sign="neg") sparsity = compute_sparsity(sorting_result, method="ptp", threshold=5) sparsity = compute_sparsity(sorting_result, method="energy", threshold=5) @@ -218,7 +218,7 @@ def test_compute_sparsity(): templates = sorting_result.get_extension("fast_templates").get_data(outputs="Templates") noise_levels = sorting_result.get_extension("noise_levels").get_data() sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(templates, method="radius", radius_um=50., peak_sign="neg") + sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index 712f87f1e7..db15bbfbea 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -14,7 +14,10 @@ def get_sorting_result(): recording, sorting = generate_ground_truth_recording( - durations=[10.0, 5.0], sampling_frequency=10_000.0, num_channels=4, num_units=10, + durations=[10.0, 5.0], + sampling_frequency=10_000.0, + num_channels=4, + num_units=10, noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), seed=2205, ) @@ -28,13 +31,12 @@ def get_sorting_result(): return sorting_result + @pytest.fixture(scope="module") def sorting_result(): return get_sorting_result() - - def _get_templates_object_from_sorting_result(sorting_result): ext = sorting_result.get_extension("fast_templates") templates = Templates( diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 9afc3771aa..a5473ae89c 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -194,6 +194,7 @@ def test_estimate_templates_average(): # ax.plot(templates[unit_index, :, :].T.flatten()) # plt.show() + def test_estimate_templates(): recording, sorting = get_dataset() @@ -228,11 +229,7 @@ def test_estimate_templates(): # plt.show() - - - - if __name__ == "__main__": test_waveform_tools() test_estimate_templates_average() - test_estimate_templates() \ No newline at end of file + test_estimate_templates() diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index 2a602b1f38..f8f4b5af2d 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -13,8 +13,6 @@ from spikeinterface.core import extract_waveforms as old_extract_waveforms - - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" else: @@ -23,7 +21,10 @@ def get_dataset(): recording, sorting = generate_ground_truth_recording( - durations=[30.0, 20.], sampling_frequency=16000.0, num_channels=4, num_units=5, + durations=[30.0, 20.0], + sampling_frequency=16000.0, + num_channels=4, + num_units=5, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), generate_unit_locations_kwargs=dict( margin_um=5.0, @@ -50,7 +51,6 @@ def test_extract_waveforms(): we_kwargs = dict(sparse=True, max_spikes_per_unit=30) - we_old = old_extract_waveforms(recording, sorting, folder=folder, **we_kwargs) print(we_old) @@ -60,7 +60,7 @@ def test_extract_waveforms(): we_mock = mock_extract_waveforms(recording, sorting, folder=folder, **we_kwargs) print(we_mock) - + for we in (we_old, we_mock): selected_spikes = we.get_sampled_indices(unit_id=sorting.unit_ids[0]) @@ -75,8 +75,6 @@ def test_extract_waveforms(): templates = we.get_all_templates() # print(templates.shape) - - # test reading old WaveformsExtractor folder folder = cache_folder / "old_waveforms_extractor" sorting_result_from_we = load_waveforms_backwards(folder, output="SortingResult") @@ -85,7 +83,6 @@ def test_extract_waveforms(): print(mock_loaded_we_old) - # @pytest.mark.skip(): # def test_read_old_waveforms_extractor_binary(): # folder = "" diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index af40ea047a..04813921a6 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -705,7 +705,7 @@ def estimate_templates( unit_ids: list | np.ndarray, nbefore: int, nafter: int, - operator: str ="average", + operator: str = "average", return_scaled: bool = True, job_name=None, **job_kwargs, @@ -742,12 +742,25 @@ def estimate_templates( job_name = "estimate_templates" if operator == "average": - templates_array = estimate_templates_average(recording, spikes, unit_ids, nbefore, nafter, return_scaled=return_scaled, job_name=job_name, **job_kwargs) + templates_array = estimate_templates_average( + recording, spikes, unit_ids, nbefore, nafter, return_scaled=return_scaled, job_name=job_name, **job_kwargs + ) elif operator == "median": - all_waveforms, wf_array_info = extract_waveforms_to_single_buffer( recording, spikes, unit_ids, nbefore, nafter, - mode="shared_memory", return_scaled=return_scaled, copy=False,**job_kwargs,) - templates_array = np.zeros((len(unit_ids), all_waveforms.shape[1], all_waveforms.shape[2]), dtype=all_waveforms.dtype) - for unit_index , unit_id in enumerate(unit_ids): + all_waveforms, wf_array_info = extract_waveforms_to_single_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="shared_memory", + return_scaled=return_scaled, + copy=False, + **job_kwargs, + ) + templates_array = np.zeros( + (len(unit_ids), all_waveforms.shape[1], all_waveforms.shape[2]), dtype=all_waveforms.dtype + ) + for unit_index, unit_id in enumerate(unit_ids): wfs = all_waveforms[spikes["unit_index"] == unit_index] templates_array[unit_index, :, :] = np.median(wfs, axis=0) # release shared memory after the median @@ -757,7 +770,7 @@ def estimate_templates( raise ValueError(f"estimate_templates(..., operator={operator}) wrong operator must be average or median") return templates_array - + def estimate_templates_average( recording: BaseRecording, diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index af08385dcb..db8365d5b7 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -3,6 +3,7 @@ * load old WaveformsExtractor saved with folder or zarr (version <=0.100) into the SortingResult (version>0.100) * mock the function extract_waveforms() and the class SortingResult() but based SortingResult """ + from __future__ import annotations from typing import Literal, Optional @@ -89,10 +90,10 @@ def extract_waveforms( ms_before=ms_before, ms_after=ms_after, **other_kwargs, - **job_kwargs + **job_kwargs, ) - sorting_result = start_sorting_result(sorting, recording, format=format, folder=folder, - sparse=sparse, sparsity=sparsity, **sparsity_kwargs + sorting_result = start_sorting_result( + sorting, recording, format=format, folder=folder, sparse=sparse, sparsity=sparsity, **sparsity_kwargs ) # TODO propagate job_kwargs @@ -107,14 +108,12 @@ def extract_waveforms( # this also done because some metrics need it sorting_result.compute("noise_levels") - we = MockWaveformExtractor(sorting_result) return we - class MockWaveformExtractor: def __init__(self, sorting_result): self.sorting_result = sorting_result @@ -126,7 +125,7 @@ def __repr__(self): def is_sparse(self) -> bool: return self.sorting_result.is_sparse() - + def has_waveforms(self) -> bool: return self.sorting_result.get_extension("waveforms") is not None @@ -136,7 +135,7 @@ def delete_waveforms(self) -> None: @property def recording(self) -> BaseRecording: return self.sorting_result.recording - + @property def sorting(self) -> BaseSorting: return self.sorting_result.sorting @@ -248,7 +247,7 @@ def get_waveforms( lazy: bool = True, sparsity=None, force_dense: bool = False, - ): + ): # lazy and cache are ingnored ext = self.sorting_result.get_extension("waveforms") unit_index = self.sorting.id_to_index(unit_id) @@ -258,7 +257,9 @@ def get_waveforms( wfs = ext.data["waveforms"][spike_mask, :, :] if sparsity is not None: - assert self.sorting_result.sparsity is None, "Waveforms are alreayd sparse! Cannot apply an additional sparsity." + assert ( + self.sorting_result.sparsity is None + ), "Waveforms are alreayd sparse! Cannot apply an additional sparsity." wfs = wfs[:, :, sparsity.mask[self.sorting.id_to_index(unit_id)]] if force_dense: @@ -288,7 +289,7 @@ def get_all_templates( key = f"pencentile_{percentile}" else: key = mode - + templates = ext.data.get(key) if templates is None: raise ValueError(f"{mode} is not computed") @@ -299,7 +300,6 @@ def get_all_templates( return templates - def get_template( self, unit_id, mode="average", sparsity=None, force_dense: bool = False, percentile: float | None = None ): @@ -308,14 +308,18 @@ def get_template( return templates[0] - -def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None, output="MockWaveformExtractor", ): +def load_waveforms( + folder, + with_recording: bool = True, + sorting: Optional[BaseSorting] = None, + output="MockWaveformExtractor", +): """ This read an old WaveformsExtactor folder (folder or zarr) and convert it into a SortingResult or MockWaveformExtractor. It also mimic the old load_waveforms by opening a Sortingresult folder and return a MockWaveformExtractor. This later behavior is usefull to no break old code like this in versio >=0.101 - + >>> # In this example we is a MockWaveformExtractor that behave the same as before >>> we = extract_waveforms(..., folder="/my_we") >>> we = load_waveforms("/my_we") @@ -349,7 +353,6 @@ def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSo return MockWaveformExtractor(sorting_result) - def _read_old_waveforms_extractor_binary(folder): params_file = folder / "params.json" if not params_file.exists(): @@ -401,18 +404,20 @@ def _read_old_waveforms_extractor_binary(folder): # need to concatenate sampled_index and order it waveform_folder = folder / "waveforms" if waveform_folder.exists(): - + spikes = sorting.to_spike_vector() random_spike_mask = np.zeros(spikes.size, dtype="bool") all_sampled_indices = [] # first readd all sampled_index to get the correct ordering for unit_index, unit_id in enumerate(sorting.unit_ids): - # unit_indices has dtype=[("spike_index", "int64"), ("segment_index", "int64")] + # unit_indices has dtype=[("spike_index", "int64"), ("segment_index", "int64")] unit_indices = np.load(waveform_folder / f"sampled_index_{unit_id}.npy") for segment_index in range(sorting.get_num_segments()): in_seg_selected = unit_indices[unit_indices["segment_index"] == segment_index]["spike_index"] - spikes_indices = np.flatnonzero((spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index)) + spikes_indices = np.flatnonzero( + (spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index) + ) random_spike_mask[spikes_indices[in_seg_selected]] = True random_spikes_indices = np.flatnonzero(random_spike_mask) @@ -431,7 +436,7 @@ def _read_old_waveforms_extractor_binary(folder): for unit_index, unit_id in enumerate(sorting.unit_ids): wfs = np.load(waveform_folder / f"waveforms_{unit_id}.npy") mask = some_spikes["unit_index"] == unit_index - waveforms[:, :, :wfs.shape[2]][mask, :, :] = wfs + waveforms[:, :, : wfs.shape[2]][mask, :, :] = wfs sorting_result.random_spikes_indices = random_spikes_indices @@ -446,7 +451,7 @@ def _read_old_waveforms_extractor_binary(folder): for mode in ("average", "std", "median", "percentile"): template_file = folder / f"templates_{mode}.npy" if template_file.is_file(): - templates [mode] = np.load(template_file) + templates[mode] = np.load(template_file) if len(templates) > 0: ext = ComputeTemplates(sorting_result) ext.params = dict(operators=list(templates.keys())) @@ -454,21 +459,20 @@ def _read_old_waveforms_extractor_binary(folder): ext.data[mode] = arr sorting_result.extensions["templates"] = ext - # TODO : implement this when extension will be prted in the new API # old_extension_to_new_class : { - # old extensions with same names and equvalent data except similarity>template_similarity - # "spike_amplitudes": , - # "spike_locations": , - # "amplitude_scalings": , - # "template_metrics" : , - # "similarity": , - # "unit_locations": , - # "correlograms" : , - # isi_histograms: , - # "noise_levels": , - # "quality_metrics": , - # "principal_components" : , + # old extensions with same names and equvalent data except similarity>template_similarity + # "spike_amplitudes": , + # "spike_locations": , + # "amplitude_scalings": , + # "template_metrics" : , + # "similarity": , + # "unit_locations": , + # "correlograms" : , + # isi_histograms: , + # "noise_levels": , + # "quality_metrics": , + # "principal_components" : , # } # for ext_name, new_class in old_extension_to_new_class.items(): # ext_folder = folder / ext_name @@ -508,8 +512,4 @@ def _read_old_waveforms_extractor_binary(folder): # # TODO: this is for you # pass - - return sorting_result - - diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 85a44331f0..c77176b520 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -155,7 +155,9 @@ def get_potential_auto_merge( # STEP 3 : unit positions are estimated roughly with channel if "unit_positions" in steps: chan_loc = sorting_result.get_channel_locations() - unit_max_chan = get_template_extremum_channel(sorting_result, peak_sign=peak_sign, mode="extremum", outputs="index") + unit_max_chan = get_template_extremum_channel( + sorting_result, peak_sign=peak_sign, mode="extremum", outputs="index" + ) unit_max_chan = list(unit_max_chan.values()) unit_locations = chan_loc[unit_max_chan, :] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") @@ -196,7 +198,12 @@ def get_potential_auto_merge( # STEP 6 : validate the potential merges with CC increase the contamination quality metrics if "check_increase_score" in steps: pair_mask, pairs_decreased_score = check_improve_contaminations_score( - sorting_result, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms + sorting_result, + pair_mask, + contaminations, + firing_contamination_balance, + refractory_period_ms, + censored_period_ms, ) # FINAL STEP : create the final list from pair_mask boolean matrix @@ -421,7 +428,6 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair return templates_diff - def check_improve_contaminations_score( sorting_result, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms ): diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index af1163fb4b..f14e08c45a 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -13,6 +13,8 @@ job_kwargs = dict(n_jobs=-1) + + def make_sorting_result(sparse=True): recording, sorting = generate_ground_truth_recording( durations=[300.0], @@ -24,7 +26,7 @@ def make_sorting_result(sparse=True): seed=2205, ) - sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) + sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) sorting_result.select_random_spikes() sorting_result.compute("waveforms", **job_kwargs) sorting_result.compute("templates") diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 45b6ba370b..8886cf474f 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -9,7 +9,6 @@ from spikeinterface.curation import get_potential_auto_merge - from spikeinterface.curation.tests.common import make_sorting_result @@ -19,9 +18,8 @@ cache_folder = Path("cache_folder") / "curation" - def test_get_auto_merge_list(sorting_result_for_curation): - + sorting = sorting_result_for_curation.sorting recording = sorting_result_for_curation.recording num_unit_splited = 1 @@ -31,13 +29,10 @@ def test_get_auto_merge_list(sorting_result_for_curation): sorting, split_ids=sorting.unit_ids[:num_unit_splited], num_split=num_split, output_ids=True, seed=42 ) - - # print(sorting_with_split) # print(sorting_with_split.unit_ids) # print(other_ids) - job_kwargs = dict(n_jobs=-1) sorting_result = start_sorting_result(sorting_with_split, recording, format="memory") @@ -45,7 +40,6 @@ def test_get_auto_merge_list(sorting_result_for_curation): sorting_result.compute("waveforms", **job_kwargs) sorting_result.compute("templates") - potential_merges, outs = get_potential_auto_merge( sorting_result, minimum_spikes=1000, @@ -73,8 +67,6 @@ def test_get_auto_merge_list(sorting_result_for_curation): true_pair = tuple(true_pair) assert true_pair in potential_merges - - # import matplotlib.pyplot as plt # from spikeinterface.curation.auto_merge import normalize_correlogram # templates_diff = outs['templates_diff'] diff --git a/src/spikeinterface/curation/tests/test_remove_redundant.py b/src/spikeinterface/curation/tests/test_remove_redundant.py index b304ab19b9..a395c200a5 100644 --- a/src/spikeinterface/curation/tests/test_remove_redundant.py +++ b/src/spikeinterface/curation/tests/test_remove_redundant.py @@ -30,7 +30,6 @@ def test_remove_redundant_units(sorting_result_for_curation): sorting_result.compute("waveforms", **job_kwargs) sorting_result.compute("templates") - for remove_strategy in ("max_spikes", "minimum_shift", "highest_amplitude"): sorting_clean = remove_redundant_units(sorting_result, remove_strategy=remove_strategy) # print(sorting_clean) diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index bc4a636684..800124300f 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -33,7 +33,7 @@ def make_sorting_result(sparse=True, with_group=False): recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1]) sorting.set_property("group", [0, 0, 1, 1]) - sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) + sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) sorting_result.select_random_spikes() sorting_result.compute("waveforms") sorting_result.compute("templates") diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index b10e0c5a1d..bac8ebd75f 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -10,8 +10,7 @@ from spikeinterface.core import compute_sparsity from spikeinterface.exporters import export_to_phy -from spikeinterface.exporters.tests.common import (cache_folder, make_sorting_result) - +from spikeinterface.exporters.tests.common import cache_folder, make_sorting_result def test_export_to_phy(sorting_result_sparse_for_export): diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 03555c049e..af6bd69c17 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -13,14 +13,13 @@ BinaryRecordingExtractor, BinaryFolderRecording, ChannelSparsity, - SortingResult + SortingResult, ) from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs from spikeinterface.postprocessing import ( compute_spike_amplitudes, compute_template_similarity, compute_principal_components, - ) @@ -74,8 +73,7 @@ def export_to_phy( """ import pandas as pd - assert isinstance( - sorting_result, SortingResult), "sorting_result must be a SortingResult object" + assert isinstance(sorting_result, SortingResult), "sorting_result must be a SortingResult object" sorting = sorting_result.sorting assert ( diff --git a/src/spikeinterface/full.py b/src/spikeinterface/full.py index 62f454c32a..dc8b2dbdd3 100644 --- a/src/spikeinterface/full.py +++ b/src/spikeinterface/full.py @@ -20,9 +20,10 @@ from .preprocessing import * from .postprocessing import * from .qualitymetrics import * + # TODO -# from .curation import * -# from .comparison import * -# from .widgets import * -# from .exporters import * +# from .curation import * +# from .comparison import * +# from .widgets import * +# from .exporters import * from .generation import * diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 20b5799f6e..6553115c43 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -18,6 +18,7 @@ # TODO extra sparsity and job_kwargs handling + class ComputeAmplitudeScalings(ResultExtension): """ Computes the amplitude scalings from a WaveformExtractor. @@ -59,7 +60,9 @@ class ComputeAmplitudeScalings(ResultExtension): """ extension_name = "amplitude_scalings" - depend_on = ["fast_templates|templates", ] + depend_on = [ + "fast_templates|templates", + ] need_recording = True use_nodepipeline = True nodepipeline_variables = ["amplitude_scalings", "collision_mask"] @@ -105,12 +108,11 @@ def _select_extension_data(self, unit_ids): new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask] return new_data - def _get_pipeline_nodes(self): recording = self.sorting_result.recording sorting = self.sorting_result.sorting - + # TODO return_scaled is not any more a property of SortingResult this is hard coded for now return_scaled = True @@ -126,7 +128,7 @@ def _get_pipeline_nodes(self): ), f"`ms_before` must be smaller than `ms_before` used in ComputeTemplates: {nbefore}" else: cut_out_before = nbefore - + if self.params["ms_after"] is not None: cut_out_after = int(self.params["ms_after"] * self.sorting_result.sampling_frequency / 1000.0) assert ( @@ -136,7 +138,9 @@ def _get_pipeline_nodes(self): cut_out_after = nafter peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" - extremum_channels_indices = get_template_extremum_channel(self.sorting_result, peak_sign=peak_sign, outputs="index") + extremum_channels_indices = get_template_extremum_channel( + self.sorting_result, peak_sign=peak_sign, outputs="index" + ) # collisions handle_collisions = self.params["handle_collisions"] @@ -190,7 +194,11 @@ def _run(self, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amp_scalings, collision_mask = run_node_pipeline( - self.sorting_result.recording, nodes, job_kwargs=job_kwargs, job_name="amplitude_scalings", gather_mode="memory" + self.sorting_result.recording, + nodes, + job_kwargs=job_kwargs, + job_name="amplitude_scalings", + gather_mode="memory", ) self.data["amplitude_scalings"] = amp_scalings if self.params["handle_collisions"]: @@ -279,6 +287,7 @@ def _get_data(self): # compute_amplitude_scalings.__doc__.format(_shared_job_kwargs_doc) + class AmplitudeScalingNode(PipelineNode): def __init__( self, @@ -342,7 +351,7 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks): + def compute(self, traces, peaks): from scipy.stats import linregress # scale traces with margin to match scaling of templates @@ -365,9 +374,6 @@ def compute(self, traces, peaks): local_spikes_w_margin = peaks local_spikes = local_spikes_w_margin[~peaks["in_margin"]] - - - # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! @@ -439,9 +445,6 @@ def get_trace_margin(self): return self._margin - - - ### Collision handling ### def _are_unit_indices_overlapping(sparsity_mask, i, j): """ diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index f81556a883..260a2693e8 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -43,8 +43,9 @@ class ComputeCorrelograms(ResultExtension): 2D array with ISI histograms (num_units, num_bins) bins : np.array 1D array with bins in ms - + """ + extension_name = "correlograms" depend_on = [] need_recording = False @@ -79,6 +80,7 @@ def _get_data(self): register_result_extension(ComputeCorrelograms) compute_correlograms_sorting_result = ComputeCorrelograms.function_factory() + def compute_correlograms( sorting_result_or_sorting, window_ms: float = 50.0, @@ -86,13 +88,16 @@ def compute_correlograms( method: str = "auto", ): if isinstance(sorting_result_or_sorting, SortingResult): - return compute_correlograms_sorting_result(sorting_result_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + return compute_correlograms_sorting_result( + sorting_result_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + ) else: - return compute_correlograms_on_sorting(sorting_result_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) - -compute_correlograms.__doc__ = compute_correlograms_sorting_result.__doc__ + return compute_correlograms_on_sorting( + sorting_result_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + ) +compute_correlograms.__doc__ = compute_correlograms_sorting_result.__doc__ def _make_bins(sorting, window_ms, bin_ms): diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 611bde47ad..dbcdfc268b 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -69,7 +69,6 @@ def _get_data(self): compute_isi_histograms = ComputeISIHistograms.function_factory() - # def compute_isi_histograms( # waveform_or_sorting_extractor, # load_if_exists=False, diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index b39ece1a05..3b3f949b93 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -12,6 +12,7 @@ from spikeinterface.core.sortingresult import register_result_extension, ResultExtension from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs + # from spikeinterface.core.globals import get_global_tmp_folder _possible_modes = ["by_channel_local", "by_channel_global", "concatenated"] @@ -19,6 +20,7 @@ # TODO handle extra sparsity + class ComputePrincipalComponents(ResultExtension): """ Compute PC scores from waveform extractor. The PCA projections are pre-computed only @@ -59,7 +61,9 @@ class ComputePrincipalComponents(ResultExtension): """ extension_name = "principal_components" - depend_on = ["waveforms", ] + depend_on = [ + "waveforms", + ] need_recording = False use_nodepipeline = False need_job_kwargs = True @@ -68,7 +72,12 @@ def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) def _set_params( - self, n_components=5, mode="by_channel_local", whiten=True, dtype="float32", sparsity=None, + self, + n_components=5, + mode="by_channel_local", + whiten=True, + dtype="float32", + sparsity=None, ): assert mode in _possible_modes, "Invalid mode!" @@ -102,7 +111,6 @@ def _select_extension_data(self, unit_ids): new_data[k] = v return new_data - def get_projections(self, unit_id, sparse=False): """ Returns the computed projections for the sampled waveforms of a unit id. @@ -209,10 +217,9 @@ def project_new(self, new_spikes, new_waveforms, progress_bar=True): """ pca_model = self.get_pca_model() - new_projections = self._transform_waveforms( new_spikes, new_waveforms, pca_model, progress_bar=progress_bar) + new_projections = self._transform_waveforms(new_spikes, new_waveforms, pca_model, progress_bar=progress_bar) return new_projections - def get_sparsity(self): if self.sorting_result.is_sparse(): return self.sorting_result.sparsity @@ -245,18 +252,16 @@ def _run(self, **job_kwargs): pca_model = self._fit_concatenated(progress_bar) self.data[f"pca_model_{mode}"] = pca_model - # transform waveforms_ext = self.sorting_result.get_extension("waveforms") some_waveforms = waveforms_ext.data["waveforms"] spikes = self.sorting_result.sorting.to_spike_vector() some_spikes = spikes[self.sorting_result.random_spikes_indices] - + pca_projection = self._transform_waveforms(some_spikes, some_waveforms, pca_model, progress_bar) self.data["pca_projection"] = pca_projection - def _get_data(self): return self.data["pca_projection"] @@ -297,7 +302,6 @@ def _get_data(self): # file_path = self.extension_folder / "all_pcs.npy" # file_path = Path(file_path) - # sparsity = self.get_sparsity() # if sparsity is None: # sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} @@ -358,7 +362,9 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): pca.partial_fit(wfs[:, :, wf_ind]) else: # parallel - items = [(chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds)] + items = [ + (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds) + ] n_jobs = min(n_jobs, len(items)) with ProcessPoolExecutor(max_workers=n_jobs) as executor: @@ -394,9 +400,8 @@ def _fit_by_channel_global(self, progress_bar): wfs_concat = wfs.transpose(0, 2, 1).reshape(shape[0] * shape[2], shape[1]) pca_model.partial_fit(wfs_concat) - return pca_model - + def _fit_concatenated(self, progress_bar): p = self.params @@ -422,7 +427,6 @@ def _fit_concatenated(self, progress_bar): pca_model.partial_fit(wfs_flat) return pca_model - def _transform_waveforms(self, spikes, waveforms, pca_model, progress_bar): # transform a waveforms buffer @@ -460,7 +464,7 @@ def _transform_waveforms(self, spikes, waveforms, pca_model, progress_bar): pca_model = pca_models[chan_ind] try: proj = pca_model.transform(wfs[:, :, wf_ind]) - pca_projection[:, :, wf_ind][spike_mask, : ] = proj + pca_projection[:, :, wf_ind][spike_mask, :] = proj except NotFittedError as e: # this could happen if len(wfs) is less then n_comp for a channel project_on_non_fitted = True @@ -477,27 +481,27 @@ def _transform_waveforms(self, spikes, waveforms, pca_model, progress_bar): continue for wf_ind, chan_ind in enumerate(channel_inds): proj = pca_model.transform(wfs[:, :, wf_ind]) - pca_projection[:, :, wf_ind][spike_mask, : ] = proj + pca_projection[:, :, wf_ind][spike_mask, :] = proj elif mode == "concatenated": for unit_ind, unit_id in units_loop: wfs, channel_inds, spike_mask = self._get_slice_waveforms(unit_id, spikes, waveforms) wfs_flat = wfs.reshape(wfs.shape[0], -1) proj = pca_model.transform(wfs_flat) pca_projection[spike_mask, :] = proj - + return pca_projection def _get_slice_waveforms(self, unit_id, spikes, waveforms): - # slice by mask waveforms from one unit + # slice by mask waveforms from one unit unit_index = self.sorting_result.sorting.id_to_index(unit_id) spike_mask = spikes["unit_index"] == unit_index wfs = waveforms[spike_mask, :, :] - + sparsity = self.sorting_result.sparsity if sparsity is not None: channel_inds = sparsity.unit_id_to_channel_indices[unit_id] - wfs = wfs[:, :, :channel_inds.size] + wfs = wfs[:, :, : channel_inds.size] else: channel_inds = np.arange(self.sorting_result.channel_ids.size, dtype=int) @@ -605,8 +609,8 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte # with pca_file.open("wb") as f: # pickle.dump(pca_model, f) + def partial_fit_one_channel(args): chan_ind, pca_model, wf_chan = args pca_model.partial_fit(wf_chan) return chan_ind, pca_model - diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index ac1c2079d5..baa278643d 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -11,6 +11,7 @@ from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type from spikeinterface.core.sorting_tools import spike_vector_to_indices + class ComputeSpikeAmplitudes(ResultExtension): """ ResultExtension @@ -52,8 +53,11 @@ class ComputeSpikeAmplitudes(ResultExtension): All locations for all spikes and all units are concatenated """ + extension_name = "spike_amplitudes" - depend_on = ["fast_templates|templates", ] + depend_on = [ + "fast_templates|templates", + ] need_recording = True use_nodepipeline = True nodepipeline_variables = ["amplitudes"] @@ -79,7 +83,6 @@ def _select_extension_data(self, unit_ids): return new_data - def _get_pipeline_nodes(self): recording = self.sorting_result.recording @@ -88,7 +91,9 @@ def _get_pipeline_nodes(self): peak_sign = self.params["peak_sign"] return_scaled = self.params["return_scaled"] - extremum_channels_indices = get_template_extremum_channel(self.sorting_result, peak_sign=peak_sign, outputs="index") + extremum_channels_indices = get_template_extremum_channel( + self.sorting_result, peak_sign=peak_sign, outputs="index" + ) peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_result, peak_sign=peak_sign) if return_scaled: @@ -115,12 +120,16 @@ def _run(self, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amps = run_node_pipeline( - self.sorting_result.recording, nodes, job_kwargs=job_kwargs, job_name="spike_amplitudes", gather_mode="memory" + self.sorting_result.recording, + nodes, + job_kwargs=job_kwargs, + job_name="spike_amplitudes", + gather_mode="memory", ) self.data["amplitudes"] = amps def _get_data(self, outputs="numpy"): - all_amplitudes = self.data["amplitudes"] + all_amplitudes = self.data["amplitudes"] if outputs == "numpy": return all_amplitudes elif outputs == "by_unit": @@ -137,13 +146,12 @@ def _get_data(self, outputs="numpy"): else: raise ValueError(f"Wrong .get_data(outputs={outputs})") + register_result_extension(ComputeSpikeAmplitudes) compute_spike_amplitudes = ComputeSpikeAmplitudes.function_factory() - - class SpikeAmplitudeNode(PipelineNode): def __init__( self, @@ -192,7 +200,7 @@ def compute(self, traces, peaks): # and scale if self._gains is not None: traces = traces.astype("float32") * self._gains + self._offsets - amplitudes = amplitudes.astype('float32', copy=True) + amplitudes = amplitudes.astype("float32", copy=True) amplitudes *= self._gains[chan_inds] amplitudes += self._offsets[chan_inds] diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 03a1e7d52a..32bd5f1455 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -14,7 +14,6 @@ # TODO job_kwargs - class ComputeSpikeLocations(ResultExtension): """ Localize spikes in 2D or 3D with several methods given the template. @@ -54,7 +53,9 @@ class ComputeSpikeLocations(ResultExtension): """ extension_name = "spike_locations" - depend_on = ["fast_templates|templates", ] + depend_on = [ + "fast_templates|templates", + ] need_recording = True use_nodepipeline = True nodepipeline_variables = ["spike_locations"] @@ -73,7 +74,7 @@ def _set_params( spike_retriver_kwargs=None, method="center_of_mass", method_kwargs={}, - ): + ): spike_retriver_kwargs_ = dict( channel_from_template=True, radius_um=50, @@ -82,8 +83,11 @@ def _set_params( if spike_retriver_kwargs is not None: spike_retriver_kwargs_.update(spike_retriver_kwargs) params = dict( - ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs_, method=method, - method_kwargs=method_kwargs + ms_before=ms_before, + ms_after=ms_after, + spike_retriver_kwargs=spike_retriver_kwargs_, + method=method, + method_kwargs=method_kwargs, ) return params @@ -100,8 +104,10 @@ def _get_pipeline_nodes(self): recording = self.sorting_result.recording sorting = self.sorting_result.sorting - peak_sign=self.params["spike_retriver_kwargs"]["peak_sign"] - extremum_channels_indices = get_template_extremum_channel(self.sorting_result, peak_sign=peak_sign, outputs="index") + peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] + extremum_channels_indices = get_template_extremum_channel( + self.sorting_result, peak_sign=peak_sign, outputs="index" + ) retriever = SpikeRetriever( recording, @@ -110,7 +116,12 @@ def _get_pipeline_nodes(self): extremum_channel_inds=extremum_channels_indices, ) nodes = get_localization_pipeline_nodes( - recording, retriever, method=self.params["method"], ms_before=self.params["ms_before"], ms_after=self.params["ms_after"], **self.params["method_kwargs"] + recording, + retriever, + method=self.params["method"], + ms_before=self.params["ms_before"], + ms_after=self.params["ms_after"], + **self.params["method_kwargs"], ) return nodes @@ -119,7 +130,11 @@ def _run(self, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() spike_locations = run_node_pipeline( - self.sorting_result.recording, nodes, job_kwargs=job_kwargs, job_name="spike_locations", gather_mode="memory" + self.sorting_result.recording, + nodes, + job_kwargs=job_kwargs, + job_name="spike_locations", + gather_mode="memory", ) self.data["spike_locations"] = spike_locations diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index c258c6afcd..5e9591996f 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -20,6 +20,7 @@ # TODO handle external sparsity + def get_single_channel_template_metric_names(): return deepcopy(list(_single_channel_metric_name_to_func.keys())) @@ -97,7 +98,9 @@ class ComputeTemplateMetrics(ResultExtension): """ extension_name = "template_metrics" - depend_on = ["fast_templates|templates", ] + depend_on = [ + "fast_templates|templates", + ] need_recording = True use_nodepipeline = False need_job_kwargs = False @@ -113,7 +116,7 @@ def _set_params( metrics_kwargs=None, include_multi_channel_metrics=False, ): - + if sparsity is not None: # TODO handle extra sparsity raise NotImplementedError @@ -130,12 +133,11 @@ def _set_params( self.sorting_result.get_channel_locations().shape[1] == 2 ), "If multi-channel metrics are computed, channel locations must be 2D." - if metric_names is None: metric_names = get_single_channel_template_metric_names() if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - + if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() else: @@ -286,8 +288,6 @@ def _get_data(self): ) - - def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index b86ec42c9a..2c4334dc15 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -5,6 +5,7 @@ from spikeinterface.core.sortingresult import register_result_extension, ResultExtension from ..core.template_tools import _get_dense_templates_array + class ComputeTemplateSimilarity(ResultExtension): """Compute similarity between templates with several methods. @@ -23,7 +24,9 @@ class ComputeTemplateSimilarity(ResultExtension): """ extension_name = "template_similarity" - depend_on = ["fast_templates|templates", ] + depend_on = [ + "fast_templates|templates", + ] need_recording = True use_nodepipeline = False need_job_kwargs = False @@ -43,16 +46,20 @@ def _select_extension_data(self, unit_ids): def _run(self): templates_array = _get_dense_templates_array(self.sorting_result, return_scaled=True) - similarity = compute_similarity_with_templates_array(templates_array, templates_array, method=self.params["method"]) + similarity = compute_similarity_with_templates_array( + templates_array, templates_array, method=self.params["method"] + ) self.data["similarity"] = similarity def _get_data(self): return self.data["similarity"] + # @alessio: compute_template_similarity() is now one inner SortingResult only register_result_extension(ComputeTemplateSimilarity) compute_template_similarity = ComputeTemplateSimilarity.function_factory() + def compute_similarity_with_templates_array(templates_array, other_templates_array, method): import sklearn.metrics.pairwise @@ -64,7 +71,7 @@ def compute_similarity_with_templates_array(templates_array, other_templates_arr else: raise ValueError(f"compute_template_similarity(method {method}) not exists") - + return similarity @@ -75,9 +82,6 @@ def compute_template_similarity_by_pair(sorting_result_1, sorting_result_2, meth return similmarity - - - # def _compute_template_similarity( # waveform_extractor, load_if_exists=False, method="cosine_similarity", waveform_extractor_other=None # ): diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index abb760420d..8078b031e3 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -17,9 +17,13 @@ else: cache_folder = Path("cache_folder") / "postprocessing" + def get_dataset(): recording, sorting = generate_ground_truth_recording( - durations=[15.0, 5.0], sampling_frequency=24000.0, num_channels=6, num_units=3, + durations=[15.0, 5.0], + sampling_frequency=24000.0, + num_channels=6, + num_units=3, generate_sorting_kwargs=dict(firing_rates=3.0, refractory_period_ms=4.0), generate_unit_locations_kwargs=dict( margin_um=5.0, @@ -36,6 +40,7 @@ def get_dataset(): ) return recording, sorting + def get_sorting_result(recording, sorting, format="memory", sparsity=None, name=""): sparse = sparsity is not None if format == "memory": @@ -46,11 +51,12 @@ def get_sorting_result(recording, sorting, format="memory", sparsity=None, name= folder = cache_folder / f"test_{name}_sparse{sparse}_{format}.zarr" if folder and folder.exists(): shutil.rmtree(folder) - + sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity) return sortres + class ResultExtensionCommonTestSuite: """ Common tests with class approach to compute extension on several cases (3 format x 2 sparsity) @@ -61,6 +67,7 @@ class ResultExtensionCommonTestSuite: This also test the select_units() ability. """ + extension_class = None extension_function_params_list = None @@ -81,11 +88,13 @@ def setUpClass(cls): @property def extension_name(self): return self.extension_class.extension_name - + def _prepare_sorting_result(self, format, sparse): # prepare a SortingResult object with depencies already computed sparsity_ = self.sparsity if sparse else None - sorting_result = get_sorting_result(self.recording, self.sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name) + sorting_result = get_sorting_result( + self.recording, self.sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name + ) sorting_result.select_random_spikes(max_spikes_per_unit=50, seed=2205) for dependency_name in self.extension_class.depend_on: if "|" in dependency_name: @@ -100,20 +109,19 @@ def _check_one(self, sorting_result): job_kwargs = dict() for params in self.extension_function_params_list: - print(' params', params) + print(" params", params) ext = sorting_result.compute(self.extension_name, **params, **job_kwargs) assert len(ext.data) > 0 main_data = ext.get_data() ext = sorting_result.get_extension(self.extension_name) assert ext is not None - + some_unit_ids = sorting_result.unit_ids[::2] sliced = sorting_result.select_units(some_unit_ids, format="memory") assert np.array_equal(sliced.unit_ids, sorting_result.unit_ids[::2]) # print(sliced) - def test_extension(self): for sparse in (True, False): for format in ("memory", "binary_folder", "zarr"): diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index c43419eb5a..eda88a5f1d 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -7,7 +7,6 @@ from spikeinterface.postprocessing import ComputeAmplitudeScalings - class AmplitudeScalingsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeAmplitudeScalings extension_function_params_list = [ @@ -22,7 +21,7 @@ def test_scaling_values(self): spikes = sorting_result.sorting.to_spike_vector() ext = sorting_result.get_extension("amplitude_scalings") - + for unit_index, unit_id in enumerate(sorting_result.unit_ids): mask = spikes["unit_index"] == unit_index scalings = ext.data["amplitude_scalings"][mask] diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 40a9a603b2..b9fbde18f8 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -4,6 +4,7 @@ try: import numba + HAVE_NUMBA = True except ModuleNotFoundError as err: HAVE_NUMBA = False @@ -15,7 +16,6 @@ from spikeinterface.postprocessing.correlograms import compute_correlograms_on_sorting, _make_bins - class ComputeCorrelogramsTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeCorrelograms extension_function_params_list = [ @@ -26,7 +26,6 @@ class ComputeCorrelogramsTest(ResultExtensionCommonTestSuite, unittest.TestCase) extension_function_params_list.append(dict(method="numba")) - def test_make_bins(): sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 618d4a6b06..89ed1257bc 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -51,4 +51,3 @@ def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): test.setUpClass() test.test_extension() test.test_compute_ISI() - diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 915d08acc7..fa6a0bfd9b 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -8,8 +8,6 @@ from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite - - # from spikeinterface import compute_sparsity # from spikeinterface.postprocessing import WaveformPrincipalComponent, compute_principal_components # from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite @@ -78,7 +76,6 @@ def test_mode_concatenated(self): # pc_unit = all_pc_sparse[all_spikes_seg0["unit_index"] == unit_index] # assert np.allclose(pc_unit[:, :, len(sparse_channel_ids) :], 0) - def test_project_new(self): from sklearn.decomposition import IncrementalPCA @@ -90,12 +87,11 @@ def test_project_new(self): sorting_result.compute("principal_components", mode="by_channel_local", n_components=n_components) ext_pca = sorting_result.get_extension(self.extension_name) - num_spike = 100 new_spikes = sorting_result.sorting.to_spike_vector()[:num_spike] new_waveforms = np.random.randn(num_spike, waveforms.shape[1], waveforms.shape[2]) new_proj = ext_pca.project_new(new_spikes, new_waveforms) - + assert new_proj.shape[0] == num_spike assert new_proj.shape[1] == n_components assert new_proj.shape[2] == ext_pca.data["pca_projection"].shape[2] @@ -109,7 +105,6 @@ def test_project_new(self): # test.test_compute_for_all_spikes() test.test_project_new() - # ext = test.sorting_results["sparseTrue_memory"].get_extension("principal_components") # pca = ext.data["pca_projection"] # import matplotlib.pyplot as plt diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index 12b800a8cc..a7cca70363 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -12,6 +12,7 @@ class ComputeSpikeAmplitudesTest(ResultExtensionCommonTestSuite, unittest.TestCa dict(return_scaled=False), ] + if __name__ == "__main__": test = ComputeSpikeAmplitudesTest() test.setUpClass() diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index b2a5d6c9d5..c1f49bc849 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -5,20 +5,21 @@ from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite - - class SpikeLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeSpikeLocations extension_function_params_list = [ - dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)), # chunk_size=10000, n_jobs=1, + dict( + method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True) + ), # chunk_size=10000, n_jobs=1, dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), - dict(method="center_of_mass", ), - dict(method="monopolar_triangulation"), # , chunk_size=10000, n_jobs=1 - dict(method="grid_convolution"), # , chunk_size=10000, n_jobs=1 + dict( + method="center_of_mass", + ), + dict(method="monopolar_triangulation"), # , chunk_size=10000, n_jobs=1 + dict(method="grid_convolution"), # , chunk_size=10000, n_jobs=1 ] - if __name__ == "__main__": test = SpikeLocationsExtensionTest() test.setUpClass() diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index e5d5c73b8e..5954db646a 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -5,7 +5,6 @@ from spikeinterface.postprocessing import ComputeTemplateMetrics - class TemplateMetricsTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeTemplateMetrics extension_function_params_list = [ @@ -15,7 +14,6 @@ class TemplateMetricsTest(ResultExtensionCommonTestSuite, unittest.TestCase): ] - if __name__ == "__main__": test = TemplateMetricsTest() test.setUpClass() diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index b169d5fe49..e1e8f1231a 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -22,6 +22,7 @@ class SimilarityExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase) # waveforms1 = we.get_waveforms(unit_id1) # check_equal_template_with_distribution_overlap(waveforms0, waveforms1) + # TODO check_equal_template_with_distribution_overlap if __name__ == "__main__": diff --git a/src/spikeinterface/postprocessing/tests/test_unit_localization.py b/src/spikeinterface/postprocessing/tests/test_unit_localization.py index 0aacbc7d4f..a46c743cb5 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_localization.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_localization.py @@ -3,15 +3,12 @@ from spikeinterface.postprocessing import ComputeUnitLocations - class UnitLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeUnitLocations extension_function_params_list = [ dict(method="center_of_mass", radius_um=100), - - dict(method="grid_convolution", radius_um=50), + dict(method="grid_convolution", radius_um=50), dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), - dict(method="monopolar_triangulation", radius_um=150), dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), ] diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 38a714610e..6358b9fd31 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -34,7 +34,7 @@ class ComputeUnitLocations(ResultExtension): Parameters ---------- sorting_result: SortingResult - A SortingResult object + A SortingResult object method: "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The method to use for localization outputs: "numpy" | "by_unit", default: "numpy" @@ -49,7 +49,9 @@ class ComputeUnitLocations(ResultExtension): """ extension_name = "unit_locations" - depend_on = ["fast_templates|templates", ] + depend_on = [ + "fast_templates|templates", + ] need_recording = True use_nodepipeline = False need_job_kwargs = False @@ -89,6 +91,7 @@ def get_data(self, outputs="numpy"): locations_by_unit[unit_id] = self.data["unit_locations"][unit_ind] return locations_by_unit + register_result_extension(ComputeUnitLocations) compute_unit_locations = ComputeUnitLocations.function_factory() @@ -209,7 +212,7 @@ def compute_monopolar_triangulation( Parameters ---------- sorting_result: SortingResult - A SortingResult object + A SortingResult object method: "least_square" | "minimize_with_log_penality", default: "least_square" The optimizer to use radius_um: float, default: 75 @@ -239,12 +242,10 @@ def compute_monopolar_triangulation( contact_locations = sorting_result.get_channel_locations() - sparsity = compute_sparsity(sorting_result, method="radius", radius_um=radius_um) templates = _get_dense_templates_array(sorting_result) nbefore = _get_nbefore(sorting_result) - if enforce_decrease: neighbours_mask = np.zeros((templates.shape[0], templates.shape[2]), dtype=bool) for i, unit_id in enumerate(unit_ids): @@ -287,7 +288,7 @@ def compute_center_of_mass(sorting_result, peak_sign="neg", radius_um=75, featur Parameters ---------- sorting_result: SortingResult - A SortingResult object + A SortingResult object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um: float @@ -349,7 +350,7 @@ def compute_grid_convolution( Parameters ---------- sorting_result: SortingResult - A SortingResult object + A SortingResult object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um: float, default: 40.0 @@ -381,7 +382,6 @@ def compute_grid_convolution( nbefore = _get_nbefore(sorting_result) nafter = templates.shape[1] - nbefore - fs = sorting_result.sampling_frequency percentile = 100 - percentile assert 0 <= percentile <= 100, "Percentile should be in [0, 100]" @@ -398,7 +398,6 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, margin_um, weight_method ) - peak_channels = get_template_extremum_channel(sorting_result, peak_sign, outputs="index") weights_sparsity_mask = weights > 0 diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 0fadf3ccc6..eeb51917e4 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -107,11 +107,10 @@ def __init__( scale_amplitude=False, time_jitter=0, waveforms_kwargs=None, - ): + ): if waveforms_kwargs is not None: warnings("remove_artifacts() waveforms_kwargs is deprecated and ignored") - available_modes = ("zeros", "linear", "cubic", "average", "median") num_seg = recording.get_num_segments() @@ -172,13 +171,19 @@ def __init__( ms_before is not None and ms_after is not None ), f"ms_before/after should not be None for mode {mode}" sorting = NumpySorting.from_times_labels(list_triggers, list_labels, recording.get_sampling_frequency()) - + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) - templates = estimate_templates(recording=recording, spikes=sorting.to_spike_vector(), - unit_ids=sorting.unit_ids, nbefore=nbefore, nafter=nafter, - operator=mode, return_scaled=False) + templates = estimate_templates( + recording=recording, + spikes=sorting.to_spike_vector(), + unit_ids=sorting.unit_ids, + nbefore=nbefore, + nafter=nafter, + operator=mode, + return_scaled=False, + ) artifacts = {} for i, label in enumerate(sorting.unit_ids): artifacts[label] = templates[i, :, :] diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 5ca2f802b9..e90add4067 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -25,7 +25,6 @@ ) - try: import numba @@ -703,9 +702,7 @@ def compute_amplitude_cv_metrics( amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: firing_rate = num_spikes[unit_id] / total_duration - temporal_bin_size_samples = int( - (average_num_spikes_per_bin / firing_rate) * sorting_result.sampling_frequency - ) + temporal_bin_size_samples = int((average_num_spikes_per_bin / firing_rate) * sorting_result.sampling_frequency) amp_spreads = [] # bins and amplitude means are computed for each segment @@ -750,11 +747,11 @@ def _get_amplitudes_by_units(sorting_result, unit_ids, peak_sign): all_amplitudes = ext.get_data() for unit_id in unit_ids: unit_index = sorting_result.sorting.id_to_index(unit_id) - spike_mask = spikes["unit_index"] ==unit_index + spike_mask = spikes["unit_index"] == unit_index amplitudes_by_units[unit_id] = all_amplitudes[spike_mask] elif sorting_result.has_extension("waveforms"): - waveforms_ext = sorting_result.get_extension("waveforms") + waveforms_ext = sorting_result.get_extension("waveforms") before = waveforms_ext.nbefore extremum_channels_ids = get_template_extremum_channel(sorting_result, peak_sign=peak_sign) for unit_id in unit_ids: @@ -765,9 +762,10 @@ def _get_amplitudes_by_units(sorting_result, unit_ids, peak_sign): else: chan_ind = sorting_result.channel_ids_to_indices([chan_id])[0] amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] - + return amplitudes_by_units + def compute_amplitude_cutoffs( sorting_result, peak_sign="neg", @@ -819,15 +817,17 @@ def compute_amplitude_cutoffs( if unit_ids is None: unit_ids = sorting_result.unit_ids - all_fraction_missing = {} if sorting_result.has_extension("spike_amplitudes") or sorting_result.has_extension("waveforms"): invert_amplitudes = False - if sorting_result.has_extension("spike_amplitudes") and sorting_result.get_extension("spike_amplitudes").params["peak_sign"] == "pos": + if ( + sorting_result.has_extension("spike_amplitudes") + and sorting_result.get_extension("spike_amplitudes").params["peak_sign"] == "pos" + ): invert_amplitudes = True elif sorting_result.has_extension("waveforms") and peak_sign == "pos": - invert_amplitudes = True + invert_amplitudes = True amplitudes_by_units = _get_amplitudes_by_units(sorting_result, unit_ids, peak_sign) @@ -889,12 +889,13 @@ def compute_amplitude_medians(sorting_result, peak_sign="neg", unit_ids=None): for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) else: - warnings.warn("compute_amplitude_medians need 'spike_amplitudes' or 'waveforms' extension") - for unit_id in unit_ids: + warnings.warn("compute_amplitude_medians need 'spike_amplitudes' or 'waveforms' extension") + for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.nan return all_amplitude_medians + _default_params["amplitude_median"] = dict(peak_sign="neg") @@ -972,11 +973,9 @@ def compute_drift_metrics( spike_locations_by_unit = {} for unit_id in unit_ids: unit_index = sorting.id_to_index(unit_id) - spike_mask = spikes["unit_index"] ==unit_index + spike_mask = spikes["unit_index"] == unit_index spike_locations_by_unit[unit_id] = spike_locations[spike_mask] - - else: warnings.warn( "The drift metrics require the `spike_locations` waveform extension. " @@ -1439,7 +1438,7 @@ def compute_sd_ratio( sd_ratio = {} for unit_id in unit_ids: unit_index = sorting_result.sorting.id_to_index(unit_id) - + spk_amp = [] for segment_index in range(sorting_result.get_num_segments()): @@ -1472,13 +1471,12 @@ def compute_sd_ratio( best_channel = best_channels[unit_id] std_noise = noise_levels[best_channel] - if correct_for_template_itself: # template = sorting_result.get_template(unit_id, force_dense=True)[:, best_channel] - + template = tamplates_array[unit_index, :, :][:, best_channel] nsamples = template.shape[0] - + # Computing the variance of a trace that is all 0 and n_spikes non-overlapping template. # TODO: Take into account that templates for different segments might differ. p = nsamples * n_spikes[unit_id] / sorting_result.get_total_samples() diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 561bcb6d5b..34a522273c 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -88,7 +88,6 @@ def calculate_pc_metrics( pca_ext = sorting_result.get_extension("principal_components") assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" - sorting = sorting_result.sorting if metric_names is None: @@ -129,7 +128,6 @@ def calculate_pc_metrics( if run_in_parallel: parallel_functions = [] - # all_labels, all_pcs = pca.get_all_projections() # TODO: this is wring all_pcs used to be dense even when the waveform extractor was sparse all_pcs = pca_ext.data["pca_projection"] @@ -657,7 +655,6 @@ def nearest_neighbors_noise_overlap( templates_ext = sorting_result.get_extension("templates") assert templates_ext is not None, "nearest_neighbors_isolation() need extension 'templates'" - if n_spikes_all_units is None: n_spikes_all_units = compute_num_spikes(sorting_result) if fr_all_units is None: diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index db5165bc1c..59cd8b49d4 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -16,7 +16,8 @@ from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params -# TODO : +# TODO : + class ComputeQualityMetrics(ResultExtension): """ @@ -50,9 +51,7 @@ class ComputeQualityMetrics(ResultExtension): use_nodepipeline = False need_job_kwargs = True - def _set_params( - self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False - ): + def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False): if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list @@ -87,7 +86,7 @@ def _set_params( def _select_extension_data(self, unit_ids): new_metrics = self.data["metrics"].loc[np.array(unit_ids)] new_data = dict(metrics=new_metrics) - return new_data + return new_data def _run(self, verbose=False, **job_kwargs): """ @@ -173,7 +172,6 @@ def _get_data(self): compute_quality_metrics = ComputeQualityMetrics.function_factory() - def get_quality_metric_list(): """Get a list of the available quality metrics.""" diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 6d821a5115..27a84da440 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -3,7 +3,11 @@ from pathlib import Path import numpy as np from spikeinterface.core import ( - NumpySorting, synthetize_spike_train_bad_isi, add_synchrony_to_sorting, generate_ground_truth_recording, start_sorting_result + NumpySorting, + synthetize_spike_train_bad_isi, + add_synchrony_to_sorting, + generate_ground_truth_recording, + start_sorting_result, ) # from spikeinterface.extractors.toy_example import toy_example @@ -43,10 +47,14 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - def _sorting_result_simple(): recording, sorting = generate_ground_truth_recording( - durations=[50.0,], sampling_frequency=30_000.0, num_channels=6, num_units=10, + durations=[ + 50.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), seed=2205, @@ -63,10 +71,12 @@ def _sorting_result_simple(): return sorting_result + @pytest.fixture(scope="module") def sorting_result_simple(): return _sorting_result_simple() + def _sorting_violation(): max_time = 100.0 sampling_frequency = 30000 @@ -100,9 +110,11 @@ def _sorting_result_violations(): sorting = _sorting_violation() duration = (sorting.to_spike_vector()["sample_index"][-1] + 1) / sorting.sampling_frequency - + recording, sorting = generate_ground_truth_recording( - durations=[duration], sampling_frequency=sorting.sampling_frequency, num_channels=6, + durations=[duration], + sampling_frequency=sorting.sampling_frequency, + num_channels=6, sorting=sorting, noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), seed=2205, @@ -117,10 +129,6 @@ def sorting_result_violations(): return _sorting_result_violations() - - - - def test_mahalanobis_metrics(): all_pcs1, all_labels1 = create_ground_truth_pc_distributions([1, -1], [1000, 1000]) all_pcs2, all_labels2 = create_ground_truth_pc_distributions( @@ -289,7 +297,9 @@ def test_calculate_sliding_rp_violations(sorting_result_violations): def test_calculate_rp_violations(sorting_result_violations): sorting_result = sorting_result_violations - rp_contamination, counts = compute_refrac_period_violations(sorting_result, refractory_period_ms=1, censored_period_ms=0.0) + rp_contamination, counts = compute_refrac_period_violations( + sorting_result, refractory_period_ms=1, censored_period_ms=0.0 + ) print(rp_contamination, counts) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -304,7 +314,9 @@ def test_calculate_rp_violations(sorting_result_violations): # we.sorting = sorting sorting_result2 = start_sorting_result(sorting, sorting_result.recording, format="memory", sparse=False) - rp_contamination, counts = compute_refrac_period_violations(sorting_result2, refractory_period_ms=1, censored_period_ms=0.0) + rp_contamination, counts = compute_refrac_period_violations( + sorting_result2, refractory_period_ms=1, censored_period_ms=0.0 + ) assert np.isnan(rp_contamination[1]) @@ -326,10 +338,7 @@ def test_synchrony_metrics(sorting_result_simple): sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) sorting_result_sync = start_sorting_result(sorting_sync, sorting_result.recording, format="memory") - - previous_synchrony_metrics = compute_synchrony_metrics( - previous_sorting_result, synchrony_sizes=synchrony_sizes - ) + previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_result, synchrony_sizes=synchrony_sizes) current_synchrony_metrics = compute_synchrony_metrics(sorting_result_sync, synchrony_sizes=synchrony_sizes) print(current_synchrony_metrics) # check that all values increased @@ -350,7 +359,9 @@ def test_calculate_drift_metrics(sorting_result_simple): sorting_result = sorting_result_simple sorting_result.compute("spike_locations", **job_kwargs) - drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics(sorting_result, interval_s=10, min_spikes_per_interval=10) + drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics( + sorting_result, interval_s=10, min_spikes_per_interval=10 + ) # print(drifts_ptps, drifts_stds, drift_mads) @@ -368,15 +379,13 @@ def test_calculate_sd_ratio(sorting_result_simple): sorting_result_simple, ) - assert np.all(list(sd_ratio.keys()) == sorting_result_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) - if __name__ == "__main__": - + sorting_result = _sorting_result_simple() print(sorting_result) @@ -392,14 +401,8 @@ def test_calculate_sd_ratio(sorting_result_simple): # test_calculate_amplitude_cv_metrics(sorting_result) test_calculate_sd_ratio(sorting_result) - - # sorting_result_violations = _sorting_result_violations() # print(sorting_result_violations) # test_calculate_isi_violations(sorting_result_violations) # test_calculate_sliding_rp_violations(sorting_result_violations) # test_calculate_rp_violations(sorting_result_violations) - - - - diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 80e398822e..2741d78ea7 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -4,7 +4,11 @@ import numpy as np import pandas as pd from spikeinterface.core import ( - NumpySorting, synthetize_spike_train_bad_isi, add_synchrony_to_sorting, generate_ground_truth_recording, start_sorting_result + NumpySorting, + synthetize_spike_train_bad_isi, + add_synchrony_to_sorting, + generate_ground_truth_recording, + start_sorting_result, ) # from spikeinterface.extractors.toy_example import toy_example @@ -13,17 +17,21 @@ from spikeinterface.qualitymetrics import ( calculate_pc_metrics, nearest_neighbors_isolation, - nearest_neighbors_noise_overlap + nearest_neighbors_noise_overlap, ) - job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") def _sorting_result_simple(): recording, sorting = generate_ground_truth_recording( - durations=[50.0,], sampling_frequency=30_000.0, num_channels=6, num_units=10, + durations=[ + 50.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), seed=2205, @@ -40,6 +48,7 @@ def _sorting_result_simple(): return sorting_result + @pytest.fixture(scope="module") def sorting_result_simple(): return _sorting_result_simple() @@ -50,6 +59,7 @@ def test_calculate_pc_metrics(sorting_result_simple): res = calculate_pc_metrics(sorting_result) print(pd.DataFrame(res)) + def test_nearest_neighbors_isolation(sorting_result_simple): sorting_result = sorting_result_simple this_unit_id = sorting_result.unit_ids[0] @@ -61,6 +71,7 @@ def test_nearest_neighbors_noise_overlap(sorting_result_simple): this_unit_id = sorting_result.unit_ids[0] nearest_neighbors_noise_overlap(sorting_result, this_unit_id) + if __name__ == "__main__": sorting_result = _sorting_result_simple() test_calculate_pc_metrics(sorting_result) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 21c7285455..6e88375f3f 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -50,10 +50,16 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + def get_sorting_result(seed=2205): # we need high firing rate for amplitude_cutoff recording, sorting = generate_ground_truth_recording( - durations=[120.0,], sampling_frequency=30_000.0, num_channels=6, num_units=10, + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), generate_unit_locations_kwargs=dict( margin_um=5.0, @@ -89,14 +95,15 @@ def sorting_result_simple(): def test_compute_quality_metrics(sorting_result_simple): sorting_result = sorting_result_simple print(sorting_result) - + # without PCs - metrics = compute_quality_metrics(sorting_result, - metric_names=["snr"], - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=True, - seed=2205 - ) + metrics = compute_quality_metrics( + sorting_result, + metric_names=["snr"], + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=True, + seed=2205, + ) # print(metrics) qm = sorting_result.get_extension("quality_metrics") @@ -106,25 +113,27 @@ def test_compute_quality_metrics(sorting_result_simple): # with PCs sorting_result.compute("principal_components") - metrics = compute_quality_metrics(sorting_result, - metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=False, - seed=2205 - ) + metrics = compute_quality_metrics( + sorting_result, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) print(metrics.columns) assert "isolation_distance" in metrics.columns + def test_compute_quality_metrics_recordingless(sorting_result_simple): sorting_result = sorting_result_simple - metrics = compute_quality_metrics(sorting_result, - metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=False, - seed=2205 - ) - + metrics = compute_quality_metrics( + sorting_result, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) # make a copy and make it recordingless sorting_result_norec = sorting_result.save_as(format="memory") @@ -133,17 +142,18 @@ def test_compute_quality_metrics_recordingless(sorting_result_simple): print(sorting_result_norec) - metrics_norec = compute_quality_metrics(sorting_result_norec, - metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=False, - seed=2205 - ) + metrics_norec = compute_quality_metrics( + sorting_result_norec, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) for metric_name in metrics.columns: if metric_name == "sd_ratio": # this one need recording!!! - continue + continue assert np.allclose(metrics[metric_name].values, metrics_norec[metric_name].values, rtol=1e-02) @@ -165,19 +175,18 @@ def test_empty_units(sorting_result_simple): sorting_result_empty.compute("templates") sorting_result_empty.compute("spike_amplitudes", **job_kwargs) - metrics_empty = compute_quality_metrics(sorting_result_empty, - metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), - skip_pc_metrics=True, - seed=2205 - ) - + metrics_empty = compute_quality_metrics( + sorting_result_empty, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=True, + seed=2205, + ) for empty_unit_id in sorting_empty.get_empty_unit_ids(): assert np.all(np.isnan(metrics_empty.loc[empty_unit_id])) - # @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics() # def test_amplitude_cutoff(self): @@ -319,4 +328,3 @@ def test_empty_units(sorting_result_simple): test_compute_quality_metrics(sorting_result) test_compute_quality_metrics_recordingless(sorting_result) test_empty_units(sorting_result) - diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index e4b89c999a..cade8d68f4 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -4,7 +4,7 @@ import numpy as np -from spikeinterface.core import WaveformExtractor, get_template_extremum_channel +from spikeinterface.core import WaveformExtractor, get_template_extremum_channel from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 595ade591b..f9b7014d35 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -23,12 +23,10 @@ class AllAmplitudesDistributionsWidget(BaseWidget): Dict of colors with key: unit, value: color, default None """ - def __init__( - self, sorting_result: SortingResult, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs - ): + def __init__(self, sorting_result: SortingResult, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs): self.check_extensions(sorting_result, "spike_amplitudes") - + amplitudes = sorting_result.get_extension("spike_amplitudes").get_data() num_segments = sorting_result.get_num_segments() @@ -46,7 +44,6 @@ def __init__( spike_mask = spikes["unit_index"] == unit_index amplitudes_by_units[unit_id] = amplitudes[spike_mask] - plot_data = dict( unit_ids=unit_ids, unit_colors=unit_colors, @@ -66,7 +63,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = self.ax - parts = ax.violinplot(list(dp.amplitudes_by_units.values()), showmeans=False, showmedians=False, showextrema=False) + parts = ax.violinplot( + list(dp.amplitudes_by_units.values()), showmeans=False, showmedians=False, showextrema=False + ) for i, pc in enumerate(parts["bodies"]): color = dp.unit_colors[dp.unit_ids[i]] diff --git a/src/spikeinterface/widgets/peak_activity.py b/src/spikeinterface/widgets/peak_activity.py index 6375d0ff29..2339166bfb 100644 --- a/src/spikeinterface/widgets/peak_activity.py +++ b/src/spikeinterface/widgets/peak_activity.py @@ -6,7 +6,6 @@ from .base import BaseWidget, to_attr - class PeakActivityMapWidget(BaseWidget): """ Plots spike rate (estimated with detect_peaks()) as 2D activity map. diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index aa952bc6ef..c35cc71e82 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -118,11 +118,21 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ).view v_unit_locations = UnitLocationsWidget( - sorting_result, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview" + sorting_result, + unit_ids=unit_ids, + hide_unit_selector=True, + generate_url=False, + display=False, + backend="sortingview", ).view w = TemplateSimilarityWidget( - sorting_result, unit_ids=unit_ids, immediate_plot=False, generate_url=False, display=False, backend="sortingview" + sorting_result, + unit_ids=unit_ids, + immediate_plot=False, + generate_url=False, + display=False, + backend="sortingview", ) similarity = w.data_plot["similarity"] diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 651170bd3d..844e50924f 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -83,7 +83,7 @@ def __init__( backend=None, **backend_kwargs, ): - + self.check_extensions(sorting_result, "unit_locations") sorting: BaseSorting = sorting_result.sorting @@ -104,8 +104,9 @@ def __init__( extremum_channel_ids = get_template_extremum_channel(sorting_result) unit_id_to_channel_ids = {u: [ch] for u, ch in extremum_channel_ids.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( - unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_result.unit_ids, - channel_ids=sorting_result.channel_ids + unit_id_to_channel_ids=unit_id_to_channel_ids, + unit_ids=sorting_result.unit_ids, + channel_ids=sorting_result.channel_ids, ) else: assert isinstance(sparsity, ChannelSparsity) @@ -211,8 +212,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if len(spike_frames_to_plot) > 0: vspacing = traces_widget.data_plot["vspacing"] traces = traces_widget.data_plot["list_traces"][0] - - # TODO find a better way + + # TODO find a better way nbefore = 30 nafter = 60 waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-nbefore, nafter) - frame_range[0] @@ -258,7 +259,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - self._traces_widget = TracesWidget(sorting_result.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + self._traces_widget = TracesWidget( + sorting_result.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts + ) self.ax = self._traces_widget.ax self.axes = self._traces_widget.axes self.figure = self._traces_widget.figure diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index 9aaf071e3d..90e952ed2a 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -36,7 +36,7 @@ def __init__( **backend_kwargs, ): self.check_extensions(sorting_result, "template_metrics") - template_metrics= sorting_result.get_extension("template_metrics").get_data() + template_metrics = sorting_result.get_extension("template_metrics").get_data() sorting = sorting_result.sorting diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 4628682ed5..4b6b50350e 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -61,7 +61,6 @@ def setUpClass(cls): cls.num_units = len(cls.sorting.get_unit_ids()) - extensions_to_compute = dict( waveforms=dict(), templates=dict(), @@ -69,7 +68,7 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names = ["snr", "isi_violation", "num_spikes"]), + quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes"]), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), @@ -90,11 +89,12 @@ def setUpClass(cls): cls.sparsity_best = compute_sparsity(cls.sorting_result_dense, method="best_channels", num_channels=5) # create sparse - cls.sorting_result_sparse = start_sorting_result(cls.sorting, cls.recording, format="memory", sparsity=cls.sparsity_radius) + cls.sorting_result_sparse = start_sorting_result( + cls.sorting, cls.recording, format="memory", sparsity=cls.sparsity_radius + ) cls.sorting_result_sparse.select_random_spikes() cls.sorting_result_sparse.compute(extensions_to_compute, **job_kwargs) - # cls.skip_backends = ["ipywidgets", "ephyviewer"] # TODO : delete this after debug cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] @@ -155,8 +155,6 @@ def test_plot_spikes_on_traces(self): if backend not in self.skip_backends: sw.plot_spikes_on_traces(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) - - def test_plot_unit_waveforms(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: @@ -300,11 +298,10 @@ def test_plot_unit_waveforms_density_map(self): for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] - + # on dense sw.plot_unit_waveforms_density_map( - self.sorting_result_dense, - unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) # on sparse sw.plot_unit_waveforms_density_map( @@ -321,7 +318,7 @@ def test_plot_unit_waveforms_density_map(self): **self.backend_kwargs[backend], ) - # on sparse with same_axis + # on sparse with same_axis sw.plot_unit_waveforms_density_map( self.sorting_result_sparse, sparsity=None, @@ -397,7 +394,9 @@ def test_plot_amplitudes(self): if backend not in self.skip_backends: sw.plot_amplitudes(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting_result_dense.unit_ids[:4] - sw.plot_amplitudes(self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) + sw.plot_amplitudes( + self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_amplitudes( self.sorting_result_dense, unit_ids=unit_ids, @@ -480,10 +479,16 @@ def test_plot_unit_summary(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( - self.sorting_result_dense, self.sorting_result_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.sorting_result_dense, + self.sorting_result_dense.sorting.unit_ids[0], + backend=backend, + **self.backend_kwargs[backend], ) sw.plot_unit_summary( - self.sorting_result_sparse, self.sorting_result_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.sorting_result_sparse, + self.sorting_result_sparse.sorting.unit_ids[0], + backend=backend, + **self.backend_kwargs[backend], ) def test_plot_sorting_summary(self): @@ -493,7 +498,10 @@ def test_plot_sorting_summary(self): sw.plot_sorting_summary(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary( - self.sorting_result_sparse, sparsity=self.sparsity_strict, backend=backend, **self.backend_kwargs[backend] + self.sorting_result_sparse, + sparsity=self.sparsity_strict, + backend=backend, + **self.backend_kwargs[backend], ) def test_plot_agreement_matrix(self): @@ -554,8 +562,6 @@ def test_plot_multicomparison(self): if backend == "matplotlib": _, axes = plt.subplots(len(mcmp.object_list), 1) sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) - - if __name__ == "__main__": @@ -570,7 +576,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_traces() # mytest.test_plot_spikes_on_traces() # mytest.test_plot_unit_waveforms() - # mytest.test_plot_unit_templates() + # mytest.test_plot_unit_templates() # mytest.test_plot_unit_depths() # mytest.test_plot_autocorrelograms() # mytest.test_plot_crosscorrelograms() @@ -584,7 +590,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() - # mytest.test_plot_rasters() + # mytest.test_plot_rasters() # mytest.test_plot_unit_probe_map() # mytest.test_plot_unit_presence() # mytest.test_plot_peak_activity() @@ -593,5 +599,3 @@ def test_plot_multicomparison(self): plt.show() # TestWidgets.tearDownClass() - - diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 0827b00aed..a37ec46bd3 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -26,9 +26,7 @@ class UnitDepthsWidget(BaseWidget): Sign of peak for amplitudes """ - def __init__( - self, sorting_result, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs - ): + def __init__(self, sorting_result, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs): unit_ids = sorting_result.sorting.unit_ids if unit_colors is None: diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index fe6e9f3c03..20640e5e63 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -11,6 +11,7 @@ from ..core.sortingresult import SortingResult from ..core.template_tools import _get_dense_templates_array + class UnitProbeMapWidget(BaseWidget): """ Plots unit map. Amplitude is color coded on probe contact. @@ -120,7 +121,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def animate_func(frame): for i, unit_id in enumerate(self.unit_ids): - # template = we.get_template(unit_id) + # template = we.get_template(unit_id) template = templates[i, :, :] contacts_values = np.abs(template[frame, :]) poly_contact = all_poly_contact[i] diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 016738d393..eadf9f4037 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -91,7 +91,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( - sorting_result, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, backend="matplotlib", ax=ax1 + sorting_result, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_legend=False, + backend="matplotlib", + ax=ax1, ) unit_locations = sorting_result.get_extension("unit_locations").get_data(outputs="by_unit") diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 92edd677a6..5c182f1f3f 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -9,6 +9,7 @@ from ..core.basesorting import BaseSorting from ..core.template_tools import _get_dense_templates_array + class UnitWaveformsWidget(BaseWidget): """ Plots unit waveforms. @@ -103,7 +104,7 @@ def __init__( backend=None, **backend_kwargs, ): - + sorting: BaseSorting = sorting_result.sorting if unit_ids is None: @@ -132,7 +133,9 @@ def __init__( # in this case, we construct a dense sparsity unit_id_to_channel_ids = {u: sorting_result.channel_ids for u in sorting_result.unit_ids} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( - unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_result.unit_ids, channel_ids=sorting_result.channel_ids + unit_id_to_channel_ids=unit_id_to_channel_ids, + unit_ids=sorting_result.unit_ids, + channel_ids=sorting_result.channel_ids, ) else: assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" @@ -142,8 +145,6 @@ def __init__( assert ext is not None, "plot_waveforms() need extension 'templates'" templates = ext.get_templates(unit_ids=unit_ids, operator="average") - - templates_shading = self._get_template_shadings(sorting_result, unit_ids, templates_percentile_shading) xvectors, y_scale, y_offset, delta_x = get_waveforms_scales( @@ -433,7 +434,7 @@ def _get_template_shadings(self, sorting_result, unit_ids, templates_percentile_ templates_shading = [] for percentile in templates_percentile_shading: template_percentile = ext.get_templates(unit_ids=unit_ids, operator="percentile", percentile=percentile) - + templates_shading.append(template_percentile) return templates_shading diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 7617154b3e..6c82c2bd4d 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -60,7 +60,9 @@ def __init__( if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" - max_channels = get_template_extremum_channel(sorting_result, mode="extremum", peak_sign=peak_sign, outputs="index") + max_channels = get_template_extremum_channel( + sorting_result, mode="extremum", peak_sign=peak_sign, outputs="index" + ) # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all if sorting_result.is_sparse(): From 6a432773b75bd5c7bc38fda58df0c86b1446eb8b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 13 Feb 2024 08:30:26 +0100 Subject: [PATCH 055/192] WIP --- .../sorters/internal/spyking_circus2.py | 38 ++++--------------- .../sortingcomponents/matching/circus.py | 28 +++++--------- 2 files changed, 17 insertions(+), 49 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ce6e3c26ca..02c8ca998b 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -28,8 +28,6 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": { - "max_spikes_per_unit": 200, - "overwrite": True, "sparse": True, "method": "energy", "threshold": 0.25, @@ -205,47 +203,25 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(clustering_folder / "labels", labels) np.save(clustering_folder / "peaks", selected_peaks) - ## We get the templates our of such a clustering - waveforms_params = params["waveforms"].copy() - waveforms_params.update(job_kwargs) - - for k in ["ms_before", "ms_after"]: - waveforms_params[k] = params["general"][k] - - if params["shared_memory"] and not params["debug"]: - mode = "memory" - waveforms_folder = None - else: - sorting = sorting.save(folder=clustering_folder / "sorting") - mode = "folder" - waveforms_folder = sorter_output_folder / "waveforms" - - # we = extract_waveforms( - # recording_f, - # sorting, - # waveforms_folder, - # return_scaled=False, - # precompute_template=["median"], - # mode=mode, - # **waveforms_params, - # ) - nbefore = int(params["general"]["ms_before"] * sampling_frequency / 1000.0) nafter = int(params["general"]["ms_after"] * sampling_frequency / 1000.0) - templates_array = estimate_templates(recording, labeled_peaks, unit_ids, nbefore, nafter, + templates_array = estimate_templates(recording_f, labeled_peaks, unit_ids, nbefore, nafter, False, job_name=None, **job_kwargs) templates = Templates(templates_array, - sampling_frequency, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()) + sampling_frequency, nbefore, None, recording_f.channel_ids, unit_ids, recording_f.get_probe()) + + if params["debug"]: + sorting = sorting.save(folder=clustering_folder / "sorting") ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces matching_method = params["matching"]["method"] matching_params = params["matching"]["method_kwargs"].copy() + matching_params["templates"] = templates matching_job_params = {} matching_job_params.update(job_kwargs) - matching_params["templates"] = templates - + if matching_method == "circus-omp-svd": for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: if value in matching_job_params: diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index a4c7ab4735..fc02f09634 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -516,25 +516,20 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): "max_failures": 20, "omp_min_sps": 0.1, "relative_error": 5e-5, - "waveform_extractor": None, + "templates": None, "rank": 5, - "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], "vicinity": 0, - "optimize_amplitudes": False, } @classmethod def _prepare_templates(cls, d): - waveform_extractor = d["waveform_extractor"] - num_templates = len(d["waveform_extractor"].sorting.unit_ids) + templates = d["templates"] + num_templates = len(d["templates"].unit_ids) assert d["stop_criteria"] in ["max_failures", "omp_min_sps", "relative_error"] - if not waveform_extractor.is_sparse(): - sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask - else: - sparsity = waveform_extractor.sparsity.mask + sparsity = templates.sparsity.mask d["sparsity_mask"] = sparsity units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) @@ -543,7 +538,7 @@ def _prepare_templates(cls, d): for i in range(num_templates): (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) - templates = waveform_extractor.get_all_templates(mode="median").copy() + templates_array = templates.templates_array.copy() # First, we set masked channels to 0 for count in range(num_templates): @@ -551,13 +546,13 @@ def _prepare_templates(cls, d): # Then we keep only the strongest components rank = d["rank"] - temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) + temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False) d["temporal"] = temporal[:, :, :rank] d["singular"] = singular[:, :rank] d["spatial"] = spatial[:, :rank, :] # We reconstruct the approximated templates - templates = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) + templates_array = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) d["templates"] = np.zeros(templates.shape, dtype=np.float32) d["norms"] = np.zeros(num_templates, dtype=np.float32) @@ -666,8 +661,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): @classmethod def serialize_method_kwargs(cls, kwargs): kwargs = dict(kwargs) - # remove waveform_extractor - kwargs.pop("waveform_extractor") return kwargs @classmethod @@ -681,11 +674,10 @@ def get_margin(cls, recording, kwargs): @classmethod def main_function(cls, traces, d): - templates = d["templates"] num_templates = d["num_templates"] num_channels = d["num_channels"] num_samples = d["num_samples"] - overlaps = d["overlaps"] + overlaps_array = d["overlaps_array"] norms = d["norms"] nbefore = d["nbefore"] nafter = d["nafter"] @@ -766,7 +758,7 @@ def main_function(cls, traces, d): myline = neighbor_window + delta_t[idx] myindices = selection[0, idx] - local_overlaps = overlaps[best_cluster_ind] + local_overlaps = overlaps_array[best_cluster_ind] overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] table = d["unit_overlaps_tables"][best_cluster_ind] @@ -839,7 +831,7 @@ def main_function(cls, traces, d): tmp_best, tmp_peak = selection[:, i] diff_amp = diff_amplitudes[i] * norms[tmp_best] - local_overlaps = overlaps[tmp_best] + local_overlaps = overlaps_array[tmp_best] overlapping_templates = d["units_overlaps"][tmp_best] if not tmp_peak in neighbors.keys(): From 87e56a347f538c73006d2b583071376789c74f88 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 13 Feb 2024 08:53:38 +0100 Subject: [PATCH 056/192] WIP --- .../sorters/internal/spyking_circus2.py | 6 +- .../clustering/clustering_tools.py | 12 +- .../clustering/random_projections.py | 3 + .../sortingcomponents/matching/circus.py | 465 +----------------- 4 files changed, 25 insertions(+), 461 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 02c8ca998b..a355141f40 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -13,6 +13,7 @@ from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter from spikeinterface.sortingcomponents.tools import cache_preprocessing from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.sparsity import compute_sparsity try: import hdbscan @@ -41,7 +42,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "select_per_channel": False, }, "clustering": {"legacy": False}, - "matching": {"method": "wobble", "method_kwargs": {}}, + "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, "apply_preprocessing": True, "shared_memory": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -212,6 +213,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates = Templates(templates_array, sampling_frequency, nbefore, None, recording_f.channel_ids, unit_ids, recording_f.get_probe()) + sparsity = compute_sparsity(templates, method='radius') + templates.sparsity = sparsity + if params["debug"]: sorting = sorting.save(folder=clustering_folder / "sorting") diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 25d4f64456..f42d0e2e98 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -541,7 +541,7 @@ def remove_duplicates_via_matching( method_kwargs={}, job_kwargs={}, tmp_folder=None, - method="naive", + method="circus-omp-svd", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core import BinaryRecordingExtractor @@ -561,10 +561,6 @@ def remove_duplicates_via_matching( fs = templates.sampling_frequency num_chans = len(templates.channel_ids) - #if waveform_extractor.is_sparse(): - # for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - # templates[count][:, ~sparsity[count]] = 0 - zdata = templates_array.reshape(nb_templates, -1) padding = 2 * duration @@ -595,11 +591,9 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() local_params.update( - {"templates": templates, "amplitudes": [0.975, 1.025], "optimize_amplitudes": False} + {"templates": templates, "amplitudes": [0.975, 1.025]} ) - - ignore_ids = [] similar_templates = [[], []] @@ -616,7 +610,7 @@ def remove_duplicates_via_matching( local_params.update( { "overlaps": computed["overlaps"], - "templates": computed["templates"], + "normed_templates": computed["normed_templates"], "norms": computed["norms"], "temporal": computed["temporal"], "spatial": computed["spatial"], diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index b91153a5b8..254c6d31df 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -25,6 +25,7 @@ from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.core.node_pipeline import ( run_node_pipeline, ExtractDenseWaveforms, @@ -209,6 +210,8 @@ def main_function(cls, recording, peaks, params): templates = Templates(templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()) + sparsity = compute_sparsity(templates, method='radius') + templates.sparsity = sparsity cleaning_matching_params = params["job_kwargs"].copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index fc02f09634..67bc0d4158 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -37,99 +37,6 @@ from .main import BaseTemplateMatchingEngine -from scipy.fft._helper import _init_nd_shape_and_axes - -try: - from scipy.signal.signaltools import _init_freq_conv_axes, _apply_conv_mode -except Exception: - from scipy.signal._signaltools import _init_freq_conv_axes, _apply_conv_mode -from scipy import linalg, fft as sp_fft - - -def get_scipy_shape(in1, in2, mode="full", axes=None, calc_fast_len=True): - in1 = np.asarray(in1) - in2 = np.asarray(in2) - - if in1.ndim == in2.ndim == 0: # scalar inputs - return in1 * in2 - elif in1.ndim != in2.ndim: - raise ValueError("in1 and in2 should have the same dimensionality") - elif in1.size == 0 or in2.size == 0: # empty arrays - return np.array([]) - - in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False) - - s1 = in1.shape - s2 = in2.shape - - shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1 for i in range(in1.ndim)] - - if not len(axes): - return in1 * in2 - - complex_result = in1.dtype.kind == "c" or in2.dtype.kind == "c" - - if calc_fast_len: - # Speed up FFT by padding to optimal size. - fshape = [sp_fft.next_fast_len(shape[a], not complex_result) for a in axes] - else: - fshape = shape - - return fshape, axes - - -def fftconvolve_with_cache(in1, in2, cache, mode="full", axes=None): - in1 = np.asarray(in1) - in2 = np.asarray(in2) - - if in1.ndim == in2.ndim == 0: # scalar inputs - return in1 * in2 - elif in1.ndim != in2.ndim: - raise ValueError("in1 and in2 should have the same dimensionality") - elif in1.size == 0 or in2.size == 0: # empty arrays - return np.array([]) - - in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False) - - s1 = in1.shape - s2 = in2.shape - - shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1 for i in range(in1.ndim)] - - ret = _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True) - - return _apply_conv_mode(ret, s1, s2, mode, axes) - - -def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): - if not len(axes): - return in1 * in2 - - complex_result = in1.dtype.kind == "c" or in2.dtype.kind == "c" - - if calc_fast_len: - # Speed up FFT by padding to optimal size. - fshape = [sp_fft.next_fast_len(shape[a], not complex_result) for a in axes] - else: - fshape = shape - - if not complex_result: - fft, ifft = sp_fft.rfftn, sp_fft.irfftn - else: - fft, ifft = sp_fft.fftn, sp_fft.ifftn - - sp1 = cache["full"][cache["mask"]] - sp2 = cache["template"] - - # sp2 = fft(in2[cache['mask']], fshape, axes=axes) - ret = ifft(sp1 * sp2, fshape, axes=axes) - - if calc_fast_len: - fslice = tuple([slice(sz) for sz in shape]) - ret = ret[fslice] - - return ret - def compute_overlaps(templates, num_samples, num_channels, sparsities): num_templates = len(templates) @@ -163,320 +70,6 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities): return new_overlaps -class CircusOMPPeeler(BaseTemplateMatchingEngine): - """ - Orthogonal Matching Pursuit inspired from Spyking Circus sorter - - https://elifesciences.org/articles/34518 - - This is an Orthogonal Template Matching algorithm. For speed and - memory optimization, templates are automatically sparsified. Signal - is convolved with the templates, and as long as some scalar products - are higher than a given threshold, we use a Cholesky decomposition - to compute the optimal amplitudes needed to reconstruct the signal. - - IMPORTANT NOTE: small chunks are more efficient for such Peeler, - consider using 100ms chunk - - Parameters - ---------- - amplitude: tuple - (Minimal, Maximal) amplitudes allowed for every template - omp_min_sps: float - Stopping criteria of the OMP algorithm, in percentage of the norm - noise_levels: array - The noise levels, for every channels. If None, they will be automatically - computed - random_chunk_kwargs: dict - Parameters for computing noise levels, if not provided (sub optimal) - sparse_kwargs: dict - Parameters to extract a sparsity mask from the waveform_extractor, if not - already sparse. - ----- - """ - - _default_params = { - "amplitudes": [0.6, 2], - "omp_min_sps": 0.1, - "waveform_extractor": None, - "templates": None, - "overlaps": None, - "norms": None, - "random_chunk_kwargs": {}, - "noise_levels": None, - "sparse_kwargs": {"method": "ptp", "threshold": 1}, - "ignored_ids": [], - "vicinity": 0, - } - - @classmethod - def _prepare_templates(cls, d): - waveform_extractor = d["waveform_extractor"] - num_templates = len(d["waveform_extractor"].sorting.unit_ids) - - if not waveform_extractor.is_sparse(): - sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask - else: - sparsity = waveform_extractor.sparsity.mask - - templates = waveform_extractor.get_all_templates(mode="median").copy() - - d["sparsities"] = {} - d["templates"] = {} - d["norms"] = np.zeros(num_templates, dtype=np.float32) - - for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - template = templates[count][:, sparsity[count]] - (d["sparsities"][count],) = np.nonzero(sparsity[count]) - d["norms"][count] = np.linalg.norm(template) - d["templates"][count] = template / d["norms"][count] - - return d - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - d = cls._default_params.copy() - d.update(kwargs) - - # assert isinstance(d['waveform_extractor'], WaveformExtractor) - - for v in ["omp_min_sps"]: - assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" - - d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() - d["num_samples"] = d["waveform_extractor"].nsamples - d["nbefore"] = d["waveform_extractor"].nbefore - d["nafter"] = d["waveform_extractor"].nafter - d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() - d["vicinity"] *= d["num_samples"] - - if d["noise_levels"] is None: - print("CircusOMPPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - - if d["templates"] is None: - d = cls._prepare_templates(d) - else: - for key in ["norms", "sparsities"]: - assert d[key] is not None, "If templates are provided, %d should also be there" % key - - d["num_templates"] = len(d["templates"]) - - if d["overlaps"] is None: - d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) - - d["ignored_ids"] = np.array(d["ignored_ids"]) - - omp_min_sps = d["omp_min_sps"] - # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) - d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) - - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - # remove waveform_extractor - kwargs.pop("waveform_extractor") - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) - return margin - - @classmethod - def main_function(cls, traces, d): - templates = d["templates"] - num_templates = d["num_templates"] - num_channels = d["num_channels"] - num_samples = d["num_samples"] - overlaps = d["overlaps"] - norms = d["norms"] - nbefore = d["nbefore"] - nafter = d["nafter"] - omp_tol = np.finfo(np.float32).eps - num_samples = d["nafter"] + d["nbefore"] - neighbor_window = num_samples - 1 - min_amplitude, max_amplitude = d["amplitudes"] - sparsities = d["sparsities"] - ignored_ids = d["ignored_ids"] - stop_criteria = d["stop_criteria"] - vicinity = d["vicinity"] - - if "cached_fft_kernels" not in d: - d["cached_fft_kernels"] = {"fshape": 0} - - cached_fft_kernels = d["cached_fft_kernels"] - - num_timesteps = len(traces) - - num_peaks = num_timesteps - num_samples + 1 - - traces = traces.T - - dummy_filter = np.empty((num_channels, num_samples), dtype=np.float32) - dummy_traces = np.empty((num_channels, num_timesteps), dtype=np.float32) - - fshape, axes = get_scipy_shape(dummy_filter, traces, axes=1) - fft_cache = {"full": sp_fft.rfftn(traces, fshape, axes=axes)} - - scalar_products = np.empty((num_templates, num_peaks), dtype=np.float32) - - flagged_chunk = cached_fft_kernels["fshape"] != fshape[0] - - for i in range(num_templates): - if i not in ignored_ids: - if i not in cached_fft_kernels or flagged_chunk: - kernel_filter = np.ascontiguousarray(templates[i][::-1].T) - cached_fft_kernels.update({i: sp_fft.rfftn(kernel_filter, fshape, axes=axes)}) - cached_fft_kernels["fshape"] = fshape[0] - - fft_cache.update({"mask": sparsities[i], "template": cached_fft_kernels[i]}) - - convolution = fftconvolve_with_cache(dummy_filter, dummy_traces, fft_cache, axes=1, mode="valid") - if len(convolution) > 0: - scalar_products[i] = convolution.sum(0) - else: - scalar_products[i] = 0 - - if len(ignored_ids) > 0: - scalar_products[ignored_ids] = -np.inf - - num_spikes = 0 - - spikes = np.empty(scalar_products.size, dtype=spike_dtype) - idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) - - M = np.zeros((100, 100), dtype=np.float32) - - all_selections = np.empty((2, scalar_products.size), dtype=np.int32) - final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) - num_selection = 0 - - full_sps = scalar_products.copy() - - neighbors = {} - cached_overlaps = {} - - is_valid = scalar_products > stop_criteria - all_amplitudes = np.zeros(0, dtype=np.float32) - is_in_vicinity = np.zeros(0, dtype=np.int32) - - while np.any(is_valid): - best_amplitude_ind = scalar_products[is_valid].argmax() - best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) - - if num_selection > 0: - delta_t = selection[1] - peak_index - idx = np.where((delta_t < neighbor_window) & (delta_t > -num_samples))[0] - myline = num_samples + delta_t[idx] - - if not best_cluster_ind in cached_overlaps: - cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() - - if num_selection == M.shape[0]: - Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) - Z[:num_selection, :num_selection] = M - M = Z - - M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] - - if vicinity == 0: - scipy.linalg.solve_triangular( - M[:num_selection, :num_selection], - M[num_selection, :num_selection], - trans=0, - lower=1, - overwrite_b=True, - check_finite=False, - ) - - v = nrm2(M[num_selection, :num_selection]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) - else: - is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] - - if len(is_in_vicinity) > 0: - L = M[is_in_vicinity, :][:, is_in_vicinity] - - M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( - L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False - ) - - v = nrm2(M[num_selection, is_in_vicinity]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) - else: - M[num_selection, num_selection] = 1.0 - else: - M[0, 0] = 1 - - all_selections[:, num_selection] = [best_cluster_ind, peak_index] - num_selection += 1 - - selection = all_selections[:, :num_selection] - res_sps = full_sps[selection[0], selection[1]] - - if True: # vicinity == 0: - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - all_amplitudes /= norms[selection[0]] - else: - # This is not working, need to figure out why - is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) - all_amplitudes = np.append(all_amplitudes, np.float32(1)) - L = M[is_in_vicinity, :][:, is_in_vicinity] - all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) - all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] - - diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] - modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] - final_amplitudes[selection[0], selection[1]] = all_amplitudes - - for i in modified: - tmp_best, tmp_peak = selection[:, i] - diff_amp = diff_amplitudes[i] * norms[tmp_best] - - if not tmp_best in cached_overlaps: - cached_overlaps[tmp_best] = overlaps[tmp_best].toarray() - - if not tmp_peak in neighbors.keys(): - idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)] - tdx = [num_samples + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak] - neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} - - idx = neighbors[tmp_peak]["idx"] - tdx = neighbors[tmp_peak]["tdx"] - - to_add = diff_amp * cached_overlaps[tmp_best][:, tdx[0] : tdx[1]] - scalar_products[:, idx[0] : idx[1]] -= to_add - - is_valid = scalar_products > stop_criteria - - is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) - valid_indices = np.where(is_valid) - - num_spikes = len(valid_indices[0]) - spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] - spikes["channel_index"][:num_spikes] = 0 - spikes["cluster_index"][:num_spikes] = valid_indices[0] - spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] - - spikes = spikes[:num_spikes] - order = np.argsort(spikes["sample_index"]) - spikes = spikes[order] - - return spikes - class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ @@ -542,7 +135,7 @@ def _prepare_templates(cls, d): # First, we set masked channels to 0 for count in range(num_templates): - templates[count][:, ~d["sparsity_mask"][count]] = 0 + templates_array[count][:, ~d["sparsity_mask"][count]] = 0 # Then we keep only the strongest components rank = d["rank"] @@ -554,39 +147,15 @@ def _prepare_templates(cls, d): # We reconstruct the approximated templates templates_array = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) - d["templates"] = np.zeros(templates.shape, dtype=np.float32) + d["normed_templates"] = np.zeros(templates_array.shape, dtype=np.float32) d["norms"] = np.zeros(num_templates, dtype=np.float32) # And get the norms, saving compressed templates for CC matrix for count in range(num_templates): - template = templates[count][:, d["sparsity_mask"][count]] + template = templates_array[count][:, d["sparsity_mask"][count]] d["norms"][count] = np.linalg.norm(template) - d["templates"][count][:, d["sparsity_mask"][count]] = template / d["norms"][count] - - if d["optimize_amplitudes"]: - noise = np.random.randn(200, d["num_samples"] * d["num_channels"]) - r = d["templates"].reshape(num_templates, -1).dot(noise.reshape(len(noise), -1).T) - s = r / d["norms"][:, np.newaxis] - mad = np.median(np.abs(s - np.median(s, 1)[:, np.newaxis]), 1) - a_min = np.median(s, 1) + 5 * mad - - means = np.zeros((num_templates, num_templates), dtype=np.float32) - stds = np.zeros((num_templates, num_templates), dtype=np.float32) - for count, unit_id in enumerate(waveform_extractor.unit_ids): - w = waveform_extractor.get_waveforms(unit_id, force_dense=True) - r = d["templates"].reshape(num_templates, -1).dot(w.reshape(len(w), -1).T) - s = r / d["norms"][:, np.newaxis] - means[count] = np.median(s, 1) - stds[count] = np.median(np.abs(s - np.median(s, 1)[:, np.newaxis]), 1) - - _, a_max = d["amplitudes"] - d["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) - - for count in range(num_templates): - indices = np.argsort(means[count]) - a = np.where(indices == count)[0][0] - d["amplitudes"][count][1] = 1 + 5 * stds[count, indices[a]] - d["amplitudes"][count][0] = max(a_min[count], 1 - 5 * stds[count, indices[a]]) + d["normed_templates"][count][:, d["sparsity_mask"][count]] = template / d["norms"][count] + d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] d["temporal"] = np.flip(d["temporal"], axis=1) @@ -619,7 +188,6 @@ def _prepare_templates(cls, d): d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) d["singular"] = d["singular"].T[:, :, np.newaxis] - return d @classmethod @@ -627,14 +195,14 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls._default_params.copy() d.update(kwargs) - d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() - d["num_samples"] = d["waveform_extractor"].nsamples - d["nbefore"] = d["waveform_extractor"].nbefore - d["nafter"] = d["waveform_extractor"].nafter - d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["num_channels"] = recording.get_num_channels() + d["num_samples"] = d["templates"].num_samples + d["nbefore"] = d["templates"].nbefore + d["nafter"] = d["templates"].nafter + d["sampling_frequency"] = recording.get_sampling_frequency() d["vicinity"] *= d["num_samples"] - if "templates" not in d: + if "normed_templates" not in d: d = cls._prepare_templates(d) else: for key in [ @@ -648,7 +216,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): ]: assert d[key] is not None, "If templates are provided, %d should also be there" % key - d["num_templates"] = len(d["templates"]) + d["num_templates"] = len(d["templates"].templates_array) d["ignored_ids"] = np.array(d["ignored_ids"]) d["unit_overlaps_tables"] = {} @@ -677,19 +245,14 @@ def main_function(cls, traces, d): num_templates = d["num_templates"] num_channels = d["num_channels"] num_samples = d["num_samples"] - overlaps_array = d["overlaps_array"] + overlaps_array = d["overlaps"] norms = d["norms"] nbefore = d["nbefore"] nafter = d["nafter"] omp_tol = np.finfo(np.float32).eps num_samples = d["nafter"] + d["nbefore"] neighbor_window = num_samples - 1 - if d["optimize_amplitudes"]: - min_amplitude, max_amplitude = d["amplitudes"][:, 0], d["amplitudes"][:, 1] - min_amplitude = min_amplitude[:, np.newaxis] - max_amplitude = max_amplitude[:, np.newaxis] - else: - min_amplitude, max_amplitude = d["amplitudes"] + min_amplitude, max_amplitude = d["amplitudes"] ignored_ids = d["ignored_ids"] vicinity = d["vicinity"] rank = d["rank"] From 3e7839c9333d39daba677c3110d41d478d97a165 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 13 Feb 2024 09:57:47 +0100 Subject: [PATCH 057/192] improve WaveformExtractor bacward compatibility --- ...forms_extractor_backwards_compatibility.py | 26 +++- ...forms_extractor_backwards_compatibility.py | 124 ++++++++++-------- 2 files changed, 91 insertions(+), 59 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index f8f4b5af2d..646789423f 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -3,11 +3,15 @@ import shutil +import numpy as np from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core.waveforms_extractor_backwards_compatibility import extract_waveforms as mock_extract_waveforms from spikeinterface.core.waveforms_extractor_backwards_compatibility import load_waveforms as load_waveforms_backwards +from spikeinterface.core.waveforms_extractor_backwards_compatibility import _read_old_waveforms_extractor_binary + +import spikeinterface.full as si # remove this when WaveformsExtractor will be removed from spikeinterface.core import extract_waveforms as old_extract_waveforms @@ -83,10 +87,24 @@ def test_extract_waveforms(): print(mock_loaded_we_old) -# @pytest.mark.skip(): -# def test_read_old_waveforms_extractor_binary(): -# folder = "" +@pytest.mark.skip() +def test_read_old_waveforms_extractor_binary(): + folder = "/data_local/DataSpikeSorting/waveform_extractor_backward_compatibility/waveforms_extractor_1" + sorting_result = _read_old_waveforms_extractor_binary(folder) + + print(sorting_result) + + for ext_name in sorting_result.get_loaded_extension_names(): + print() + print(ext_name) + keys = sorting_result.get_extension(ext_name).data.keys() + print(keys) + data = sorting_result.get_extension(ext_name).get_data() + if isinstance(data, np.ndarray): + print(data.shape) + if __name__ == "__main__": - test_extract_waveforms() + # test_extract_waveforms() + test_read_old_waveforms_extractor_binary() diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index db8365d5b7..abf9edc86a 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -18,7 +18,7 @@ from .baserecording import BaseRecording from .basesorting import BaseSorting -from .sortingresult import start_sorting_result +from .sortingresult import start_sorting_result, get_extension_class from .job_tools import split_job_kwargs from .sparsity import ChannelSparsity from .sortingresult import SortingResult, load_sorting_result @@ -354,6 +354,7 @@ def load_waveforms( def _read_old_waveforms_extractor_binary(folder): + folder = Path(folder) params_file = folder / "params.json" if not params_file.exists(): raise ValueError(f"This folder is not a WaveformsExtractor folder {folder}") @@ -441,7 +442,12 @@ def _read_old_waveforms_extractor_binary(folder): sorting_result.random_spikes_indices = random_spikes_indices ext = ComputeWaveforms(sorting_result) - ext.params = params + ext.params = dict( + ms_before=params["ms_before"], + ms_after=params["ms_after"], + return_scaled=params["return_scaled"], + dtype=params["dtype"], + ) ext.data["waveforms"] = waveforms sorting_result.extensions["waveforms"] = ext @@ -454,62 +460,70 @@ def _read_old_waveforms_extractor_binary(folder): templates[mode] = np.load(template_file) if len(templates) > 0: ext = ComputeTemplates(sorting_result) - ext.params = dict(operators=list(templates.keys())) + ext.params = dict( + nbefore=nbefore, + nafter=nafter, + return_scaled=params["return_scaled"], + operators=list(templates.keys()) + ) for mode, arr in templates.items(): ext.data[mode] = arr sorting_result.extensions["templates"] = ext - # TODO : implement this when extension will be prted in the new API - # old_extension_to_new_class : { - # old extensions with same names and equvalent data except similarity>template_similarity - # "spike_amplitudes": , - # "spike_locations": , - # "amplitude_scalings": , - # "template_metrics" : , - # "similarity": , - # "unit_locations": , - # "correlograms" : , - # isi_histograms: , - # "noise_levels": , - # "quality_metrics": , - # "principal_components" : , - # } - # for ext_name, new_class in old_extension_to_new_class.items(): - # ext_folder = folder / ext_name - # ext = new_class(sorting_result) - # with open(ext_folder / "params.json", "r") as f: - # params = json.load(f) - # ext.params = params - # if ext_name == "spike_amplitudes": - # amplitudes = [] - # for segment_index in range(sorting.get_num_segments()): - # amplitudes.append(np.load(ext_folder / f"amplitude_segment_{segment_index}.npy")) - # amplitudes = np.concatenate(amplitudes) - # ext.data["amplitudes"] = amplitudes - # elif ext_name == "spike_locations": - # ext.data["spike_locations"] = np.load(ext_folder / "spike_locations.npy") - # elif ext_name == "amplitude_scalings": - # ext.data["amplitude_scalings"] = np.load(ext_folder / "amplitude_scalings.npy") - # elif ext_name == "template_metrics": - # import pandas as pd - # ext.data["metrics"] = pd.read_csv(ext_folder / "metrics.csv", index_col=0) - # elif ext_name == "similarity": - # ext.data["similarity"] = np.load(ext_folder / "similarity.npy") - # elif ext_name == "unit_locations": - # ext.data["unit_locations"] = np.load(ext_folder / "unit_locations.npy") - # elif ext_name == "correlograms": - # ext.data["ccgs"] = np.load(ext_folder / "ccgs.npy") - # ext.data["bins"] = np.load(ext_folder / "bins.npy") - # elif ext_name == "isi_histograms": - # ext.data["isi_histograms"] = np.load(ext_folder / "isi_histograms.npy") - # ext.data["bins"] = np.load(ext_folder / "bins.npy") - # elif ext_name == "noise_levels": - # ext.data["noise_levels"] = np.load(ext_folder / "noise_levels.npy") - # elif ext_name == "quality_metrics": - # import pandas as pd - # ext.data["metrics"] = pd.read_csv(ext_folder / "metrics.csv", index_col=0) - # elif ext_name == "principal_components": - # # TODO: this is for you - # pass + # old extensions with same names and equvalent data except similarity>template_similarity + old_extension_to_new_class = { + "spike_amplitudes": "spike_amplitudes", + "spike_locations": "spike_locations", + "amplitude_scalings": "amplitude_scalings", + "template_metrics" : "template_metrics", + "similarity": "template_similarity", + "unit_locations": "unit_locations", + "correlograms" : "correlograms", + "isi_histograms": "isi_histograms", + "noise_levels": "noise_levels", + "quality_metrics": "quality_metrics", + # "principal_components" : "principal_components", + } + for old_name, new_name in old_extension_to_new_class.items(): + ext_folder = folder / old_name + if not ext_folder.is_dir(): + continue + new_class = get_extension_class(new_name) + ext = new_class(sorting_result) + with open(ext_folder / "params.json", "r") as f: + params = json.load(f) + ext.params = params + if new_name == "spike_amplitudes": + amplitudes = [] + for segment_index in range(sorting.get_num_segments()): + amplitudes.append(np.load(ext_folder / f"amplitude_segment_{segment_index}.npy")) + amplitudes = np.concatenate(amplitudes) + ext.data["amplitudes"] = amplitudes + elif new_name == "spike_locations": + ext.data["spike_locations"] = np.load(ext_folder / "spike_locations.npy") + elif new_name == "amplitude_scalings": + ext.data["amplitude_scalings"] = np.load(ext_folder / "amplitude_scalings.npy") + elif new_name == "template_metrics": + import pandas as pd + ext.data["metrics"] = pd.read_csv(ext_folder / "metrics.csv", index_col=0) + elif new_name == "template_similarity": + ext.data["similarity"] = np.load(ext_folder / "similarity.npy") + elif new_name == "unit_locations": + ext.data["unit_locations"] = np.load(ext_folder / "unit_locations.npy") + elif new_name == "correlograms": + ext.data["ccgs"] = np.load(ext_folder / "ccgs.npy") + ext.data["bins"] = np.load(ext_folder / "bins.npy") + elif new_name == "isi_histograms": + ext.data["isi_histograms"] = np.load(ext_folder / "isi_histograms.npy") + ext.data["bins"] = np.load(ext_folder / "bins.npy") + elif new_name == "noise_levels": + ext.data["noise_levels"] = np.load(ext_folder / "noise_levels.npy") + elif new_name == "quality_metrics": + import pandas as pd + ext.data["metrics"] = pd.read_csv(ext_folder / "metrics.csv", index_col=0) + # elif new_name == "principal_components": + # # TODO: alessio this is for you + # pass + sorting_result.extensions[new_name] = ext return sorting_result From b3ea569e473d0ebd4dc56362d788eab1c37dbf75 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 13 Feb 2024 10:38:38 +0100 Subject: [PATCH 058/192] WIP --- src/spikeinterface/core/sparsity.py | 5 +- src/spikeinterface/core/template.py | 6 + .../sorters/internal/spyking_circus2.py | 2 +- .../clustering/clustering_tools.py | 3 +- .../clustering/random_projections.py | 2 +- .../sortingcomponents/matching/circus.py | 242 +++++++++--------- .../sortingcomponents/matching/method_list.py | 3 +- .../sortingcomponents/matching/naive.py | 2 +- .../sortingcomponents/matching/tdc.py | 2 +- .../sortingcomponents/matching/wobble.py | 2 +- 10 files changed, 139 insertions(+), 130 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index c4e703d911..29dcf35686 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -220,9 +220,10 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo return int(excess_zeros) == 0 - def sparisfy_templates(self, templates_array: np.ndarray) -> np.ndarray: + def sparsify_templates(self, templates_array: np.ndarray) -> np.ndarray: max_num_active_channels = self.max_num_active_channels - sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) + num_samples = templates_array.shape[1] + sparisfied_shape = (self.num_units, num_samples, max_num_active_channels) sparse_templates = np.zeros(shape=sparisfied_shape, dtype=templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): template = templates_array[unit_index, ...] diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 99334022bb..aa0d967036 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -108,6 +108,12 @@ def __post_init__(self): if not self._are_passed_templates_sparse(): raise ValueError("Sparsity mask passed but the templates are not sparse") + def set_sparsity(self, sparsity): + assert isinstance(sparsity, ChannelSparsity), "sparsity should be of type ChannelSparsity" + self.sparsity = sparsity + self.templates_array = self.sparsity.sparsify_templates(self.get_dense_templates()) + + def get_one_template_dense(self, unit_index): if self.sparsity is None: template = self.templates_array[unit_index, :, :] diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a355141f40..4b68b84398 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -214,7 +214,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sampling_frequency, nbefore, None, recording_f.channel_ids, unit_ids, recording_f.get_probe()) sparsity = compute_sparsity(templates, method='radius') - templates.sparsity = sparsity + templates.set_sparsity(sparsity) if params["debug"]: sorting = sorting.save(folder=clustering_folder / "sorting") diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index f42d0e2e98..bb89bc14c8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -617,14 +617,13 @@ def remove_duplicates_via_matching( "singular": computed["singular"], "units_overlaps": computed["units_overlaps"], "unit_overlaps_indices": computed["unit_overlaps_indices"], - "sparsity_mask": computed["sparsity_mask"], } ) elif method == "circus-omp": local_params.update( { "overlaps": computed["overlaps"], - "templates": computed["templates"], + "circus_templates": computed["normed_templates"], "norms": computed["norms"], "sparsities": computed["sparsities"], } diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 254c6d31df..9e95968d6f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -211,7 +211,7 @@ def main_function(cls, recording, peaks, params): templates = Templates(templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()) sparsity = compute_sparsity(templates, method='radius') - templates.sparsity = sparsity + templates.set_sparsity(sparsity) cleaning_matching_params = params["job_kwargs"].copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 67bc0d4158..9482ee561a 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -21,6 +21,7 @@ from spikeinterface.core import get_noise_levels, get_random_data_chunks, compute_sparsity from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel +from spikeinterface.core.template import Templates (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) @@ -131,7 +132,7 @@ def _prepare_templates(cls, d): for i in range(num_templates): (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) - templates_array = templates.templates_array.copy() + templates_array = templates.get_dense_templates().copy() # First, we set masked channels to 0 for count in range(num_templates): @@ -195,6 +196,11 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls._default_params.copy() d.update(kwargs) + assert isinstance(d["templates"], Templates), ( + f"The templates supplied is of type {type(d['templates'])} " + f"and must be a Templates" + ) + d["num_channels"] = recording.get_num_channels() d["num_samples"] = d["templates"].num_samples d["nbefore"] = d["templates"].nbefore @@ -202,7 +208,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["sampling_frequency"] = recording.get_sampling_frequency() d["vicinity"] *= d["num_samples"] - if "normed_templates" not in d: + if "overlaps" not in d: d = cls._prepare_templates(d) else: for key in [ @@ -211,7 +217,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): "spatial", "singular", "units_overlaps", - "sparsity_mask", "unit_overlaps_indices", ]: assert d[key] is not None, "If templates are provided, %d should also be there" % key @@ -500,13 +505,13 @@ class CircusPeeler(BaseTemplateMatchingEngine): "max_amplitude": 1.5, "min_amplitude": 0.5, "use_sparse_matrix_threshold": 0.25, - "waveform_extractor": None, + "templates": None, "sparse_kwargs": {"method": "ptp", "threshold": 1}, } @classmethod def _prepare_templates(cls, d): - waveform_extractor = d["waveform_extractor"] + templates = d["templates"] num_samples = d["num_samples"] num_channels = d["num_channels"] num_templates = d["num_templates"] @@ -514,163 +519,162 @@ def _prepare_templates(cls, d): d["norms"] = np.zeros(num_templates, dtype=np.float32) - all_units = list(d["waveform_extractor"].sorting.unit_ids) + all_units = d["templates"].unit_ids - if not waveform_extractor.is_sparse(): - sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask - else: - sparsity = waveform_extractor.sparsity.mask + sparsity = templates.sparsity.mask - templates = waveform_extractor.get_all_templates(mode="median").copy() + templates_array = templates.get_dense_templates() d["sparsities"] = {} - d["circus_templates"] = {} + d["normed_templates"] = {} for count, unit_id in enumerate(all_units): (d["sparsities"][count],) = np.nonzero(sparsity[count]) - templates[count][:, ~sparsity[count]] = 0 - d["norms"][count] = np.linalg.norm(templates[count]) - templates[count] /= d["norms"][count] - d["circus_templates"][count] = templates[count][:, sparsity[count]] + templates_array[count][:, ~sparsity[count]] = 0 + d["norms"][count] = np.linalg.norm(templates_array[count]) + templates_array[count] /= d["norms"][count] + d["normed_templates"][count] = templates_array[count][:, sparsity[count]] - templates = templates.reshape(num_templates, -1) + templates_array = templates_array.reshape(num_templates, -1) - nnz = np.sum(templates != 0) / (num_templates * num_samples * num_channels) + nnz = np.sum(templates_array != 0) / (num_templates * num_samples * num_channels) if nnz <= use_sparse_matrix_threshold: - templates = scipy.sparse.csr_matrix(templates) + templates_array = scipy.sparse.csr_matrix(templates_array) print(f"Templates are automatically sparsified (sparsity level is {nnz})") d["is_dense"] = False else: d["is_dense"] = True - d["templates"] = templates + d["circus_templates"] = templates_array return d - @classmethod - def _mcc_error(cls, bounds, good, bad): - fn = np.sum((good < bounds[0]) | (good > bounds[1])) - fp = np.sum((bounds[0] <= bad) & (bad <= bounds[1])) - tp = np.sum((bounds[0] <= good) & (good <= bounds[1])) - tn = np.sum((bad < bounds[0]) | (bad > bounds[1])) - denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) - if denom > 0: - mcc = 1 - (tp * tn - fp * fn) / np.sqrt(denom) - else: - mcc = 1 - return mcc - - @classmethod - def _cost_function_mcc(cls, bounds, good, bad, delta_amplitude, alpha): - # We want a minimal error, with the larger bounds that are possible - cost = alpha * cls._mcc_error(bounds, good, bad) + (1 - alpha) * np.abs( - (1 - (bounds[1] - bounds[0]) / delta_amplitude) - ) - return cost - - @classmethod - def _optimize_amplitudes(cls, noise_snippets, d): - parameters = d - waveform_extractor = parameters["waveform_extractor"] - templates = parameters["templates"] - num_templates = parameters["num_templates"] - max_amplitude = parameters["max_amplitude"] - min_amplitude = parameters["min_amplitude"] - alpha = 0.5 - norms = parameters["norms"] - all_units = list(waveform_extractor.sorting.unit_ids) - - parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) - noise = templates.dot(noise_snippets) / norms[:, np.newaxis] - - all_amps = {} - for count, unit_id in enumerate(all_units): - waveform = waveform_extractor.get_waveforms(unit_id, force_dense=True) - snippets = waveform.reshape(waveform.shape[0], -1).T - amps = templates.dot(snippets) / norms[:, np.newaxis] - good = amps[count, :].flatten() - - sub_amps = amps[np.concatenate((np.arange(count), np.arange(count + 1, num_templates))), :] - bad = sub_amps[sub_amps >= good] - bad = np.concatenate((bad, noise[count])) - cost_kwargs = [good, bad, max_amplitude - min_amplitude, alpha] - cost_bounds = [(min_amplitude, 1), (1, max_amplitude)] - res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) - parameters["amplitudes"][count] = res.x - - return d + # @classmethod + # def _mcc_error(cls, bounds, good, bad): + # fn = np.sum((good < bounds[0]) | (good > bounds[1])) + # fp = np.sum((bounds[0] <= bad) & (bad <= bounds[1])) + # tp = np.sum((bounds[0] <= good) & (good <= bounds[1])) + # tn = np.sum((bad < bounds[0]) | (bad > bounds[1])) + # denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + # if denom > 0: + # mcc = 1 - (tp * tn - fp * fn) / np.sqrt(denom) + # else: + # mcc = 1 + # return mcc + + # @classmethod + # def _cost_function_mcc(cls, bounds, good, bad, delta_amplitude, alpha): + # # We want a minimal error, with the larger bounds that are possible + # cost = alpha * cls._mcc_error(bounds, good, bad) + (1 - alpha) * np.abs( + # (1 - (bounds[1] - bounds[0]) / delta_amplitude) + # ) + # return cost + + # @classmethod + # def _optimize_amplitudes(cls, noise_snippets, d): + # parameters = d + # waveform_extractor = parameters["waveform_extractor"] + # templates = parameters["templates"] + # num_templates = parameters["num_templates"] + # max_amplitude = parameters["max_amplitude"] + # min_amplitude = parameters["min_amplitude"] + # alpha = 0.5 + # norms = parameters["norms"] + # all_units = list(waveform_extractor.sorting.unit_ids) + + # parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) + # noise = templates.dot(noise_snippets) / norms[:, np.newaxis] + + # all_amps = {} + # for count, unit_id in enumerate(all_units): + # waveform = waveform_extractor.get_waveforms(unit_id, force_dense=True) + # snippets = waveform.reshape(waveform.shape[0], -1).T + # amps = templates.dot(snippets) / norms[:, np.newaxis] + # good = amps[count, :].flatten() + + # sub_amps = amps[np.concatenate((np.arange(count), np.arange(count + 1, num_templates))), :] + # bad = sub_amps[sub_amps >= good] + # bad = np.concatenate((bad, noise[count])) + # cost_kwargs = [good, bad, max_amplitude - min_amplitude, alpha] + # cost_bounds = [(min_amplitude, 1), (1, max_amplitude)] + # res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) + # parameters["amplitudes"][count] = res.x + + # return d @classmethod def initialize_and_check_kwargs(cls, recording, kwargs): assert HAVE_SKLEARN, "CircusPeeler needs sklearn to work" - default_parameters = cls._default_params.copy() - default_parameters.update(kwargs) + d = cls._default_params.copy() + d.update(kwargs) # assert isinstance(d['waveform_extractor'], WaveformExtractor) for v in ["use_sparse_matrix_threshold"]: - assert (default_parameters[v] >= 0) and (default_parameters[v] <= 1), f"{v} should be in [0, 1]" + assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" - default_parameters["num_channels"] = default_parameters["waveform_extractor"].recording.get_num_channels() - default_parameters["num_samples"] = default_parameters["waveform_extractor"].nsamples - default_parameters["num_templates"] = len(default_parameters["waveform_extractor"].sorting.unit_ids) + d["num_channels"] = recording.get_num_channels() + d["num_samples"] = d["templates"].num_samples + d["num_templates"] = len(d["templates"].unit_ids) - if default_parameters["noise_levels"] is None: + if d["noise_levels"] is None: print("CircusPeeler : noise should be computed outside") - default_parameters["noise_levels"] = get_noise_levels( - recording, **default_parameters["random_chunk_kwargs"], return_scaled=False + d["noise_levels"] = get_noise_levels( + recording, **d["random_chunk_kwargs"], return_scaled=False ) - default_parameters["abs_threholds"] = ( - default_parameters["noise_levels"] * default_parameters["detect_threshold"] + d["abs_threholds"] = ( + d["noise_levels"] * d["detect_threshold"] ) - default_parameters = cls._prepare_templates(default_parameters) + if not "circus_templates" in d: + d = cls._prepare_templates(d) - default_parameters["overlaps"] = compute_overlaps( - default_parameters["circus_templates"], - default_parameters["num_samples"], - default_parameters["num_channels"], - default_parameters["sparsities"], + d["overlaps"] = compute_overlaps( + d["normed_templates"], + d["num_samples"], + d["num_channels"], + d["sparsities"], ) - default_parameters["exclude_sweep_size"] = int( - default_parameters["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 + d["exclude_sweep_size"] = int( + d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 ) - default_parameters["nbefore"] = default_parameters["waveform_extractor"].nbefore - default_parameters["nafter"] = default_parameters["waveform_extractor"].nafter - default_parameters["patch_sizes"] = ( - default_parameters["waveform_extractor"].nsamples, - default_parameters["num_channels"], + d["nbefore"] = d["templates"].nbefore + d["nafter"] = d["templates"].nafter + d["patch_sizes"] = ( + d["templates"].num_samples, + d["num_channels"], ) - default_parameters["sym_patch"] = default_parameters["nbefore"] == default_parameters["nafter"] - default_parameters["jitter"] = int( - default_parameters["jitter_ms"] * recording.get_sampling_frequency() / 1000.0 + d["sym_patch"] = d["nbefore"] == d["nafter"] + d["jitter"] = int( + d["jitter_ms"] * recording.get_sampling_frequency() / 1000.0 ) - num_segments = recording.get_num_segments() - if default_parameters["waveform_extractor"]._params["max_spikes_per_unit"] is None: - num_snippets = 1000 - else: - num_snippets = 2 * default_parameters["waveform_extractor"]._params["max_spikes_per_unit"] + d["amplitudes"] = np.zeros((d["num_templates"], 2), dtype=np.float32) + d["amplitudes"][:, 0] = d["min_amplitude"] + d["amplitudes"][:, 1] = d["max_amplitude"] + # num_segments = recording.get_num_segments() + # if d["waveform_extractor"]._params["max_spikes_per_unit"] is None: + # num_snippets = 1000 + # else: + # num_snippets = 2 * d["waveform_extractor"]._params["max_spikes_per_unit"] + + # num_chunks = num_snippets // num_segments + # noise_snippets = get_random_data_chunks( + # recording, num_chunks_per_segment=num_chunks, chunk_size=d["num_samples"], seed=42 + # ) + # noise_snippets = ( + # noise_snippets.reshape(num_chunks, d["num_samples"], d["num_channels"]) + # .reshape(num_chunks, -1) + # .T + # ) + #parameters = cls._optimize_amplitudes(noise_snippets, d) - num_chunks = num_snippets // num_segments - noise_snippets = get_random_data_chunks( - recording, num_chunks_per_segment=num_chunks, chunk_size=default_parameters["num_samples"], seed=42 - ) - noise_snippets = ( - noise_snippets.reshape(num_chunks, default_parameters["num_samples"], default_parameters["num_channels"]) - .reshape(num_chunks, -1) - .T - ) - parameters = cls._optimize_amplitudes(noise_snippets, default_parameters) - - return parameters + return d @classmethod def serialize_method_kwargs(cls, kwargs): kwargs = dict(kwargs) - # remove waveform_extractor - kwargs.pop("waveform_extractor") return kwargs @classmethod @@ -687,7 +691,7 @@ def main_function(cls, traces, d): peak_sign = d["peak_sign"] abs_threholds = d["abs_threholds"] exclude_sweep_size = d["exclude_sweep_size"] - templates = d["templates"] + templates = d["circus_templates"] num_templates = d["num_templates"] num_channels = d["num_channels"] overlaps = d["overlaps"] diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index 4a27fcd8c2..bd8dfd21bc 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -2,14 +2,13 @@ from .naive import NaiveMatching from .tdc import TridesclousPeeler -from .circus import CircusPeeler, CircusOMPPeeler, CircusOMPSVDPeeler +from .circus import CircusPeeler, CircusOMPSVDPeeler from .wobble import WobbleMatch matching_methods = { "naive": NaiveMatching, "tridesclous": TridesclousPeeler, "circus": CircusPeeler, - "circus-omp": CircusOMPPeeler, "circus-omp-svd": CircusOMPSVDPeeler, "wobble": WobbleMatch, } diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index f64f4d8176..edb69ef6bb 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -47,7 +47,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d.update(kwargs) assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['waveform_extractor'])} " + f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" ) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 25f8129b3d..92777ed6ae 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -83,7 +83,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["nbefore"] = templates.nbefore d["nafter"] = templates.nafter - templates_array = templates.templates_array + templates_array = templates.get_dense_templates() nbefore_short = int(d["ms_before"] * sr / 1000.0) nafter_short = int(d["ms_before"] * sr / 1000.0) diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 0bbc147dd8..07a73cb9e3 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -350,7 +350,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" ) - templates_array = templates.templates_array.astype(np.float32, casting="safe") + templates_array = templates.get_dense_templates().astype(np.float32, casting="safe") # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) From 36e85e384a4d4de3d5d93de4b821805eec43c67e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 13 Feb 2024 10:42:54 +0100 Subject: [PATCH 059/192] full.py --- src/spikeinterface/full.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/full.py b/src/spikeinterface/full.py index dc8b2dbdd3..0cd0fb0fb5 100644 --- a/src/spikeinterface/full.py +++ b/src/spikeinterface/full.py @@ -20,10 +20,8 @@ from .preprocessing import * from .postprocessing import * from .qualitymetrics import * - -# TODO -# from .curation import * -# from .comparison import * -# from .widgets import * -# from .exporters import * +from .curation import * +from .comparison import * +from .widgets import * +from .exporters import * from .generation import * From c7b5fc83363d356b6cb8d778e2bd08409c09b03c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 13 Feb 2024 11:09:12 +0100 Subject: [PATCH 060/192] More compatibility for plotting. --- .../widgets/all_amplitudes_distributions.py | 3 ++- src/spikeinterface/widgets/amplitudes.py | 3 +++ src/spikeinterface/widgets/base.py | 12 ++++++++++++ src/spikeinterface/widgets/crosscorrelograms.py | 2 ++ src/spikeinterface/widgets/quality_metrics.py | 1 + src/spikeinterface/widgets/sorting_summary.py | 1 + src/spikeinterface/widgets/spike_locations.py | 4 +++- src/spikeinterface/widgets/spikes_on_traces.py | 2 +- src/spikeinterface/widgets/template_metrics.py | 1 + .../widgets/template_similarity.py | 4 +++- src/spikeinterface/widgets/unit_depths.py | 3 +++ src/spikeinterface/widgets/unit_locations.py | 2 ++ src/spikeinterface/widgets/unit_probe_map.py | 2 ++ src/spikeinterface/widgets/unit_summary.py | 2 ++ src/spikeinterface/widgets/unit_waveforms.py | 16 +++++++++++----- .../widgets/unit_waveforms_density_map.py | 3 ++- 16 files changed, 51 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index f9b7014d35..4ba9661d66 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -24,7 +24,8 @@ class AllAmplitudesDistributionsWidget(BaseWidget): """ def __init__(self, sorting_result: SortingResult, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs): - + + sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, "spike_amplitudes") amplitudes = sorting_result.get_extension("spike_amplitudes").get_data() diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 67432d7b66..1867cae7da 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -50,6 +50,9 @@ def __init__( backend=None, **backend_kwargs, ): + + sorting_result = self.ensure_sorting_result(sorting_result) + sorting = sorting_result.sorting self.check_extensions(sorting_result, "spike_amplitudes") diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index c207fca1f2..26cfd0fa23 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -5,6 +5,9 @@ global default_backend_ default_backend_ = "matplotlib" +from ..core import SortingResult +from ..core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor + def get_default_plotter_backend(): """Return the default backend for spikeinterface widgets. @@ -102,6 +105,15 @@ def do_plot(self): func = getattr(self, f"plot_{self.backend}") func(self.data_plot, **self.backend_kwargs) + @classmethod + def ensure_sorting_result(cls, sorting_result_or_waveform_extractor): + if isinstance(sorting_result_or_waveform_extractor, SortingResult): + return sorting_result_or_waveform_extractor + elif isinstance(sorting_result_or_waveform_extractor, MockWaveformExtractor): + return sorting_result_or_waveform_extractor.sorting_result + else: + return sorting_result_or_waveform_extractor + @staticmethod def check_extensions(sorting_result, extensions): if isinstance(extensions, str): diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 445d7b5a6b..087a97df9e 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -46,6 +46,8 @@ def __init__( backend=None, **backend_kwargs, ): + sorting_result_or_sorting = self.ensure_sorting_result(sorting_result_or_sorting) + if min_similarity_for_correlograms is None: min_similarity_for_correlograms = 0 similarity = None diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index bf63b0d494..5a9b77dfcd 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -35,6 +35,7 @@ def __init__( backend=None, **backend_kwargs, ): + sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, "quality_metrics") quality_metrics = sorting_result.get_extension("quality_metrics").get_data() diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index c35cc71e82..1c01e99a69 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -58,6 +58,7 @@ def __init__( backend=None, **backend_kwargs, ): + sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, ["correlograms", "spike_amplitudes", "unit_locations", "similarity"]) sorting = sorting_result.sorting diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 77c6537ea4..9427c62f46 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -52,8 +52,10 @@ def __init__( hide_axis=False, backend=None, **backend_kwargs, - ): + ): + sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, "spike_locations") + spike_locations_by_units = sorting_result.get_extension("spike_locations").get_data(outputs="by_unit") sorting = sorting_result.sorting diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 844e50924f..7515bc5d64 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -83,7 +83,7 @@ def __init__( backend=None, **backend_kwargs, ): - + sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, "unit_locations") sorting: BaseSorting = sorting_result.sorting diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index 90e952ed2a..c3b7d7f3e8 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -35,6 +35,7 @@ def __init__( backend=None, **backend_kwargs, ): + sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, "template_metrics") template_metrics = sorting_result.get_extension("template_metrics").get_data() diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 9c36b22309..6800a55b51 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -37,8 +37,10 @@ def __init__( show_colorbar=True, backend=None, **backend_kwargs, - ): + ): + sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, "template_similarity") + tsc = sorting_result.get_extension("template_similarity") similarity = tsc.get_data().copy() diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index a37ec46bd3..cfc141f396 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -27,6 +27,9 @@ class UnitDepthsWidget(BaseWidget): """ def __init__(self, sorting_result, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs): + + sorting_result = self.ensure_sorting_result(sorting_result) + unit_ids = sorting_result.sorting.unit_ids if unit_colors is None: diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index f91f7291aa..7fc635e419 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -48,6 +48,8 @@ def __init__( backend=None, **backend_kwargs, ): + sorting_result = self.ensure_sorting_result(sorting_result) + self.check_extensions(sorting_result, "unit_locations") ulc = sorting_result.get_extension("unit_locations") unit_locations = ulc.get_data(outputs="by_unit") diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index 20640e5e63..14da93079d 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -42,6 +42,8 @@ def __init__( backend=None, **backend_kwargs, ): + sorting_result = self.ensure_sorting_result(sorting_result) + if unit_ids is None: unit_ids = sorting_result.unit_ids self.unit_ids = unit_ids diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index eadf9f4037..27026f645e 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -45,6 +45,8 @@ def __init__( **backend_kwargs, ): + sorting_result = self.ensure_sorting_result(sorting_result) + if unit_colors is None: unit_colors = get_unit_colors(sorting_result.sorting) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 5c182f1f3f..ca98888428 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -105,6 +105,7 @@ def __init__( **backend_kwargs, ): + sorting_result = self.ensure_sorting_result(sorting_result) sorting: BaseSorting = sorting_result.sorting if unit_ids is None: @@ -345,7 +346,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() cm = 1 / 2.54 - self.we = we = data_plot["sorting_result"] + self.sorting_result = data_plot["sorting_result"] width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -449,18 +450,23 @@ def _update_plot(self, change): hide_axis = self.hide_axis_button.value do_shading = self.template_shading_button.value + wf_ext = self.sorting_result.get_extension("waveforms") + templates_ext = self.sorting_result.get_extension("templates") + templates = templates_ext.get_templates(unit_ids=unit_ids, operator="average") + + # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids - data_plot["templates"] = self.we.get_all_templates(unit_ids=unit_ids) - templates_shadings = self._get_template_shadings(self.we, unit_ids, data_plot["templates_percentile_shading"]) + data_plot["templates"] = templates + templates_shadings = self._get_template_shadings(self.sorting_result, unit_ids, data_plot["templates_percentile_shading"]) data_plot["templates_shading"] = templates_shadings data_plot["same_axis"] = same_axis data_plot["plot_templates"] = plot_templates data_plot["do_shading"] = do_shading data_plot["scale"] = self.scaler.value if data_plot["plot_waveforms"]: - data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} + data_plot["wfs_by_ids"] = {unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids} # TODO option for plot_legend @@ -484,7 +490,7 @@ def _update_plot(self, change): ax.axis("off") # update probe plot - channel_locations = self.we.get_channel_locations() + channel_locations = self.sorting_result.get_channel_locations() self.ax_probe.plot( channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 ) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 6c82c2bd4d..659ba33f80 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -48,7 +48,8 @@ def __init__( backend=None, **backend_kwargs, ): - + sorting_result = self.ensure_sorting_result(sorting_result) + if channel_ids is None: channel_ids = sorting_result.channel_ids From 9f0c65256dd13673a9a412572ba791a54abd1cf1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Feb 2024 10:14:54 +0000 Subject: [PATCH 061/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dev_pool.py | 2 +- src/spikeinterface/core/template.py | 1 - .../sorters/internal/spyking_circus2.py | 20 +++++++++++----- .../clustering/clustering_tools.py | 4 +--- .../clustering/random_projections.py | 12 ++++++---- .../sortingcomponents/matching/circus.py | 24 +++++-------------- .../sortingcomponents/matching/naive.py | 3 +-- .../sortingcomponents/matching/tdc.py | 8 +++---- .../sortingcomponents/matching/wobble.py | 5 +--- 9 files changed, 35 insertions(+), 44 deletions(-) diff --git a/dev_pool.py b/dev_pool.py index 52f5a7572a..9a9b2ca0f2 100644 --- a/dev_pool.py +++ b/dev_pool.py @@ -71,4 +71,4 @@ def init_worker(lock, array_pid): # print(multiprocessing.current_process()) # p = multiprocessing.current_process() -# print(p._identity) \ No newline at end of file +# print(p._identity) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index aa0d967036..9102b1a20e 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -113,7 +113,6 @@ def set_sparsity(self, sparsity): self.sparsity = sparsity self.templates_array = self.sparsity.sparsify_templates(self.get_dense_templates()) - def get_one_template_dense(self, unit_index): if self.sparsity is None: template = self.templates_array[unit_index, :, :] diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4b68b84398..4ae5e301a1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -207,13 +207,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore = int(params["general"]["ms_before"] * sampling_frequency / 1000.0) nafter = int(params["general"]["ms_after"] * sampling_frequency / 1000.0) - templates_array = estimate_templates(recording_f, labeled_peaks, unit_ids, nbefore, nafter, - False, job_name=None, **job_kwargs) + templates_array = estimate_templates( + recording_f, labeled_peaks, unit_ids, nbefore, nafter, False, job_name=None, **job_kwargs + ) - templates = Templates(templates_array, - sampling_frequency, nbefore, None, recording_f.channel_ids, unit_ids, recording_f.get_probe()) + templates = Templates( + templates_array, + sampling_frequency, + nbefore, + None, + recording_f.channel_ids, + unit_ids, + recording_f.get_probe(), + ) - sparsity = compute_sparsity(templates, method='radius') + sparsity = compute_sparsity(templates, method="radius") templates.set_sparsity(sparsity) if params["debug"]: @@ -225,7 +233,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params["templates"] = templates matching_job_params = {} matching_job_params.update(job_kwargs) - + if matching_method == "circus-omp-svd": for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: if value in matching_job_params: diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index bb89bc14c8..90a1731f6f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -590,9 +590,7 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() - local_params.update( - {"templates": templates, "amplitudes": [0.975, 1.025]} - ) + local_params.update({"templates": templates, "amplitudes": [0.975, 1.025]}) ignore_ids = [] similar_templates = [[], []] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 9e95968d6f..459b7f981e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -205,12 +205,14 @@ def main_function(cls, recording, peaks, params): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) - templates_array = estimate_templates(recording, spikes, unit_ids, nbefore, nafter, - False, job_name=None, **job_kwargs) + templates_array = estimate_templates( + recording, spikes, unit_ids, nbefore, nafter, False, job_name=None, **job_kwargs + ) - templates = Templates(templates_array, - fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()) - sparsity = compute_sparsity(templates, method='radius') + templates = Templates( + templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() + ) + sparsity = compute_sparsity(templates, method="radius") templates.set_sparsity(sparsity) cleaning_matching_params = params["job_kwargs"].copy() diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 9482ee561a..25e6c59dfa 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -38,7 +38,6 @@ from .main import BaseTemplateMatchingEngine - def compute_overlaps(templates, num_samples, num_channels, sparsities): num_templates = len(templates) @@ -71,7 +70,6 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities): return new_overlaps - class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -157,7 +155,6 @@ def _prepare_templates(cls, d): d["norms"][count] = np.linalg.norm(template) d["normed_templates"][count][:, d["sparsity_mask"][count]] = template / d["norms"][count] - d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] d["temporal"] = np.flip(d["temporal"], axis=1) @@ -197,8 +194,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d.update(kwargs) assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " - f"and must be a Templates" + f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" ) d["num_channels"] = recording.get_num_channels() @@ -617,13 +613,9 @@ def initialize_and_check_kwargs(cls, recording, kwargs): if d["noise_levels"] is None: print("CircusPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels( - recording, **d["random_chunk_kwargs"], return_scaled=False - ) + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - d["abs_threholds"] = ( - d["noise_levels"] * d["detect_threshold"] - ) + d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] if not "circus_templates" in d: d = cls._prepare_templates(d) @@ -635,9 +627,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["sparsities"], ) - d["exclude_sweep_size"] = int( - d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 - ) + d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) d["nbefore"] = d["templates"].nbefore d["nafter"] = d["templates"].nafter @@ -646,9 +636,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["num_channels"], ) d["sym_patch"] = d["nbefore"] == d["nafter"] - d["jitter"] = int( - d["jitter_ms"] * recording.get_sampling_frequency() / 1000.0 - ) + d["jitter"] = int(d["jitter_ms"] * recording.get_sampling_frequency() / 1000.0) d["amplitudes"] = np.zeros((d["num_templates"], 2), dtype=np.float32) d["amplitudes"][:, 0] = d["min_amplitude"] @@ -668,7 +656,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): # .reshape(num_chunks, -1) # .T # ) - #parameters = cls._optimize_amplitudes(noise_snippets, d) + # parameters = cls._optimize_amplitudes(noise_snippets, d) return d diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index edb69ef6bb..8103ccd011 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -47,8 +47,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d.update(kwargs) assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " - f"and must be a Templates" + f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" ) templates = d["templates"] diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 92777ed6ae..55818c471c 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -70,8 +70,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d.update(kwargs) assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " - f"and must be a Templates" + f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" ) templates = d["templates"] @@ -80,7 +79,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): sr = templates.sampling_frequency - d["nbefore"] = templates.nbefore d["nafter"] = templates.nafter templates_array = templates.get_dense_templates() @@ -109,7 +107,9 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["radius_um"] - sparsity = compute_sparsity(templates, method="best_channels")#, peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) + sparsity = compute_sparsity( + templates, method="best_channels" + ) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) template_sparsity_inds = sparsity.unit_id_to_channel_indices template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") for unit_index, unit_id in enumerate(unit_ids): diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 07a73cb9e3..c4953d7aa3 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -342,13 +342,10 @@ def initialize_and_check_kwargs(cls, recording, kwargs): for required_key in required_kwargs_keys: assert required_key in kwargs, f"`{required_key}` is a required key in the kwargs" - - parameters = kwargs.get("parameters", {}) templates = kwargs["templates"] assert isinstance(templates, Templates), ( - f"The templates supplied is of type {type(d['templates'])} " - f"and must be a Templates" + f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" ) templates_array = templates.get_dense_templates().astype(np.float32, casting="safe") From 51a7a39af72b4b2b750dd2ae282e3ae0fc38854d Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 13 Feb 2024 12:50:13 +0100 Subject: [PATCH 062/192] WIP --- src/spikeinterface/core/sparsity.py | 8 ++-- src/spikeinterface/core/template.py | 15 +++++-- .../sorters/internal/spyking_circus2.py | 21 ++++------ .../clustering/random_projections.py | 42 +++---------------- 4 files changed, 31 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 8687d14330..4883c64507 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -223,11 +223,13 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo def sparsify_templates(self, templates_array: np.ndarray) -> np.ndarray: max_num_active_channels = self.max_num_active_channels num_samples = templates_array.shape[1] - sparisfied_shape = (self.num_units, num_samples, max_num_active_channels) - sparse_templates = np.zeros(shape=sparisfied_shape, dtype=templates_array.dtype) + sparsified_shape = (self.num_units, num_samples, max_num_active_channels) + sparse_templates = np.zeros(shape=sparsified_shape, dtype=templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): template = templates_array[unit_index, ...] - sparse_templates[unit_index, ...] = self.sparsify_waveforms(waveforms=template, unit_id=unit_id) + sparse_template = self.sparsify_waveforms(waveforms=template[np.newaxis, :, :], unit_id=unit_id) + sparse_templates[unit_index, :, :sparse_template.shape[2]] = sparse_template + return sparse_templates diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index aa0d967036..ccacfab8a1 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -108,10 +108,19 @@ def __post_init__(self): if not self._are_passed_templates_sparse(): raise ValueError("Sparsity mask passed but the templates are not sparse") - def set_sparsity(self, sparsity): + def to_sparse(self, sparsity): assert isinstance(sparsity, ChannelSparsity), "sparsity should be of type ChannelSparsity" - self.sparsity = sparsity - self.templates_array = self.sparsity.sparsify_templates(self.get_dense_templates()) + assert self.sparsity_mask is None, "Templates should be dense" + + return Templates( + templates_array = sparsity.sparsify_templates(self.templates_array), + sampling_frequency=self.sampling_frequency, + nbefore=self.nbefore, + sparsity_mask=sparsity.mask, + channel_ids=self.channel_ids, + unit_ids=self.unit_ids, + probe=self.probe, + check_for_consistent_sparsity=self.check_for_consistent_sparsity) def get_one_template_dense(self, unit_index): diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4b68b84398..4817e1ae82 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -28,11 +28,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": { - "sparse": True, - "method": "energy", - "threshold": 0.25, - }, + "sparsity": {"method": "ptp", "threshold": 1}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { @@ -56,8 +52,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _params_description = { "general": "A dictionary to describe how templates should be computed. User can define ms_before and ms_after (in ms) \ and also the radius_um used to be considered during clustering", - "waveforms": "A dictionary to be passed to all the calls to extract_waveforms that will be performed internally. Default is \ - to consider sparse waveforms", + "sparsity": "A dictionary to be passed to all the calls to sparsify the templates", "filtering": "A dictionary for the high_pass filter to be used during preprocessing", "detection": "A dictionary for the peak detection node (locally_exclusive)", "selection": "A dictionary for the peak selection node. Default is to use smart_sampling_amplitudes, with a minimum of 20000 peaks\ @@ -115,9 +110,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f = recording recording_f.annotate(is_filtered=True) - # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") - noise_levels = np.ones(num_channels, dtype=np.float32) + noise_levels = get_noise_levels(recording_f) if recording_f.check_serializability("json"): recording_f.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -158,7 +152,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets clustering_params = params["clustering"].copy() - clustering_params["waveforms"] = params["waveforms"].copy() + clustering_params["waveforms"] = {} + clustering_params["sparsity"] = params["sparsity"] for k in ["ms_before", "ms_after"]: clustering_params["waveforms"][k] = params["general"][k] @@ -208,13 +203,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nafter = int(params["general"]["ms_after"] * sampling_frequency / 1000.0) templates_array = estimate_templates(recording_f, labeled_peaks, unit_ids, nbefore, nafter, - False, job_name=None, **job_kwargs) + return_scaled=False, job_name=None, **job_kwargs) templates = Templates(templates_array, sampling_frequency, nbefore, None, recording_f.channel_ids, unit_ids, recording_f.get_probe()) - sparsity = compute_sparsity(templates, method='radius') - templates.set_sparsity(sparsity) + sparsity = compute_sparsity(templates, noise_levels, **params['sparsity']) + templates = templates.to_sparse(sparsity) if params["debug"]: sorting = sorting.save(folder=clustering_folder / "sorting") diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 9e95968d6f..f3fff94d97 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -21,6 +21,7 @@ from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms +from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature @@ -47,7 +48,8 @@ class RandomProjectionClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, - "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, + "waveforms": {"ms_before": 2, "ms_after": 2}, + "sparsity" : {"method": "ptp", "threshold": 1}, "radius_um": 100, "selection_method": "closest_to_centroid", "nb_projections": 10, @@ -150,38 +152,6 @@ def main_function(cls, recording, peaks, params): labels = np.unique(peak_labels) labels = labels[labels >= 0] - # best_spikes = {} - # nb_spikes = 0 - - # all_indices = np.arange(0, peak_labels.size) - - # max_spikes = params["waveforms"]["max_spikes_per_unit"] - # selection_method = params["selection_method"] - - # for unit_ind in labels: - # mask = peak_labels == unit_ind - # if selection_method == "closest_to_centroid": - # data = hdbscan_data[mask] - # centroid = np.median(data, axis=0) - # distances = sklearn.metrics.pairwise_distances(centroid[np.newaxis, :], data)[0] - # best_spikes[unit_ind] = all_indices[mask][np.argsort(distances)[:max_spikes]] - # elif selection_method == "random": - # best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[:max_spikes] - # nb_spikes += best_spikes[unit_ind].size - - # spikes = np.zeros(nb_spikes, dtype=minimum_spike_dtype) - - # mask = np.zeros(0, dtype=np.int32) - # for unit_ind in labels: - # mask = np.concatenate((mask, best_spikes[unit_ind])) - - # idx = np.argsort(mask) - # mask = mask[idx] - - # spikes["sample_index"] = peaks[mask]["sample_index"] - # spikes["segment_index"] = peaks[mask]["segment_index"] - # spikes["unit_index"] = peak_labels[mask] - spikes = np.zeros(len(peaks), dtype=minimum_spike_dtype) spikes["sample_index"] = peaks["sample_index"] spikes["segment_index"] = peaks["segment_index"] @@ -206,12 +176,12 @@ def main_function(cls, recording, peaks, params): nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) templates_array = estimate_templates(recording, spikes, unit_ids, nbefore, nafter, - False, job_name=None, **job_kwargs) + return_scaled=False, job_name=None, **job_kwargs) templates = Templates(templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()) - sparsity = compute_sparsity(templates, method='radius') - templates.set_sparsity(sparsity) + sparsity = compute_sparsity(templates, get_noise_levels(recording), **params["sparsity"]) + templates = templates.to_sparse(sparsity) cleaning_matching_params = params["job_kwargs"].copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: From faf2619539f33bfc7c339235f86df10cd0356faf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Feb 2024 11:53:44 +0000 Subject: [PATCH 063/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sparsity.py | 3 +-- src/spikeinterface/core/template.py | 5 +++-- .../sorters/internal/spyking_circus2.py | 7 ++++--- .../clustering/random_projections.py | 12 +++++++----- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 4883c64507..9de9856aa1 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -228,8 +228,7 @@ def sparsify_templates(self, templates_array: np.ndarray) -> np.ndarray: for unit_index, unit_id in enumerate(self.unit_ids): template = templates_array[unit_index, ...] sparse_template = self.sparsify_waveforms(waveforms=template[np.newaxis, :, :], unit_id=unit_id) - sparse_templates[unit_index, :, :sparse_template.shape[2]] = sparse_template - + sparse_templates[unit_index, :, : sparse_template.shape[2]] = sparse_template return sparse_templates diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 4b6f2852ed..c5c9a4a0cf 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -113,14 +113,15 @@ def to_sparse(self, sparsity): assert self.sparsity_mask is None, "Templates should be dense" return Templates( - templates_array = sparsity.sparsify_templates(self.templates_array), + templates_array=sparsity.sparsify_templates(self.templates_array), sampling_frequency=self.sampling_frequency, nbefore=self.nbefore, sparsity_mask=sparsity.mask, channel_ids=self.channel_ids, unit_ids=self.unit_ids, probe=self.probe, - check_for_consistent_sparsity=self.check_for_consistent_sparsity) + check_for_consistent_sparsity=self.check_for_consistent_sparsity, + ) def get_one_template_dense(self, unit_index): if self.sparsity is None: diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2e687431da..4540122752 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -202,8 +202,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore = int(params["general"]["ms_before"] * sampling_frequency / 1000.0) nafter = int(params["general"]["ms_after"] * sampling_frequency / 1000.0) - templates_array = estimate_templates(recording_f, labeled_peaks, unit_ids, nbefore, nafter, - return_scaled=False, job_name=None, **job_kwargs) + templates_array = estimate_templates( + recording_f, labeled_peaks, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + ) templates = Templates( templates_array, @@ -215,7 +216,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f.get_probe(), ) - sparsity = compute_sparsity(templates, noise_levels, **params['sparsity']) + sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"]) templates = templates.to_sparse(sparsity) if params["debug"]: diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index f3fff94d97..4f51ac169c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -49,7 +49,7 @@ class RandomProjectionClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity" : {"method": "ptp", "threshold": 1}, + "sparsity": {"method": "ptp", "threshold": 1}, "radius_um": 100, "selection_method": "closest_to_centroid", "nb_projections": 10, @@ -175,11 +175,13 @@ def main_function(cls, recording, peaks, params): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) - templates_array = estimate_templates(recording, spikes, unit_ids, nbefore, nafter, - return_scaled=False, job_name=None, **job_kwargs) + templates_array = estimate_templates( + recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + ) - templates = Templates(templates_array, - fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()) + templates = Templates( + templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() + ) sparsity = compute_sparsity(templates, get_noise_levels(recording), **params["sparsity"]) templates = templates.to_sparse(sparsity) From 28faaf269aaaeba7e2dbd14fff1966b28f73afa1 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 13 Feb 2024 13:23:56 +0100 Subject: [PATCH 064/192] Wobble can now use sparsity from Templates --- .../sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/matching/circus.py | 11 ++---- .../sortingcomponents/matching/naive.py | 2 +- .../sortingcomponents/matching/tdc.py | 19 +++++----- .../sortingcomponents/matching/wobble.py | 35 +++++++++++++++++-- 5 files changed, 47 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4540122752..44b09c7668 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -28,7 +28,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 1}, + "sparsity": {"method": "ptp", "threshold": 0.5}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 25e6c59dfa..41f92f78c2 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -123,7 +123,6 @@ def _prepare_templates(cls, d): sparsity = templates.sparsity.mask - d["sparsity_mask"] = sparsity units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) d["units_overlaps"] = units_overlaps > 0 d["unit_overlaps_indices"] = {} @@ -132,10 +131,6 @@ def _prepare_templates(cls, d): templates_array = templates.get_dense_templates().copy() - # First, we set masked channels to 0 - for count in range(num_templates): - templates_array[count][:, ~d["sparsity_mask"][count]] = 0 - # Then we keep only the strongest components rank = d["rank"] temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False) @@ -151,9 +146,9 @@ def _prepare_templates(cls, d): # And get the norms, saving compressed templates for CC matrix for count in range(num_templates): - template = templates_array[count][:, d["sparsity_mask"][count]] + template = templates_array[count][:, sparsity[count]] d["norms"][count] = np.linalg.norm(template) - d["normed_templates"][count][:, d["sparsity_mask"][count]] = template / d["norms"][count] + d["normed_templates"][count][:, sparsity[count]] = template / d["norms"][count] d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] d["temporal"] = np.flip(d["temporal"], axis=1) @@ -171,7 +166,7 @@ def _prepare_templates(cls, d): unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32) for count, j in enumerate(overlapping_units): - overlapped_channels = d["sparsity_mask"][j] + overlapped_channels = sparsity[j] visible_i = template_i[:, overlapped_channels] spatial_filters = d["spatial"][j, :, overlapped_channels] diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index d71e6d2084..951c61b5cb 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -87,7 +87,7 @@ def main_function(cls, traces, method_kwargs): abs_threholds = method_kwargs["abs_threholds"] exclude_sweep_size = method_kwargs["exclude_sweep_size"] neighbours_mask = method_kwargs["neighbours_mask"] - templates_array = method_kwargs["templates"].templates_array + templates_array = method_kwargs["templates"].get_dense_templates() nbefore = method_kwargs["nbefore"] nafter = method_kwargs["nafter"] diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 55818c471c..b6c935b318 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -107,15 +107,16 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["radius_um"] - sparsity = compute_sparsity( - templates, method="best_channels" - ) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) - template_sparsity_inds = sparsity.unit_id_to_channel_indices - template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - for unit_index, unit_id in enumerate(unit_ids): - chan_inds = template_sparsity_inds[unit_id] - template_sparsity[unit_index, chan_inds] = True - + #sparsity = compute_sparsity( + # templates, method="best_channels" + #) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) + #template_sparsity_inds = sparsity.unit_id_to_channel_indices + #template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") + #for unit_index, unit_id in enumerate(unit_ids): + # chan_inds = template_sparsity_inds[unit_id] + # template_sparsity[unit_index, chan_inds] = True + + template_sparsity = templates.sparsity.mask d["template_sparsity"] = template_sparsity extremum_channel = get_template_extremum_channel(templates, peak_sign=d["peak_sign"], outputs="index") diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index c4953d7aa3..b0b0118e91 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -240,6 +240,30 @@ def from_parameters_and_templates(cls, params, templates): sparsity = cls(visible_channels=visible_channels, unit_overlap=unit_overlap) return sparsity + @classmethod + def from_templates(cls, params, templates): + """Aggregate variables relevant to sparse representation of templates. + + Parameters + ---------- + params : WobbleParameters + Dataclass object for aggregating the parameters together. + templates : Templates object + + Returns + ------- + sparsity : Sparsity + Dataclass object for aggregating channel sparsity variables together. + """ + visible_channels = templates.sparsity.mask + unit_overlap = np.sum( + np.logical_and(visible_channels[:, np.newaxis, :], visible_channels[np.newaxis, :, :]), axis=2 + ) + unit_overlap = unit_overlap > 0 + unit_overlap = np.repeat(unit_overlap, params.jitter_factor, axis=0) + sparsity = cls(visible_channels=visible_channels, unit_overlap=unit_overlap) + return sparsity + @dataclass class TemplateData: @@ -352,9 +376,14 @@ def initialize_and_check_kwargs(cls, recording, kwargs): # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) template_meta = TemplateMetadata.from_parameters_and_templates(params, templates_array) - sparsity = Sparsity.from_parameters_and_templates( - params, templates_array - ) # TODO: replace with spikeinterface sparsity + if not templates.are_templates_sparse(): + sparsity = Sparsity.from_parameters_and_templates( + params, templates_array + ) + else: + sparsity = Sparsity.from_templates( + params, templates + ) # Perform initial computations on templates necessary for computing the objective sparse_templates = np.where(sparsity.visible_channels[:, np.newaxis, :], templates_array, 0) From bb8676acb9a9e77d4f2f6c582cbd564f1f13d02c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Feb 2024 12:24:18 +0000 Subject: [PATCH 065/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/tdc.py | 10 +++++----- .../sortingcomponents/matching/wobble.py | 8 ++------ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index b6c935b318..f7066845e6 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -107,12 +107,12 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["radius_um"] - #sparsity = compute_sparsity( + # sparsity = compute_sparsity( # templates, method="best_channels" - #) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) - #template_sparsity_inds = sparsity.unit_id_to_channel_indices - #template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - #for unit_index, unit_id in enumerate(unit_ids): + # ) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) + # template_sparsity_inds = sparsity.unit_id_to_channel_indices + # template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") + # for unit_index, unit_id in enumerate(unit_ids): # chan_inds = template_sparsity_inds[unit_id] # template_sparsity[unit_index, chan_inds] = True diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index b0b0118e91..8196df4dec 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -377,13 +377,9 @@ def initialize_and_check_kwargs(cls, recording, kwargs): params = WobbleParameters(**parameters) template_meta = TemplateMetadata.from_parameters_and_templates(params, templates_array) if not templates.are_templates_sparse(): - sparsity = Sparsity.from_parameters_and_templates( - params, templates_array - ) + sparsity = Sparsity.from_parameters_and_templates(params, templates_array) else: - sparsity = Sparsity.from_templates( - params, templates - ) + sparsity = Sparsity.from_templates(params, templates) # Perform initial computations on templates necessary for computing the objective sparse_templates = np.where(sparsity.visible_channels[:, np.newaxis, :], templates_array, 0) From d4892d7db7337cfc268b426da5079fe8a855ab94 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 13 Feb 2024 13:29:13 +0100 Subject: [PATCH 066/192] WIP --- src/spikeinterface/sortingcomponents/matching/tdc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index b6c935b318..498586dff6 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -102,7 +102,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): print("TridesclousPeeler : noise should be computed outside") d["noise_levels"] = get_noise_levels(recording) - d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] + d["abs_thresholds"] = d["noise_levels"] * d["detect_threshold"] channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["radius_um"] @@ -219,14 +219,14 @@ def _tdc_find_spikes(traces, d, level=0): peak_sign = d["peak_sign"] templates = d["templates"] templates_short = d["templates_short"] - templates_array = templates.templates_array + templates_array = templates.get_dense_templates() margin = d["margin"] possible_clusters_by_channel = d["possible_clusters_by_channel"] peak_traces = traces[margin // 2 : -margin // 2, :] peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, peak_sign, d["abs_threholds"], d["peak_shift"], d["neighbours_mask"] + peak_traces, peak_sign, d["abs_thresholds"], d["peak_shift"], d["neighbours_mask"] ) peak_sample_ind += margin // 2 From 61c1485bcec3b3fa03f8997e907e01af74d7705b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 13 Feb 2024 17:00:20 +0100 Subject: [PATCH 067/192] WIP --- .../sortingcomponents/clustering/circus.py | 91 ++++++------------- .../clustering/random_projections.py | 23 +---- 2 files changed, 27 insertions(+), 87 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index b6793e3d37..f98b1effcc 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -15,14 +15,19 @@ import random, string, os from spikeinterface.core import get_global_tmp_folder, get_channel_distances +from spikeinterface.core.basesorting import minimum_spike_dtype from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler -from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers +from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers, estimate_templates from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting +from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core import extract_waveforms from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from sklearn.decomposition import TruncatedSVD +from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import compute_sparsity import pickle, json from spikeinterface.core.node_pipeline import ( run_node_pipeline, @@ -46,14 +51,14 @@ class CircusClustering: "cluster_selection_method": "eom", }, "cleaning_kwargs": {}, - "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, + "waveforms": {"ms_before": 2, "ms_after": 2}, + "sparsity": {"method": "ptp", "threshold": 1}, "radius_um": 100, - "selection_method": "closest_to_centroid", "n_svd": [5, 10], "ms_before": 1, "ms_after": 1, "random_seed": 42, - "shared_memory": True, + "debug": False, "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } @@ -70,6 +75,8 @@ def main_function(cls, recording, peaks, params): if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1: params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count() + job_kwargs = fix_job_kwargs(params["job_kwargs"]) + d = params verbose = d["job_kwargs"]["verbose"] @@ -165,65 +172,28 @@ def main_function(cls, recording, peaks, params): labels = np.unique(peak_labels) labels = labels[labels >= 0] - best_spikes = {} - nb_spikes = 0 - - import sklearn - - all_indices = np.arange(0, peak_labels.size) - - max_spikes = params["waveforms"]["max_spikes_per_unit"] - selection_method = params["selection_method"] - - for unit_ind in labels: - mask = peak_labels == unit_ind - if selection_method == "closest_to_centroid": - data = all_pc_data[mask].reshape(np.sum(mask), -1) - centroid = np.median(data, axis=0) - distances = sklearn.metrics.pairwise_distances(centroid[np.newaxis, :], data)[0] - best_spikes[unit_ind] = all_indices[mask][np.argsort(distances)[:max_spikes]] - elif selection_method == "random": - best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[:max_spikes] - nb_spikes += best_spikes[unit_ind].size - - spikes = np.zeros(nb_spikes, dtype=peak_dtype) - - mask = np.zeros(0, dtype=np.int32) - for unit_ind in labels: - mask = np.concatenate((mask, best_spikes[unit_ind])) - - idx = np.argsort(mask) - mask = mask[idx] - - spikes["sample_index"] = peaks[mask]["sample_index"] - spikes["segment_index"] = peaks[mask]["segment_index"] - spikes["unit_index"] = peak_labels[mask] + spikes = np.zeros(len(peaks), dtype=minimum_spike_dtype) + spikes["sample_index"] = peaks["sample_index"] + spikes["segment_index"] = peaks["segment_index"] + spikes["unit_index"] = peak_labels if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) - sorting_folder = tmp_folder / "sorting" unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) - sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) + + nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) + nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) - if params["shared_memory"]: - waveform_folder = None - mode = "memory" - else: - waveform_folder = tmp_folder / "waveforms" - mode = "folder" - sorting = sorting.save(folder=sorting_folder) + templates_array = estimate_templates( + recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + ) - we = extract_waveforms( - recording, - sorting, - waveform_folder, - return_scaled=False, - precompute_template=["median"], - mode=mode, - **params["job_kwargs"], - **params["waveforms"], + templates = Templates( + templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() ) + sparsity = compute_sparsity(templates, get_noise_levels(recording), **params["sparsity"]) + templates = templates.to_sparse(sparsity) cleaning_matching_params = params["job_kwargs"].copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: @@ -238,18 +208,9 @@ def main_function(cls, recording, peaks, params): cleaning_params["tmp_folder"] = tmp_folder labels, peak_labels = remove_duplicates_via_matching( - we, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) - del we, sorting - - if params["tmp_folder"] is None: - shutil.rmtree(tmp_folder) - else: - if not params["shared_memory"]: - shutil.rmtree(tmp_folder / "waveforms") - shutil.rmtree(tmp_folder / "sorting") - if verbose: print("We kept %d non-duplicated clusters..." % len(labels)) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 4f51ac169c..0e5b5ca3fe 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -51,15 +51,13 @@ class RandomProjectionClustering: "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "ptp", "threshold": 1}, "radius_um": 100, - "selection_method": "closest_to_centroid", "nb_projections": 10, "ms_before": 1, "ms_after": 1, "random_seed": 42, "smoothing_kwargs": {"window_length_ms": 0.25}, - "shared_memory": True, - "tmp_folder": None, "debug": False, + "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } @@ -160,17 +158,7 @@ def main_function(cls, recording, peaks, params): if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) - sorting_folder = tmp_folder / "sorting" unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) - sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) - - if params["shared_memory"]: - waveform_folder = None - mode = "memory" - else: - waveform_folder = tmp_folder / "waveforms" - mode = "folder" - sorting = sorting.save(folder=sorting_folder) nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) @@ -201,15 +189,6 @@ def main_function(cls, recording, peaks, params): templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) - del sorting - - if params["tmp_folder"] is None: - shutil.rmtree(tmp_folder) - else: - if not params["shared_memory"]: - shutil.rmtree(tmp_folder / "waveforms") - shutil.rmtree(tmp_folder / "sorting") - if verbose: print("We kept %d non-duplicated clusters..." % len(labels)) From da0d113966e9057bfb53720fb2065015bd6ff81d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Feb 2024 16:00:42 +0000 Subject: [PATCH 068/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index f98b1effcc..291328413f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -181,7 +181,7 @@ def main_function(cls, recording, peaks, params): print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) - + nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) From dc71e7899e8c02080ae00d0c6588fd4def568633 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 13 Feb 2024 18:14:03 +0100 Subject: [PATCH 069/192] small fixes --- src/spikeinterface/curation/tests/test_auto_merge.py | 2 +- .../curation/tests/test_remove_redundant.py | 2 +- src/spikeinterface/exporters/tests/common.py | 6 +++--- src/spikeinterface/exporters/tests/test_export_to_phy.py | 8 +++++++- src/spikeinterface/exporters/tests/test_report.py | 7 ++++--- src/spikeinterface/widgets/tests/test_widgets.py | 5 ++--- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 8886cf474f..66f1d6602f 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -9,7 +9,7 @@ from spikeinterface.curation import get_potential_auto_merge -from spikeinterface.curation.tests.common import make_sorting_result +from spikeinterface.curation.tests.common import make_sorting_result, sorting_result_for_curation if hasattr(pytest, "global_test_folder"): diff --git a/src/spikeinterface/curation/tests/test_remove_redundant.py b/src/spikeinterface/curation/tests/test_remove_redundant.py index a395c200a5..5a0f15f6e4 100644 --- a/src/spikeinterface/curation/tests/test_remove_redundant.py +++ b/src/spikeinterface/curation/tests/test_remove_redundant.py @@ -10,7 +10,7 @@ from spikeinterface.core.generate import inject_some_duplicate_units -from spikeinterface.curation.tests.common import make_sorting_result +from spikeinterface.curation.tests.common import make_sorting_result, sorting_result_for_curation from spikeinterface.curation import remove_redundant_units diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index 800124300f..e179171ca3 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -45,17 +45,17 @@ def make_sorting_result(sparse=True, with_group=False): return sorting_result -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sorting_result_dense_for_export(): return make_sorting_result(sparse=False) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sorting_result_with_group_for_export(): return make_sorting_result(sparse=False, with_group=True) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sorting_result_sparse_for_export(): return make_sorting_result(sparse=True) diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index bac8ebd75f..a7d05335a7 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -10,7 +10,13 @@ from spikeinterface.core import compute_sparsity from spikeinterface.exporters import export_to_phy -from spikeinterface.exporters.tests.common import cache_folder, make_sorting_result +from spikeinterface.exporters.tests.common import ( + cache_folder, + make_sorting_result, + sorting_result_sparse_for_export, + sorting_result_with_group_for_export, + sorting_result_dense_for_export, +) def test_export_to_phy(sorting_result_sparse_for_export): diff --git a/src/spikeinterface/exporters/tests/test_report.py b/src/spikeinterface/exporters/tests/test_report.py index 5ad01f7609..cc3a0b2a64 100644 --- a/src/spikeinterface/exporters/tests/test_report.py +++ b/src/spikeinterface/exporters/tests/test_report.py @@ -8,18 +8,19 @@ from spikeinterface.exporters.tests.common import ( cache_folder, make_sorting_result, + sorting_result_sparse_for_export ) -def test_export_report(waveforms_extractor_sparse_for_export): +def test_export_report(sorting_result_sparse_for_export): report_folder = cache_folder / "report" if report_folder.exists(): shutil.rmtree(report_folder) - we = waveforms_extractor_sparse_for_export + sorting_result = sorting_result_sparse_for_export job_kwargs = dict(n_jobs=1, chunk_size=30000, progress_bar=True) - export_report(we, report_folder, force_computation=True, **job_kwargs) + export_report(sorting_result, report_folder, force_computation=True, **job_kwargs) if __name__ == "__main__": diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 4b6b50350e..877383cd4b 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -95,9 +95,8 @@ def setUpClass(cls): cls.sorting_result_sparse.select_random_spikes() cls.sorting_result_sparse.compute(extensions_to_compute, **job_kwargs) - # cls.skip_backends = ["ipywidgets", "ephyviewer"] - # TODO : delete this after debug - cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] + cls.skip_backends = ["ipywidgets", "ephyviewer"] + # cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] if ON_GITHUB and not KACHERY_CLOUD_SET: cls.skip_backends.append("sortingview") From d07ef00a0616cf3590a7cf74c524a68b251c1d9d Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 14 Feb 2024 09:20:59 +0100 Subject: [PATCH 070/192] WIP --- src/spikeinterface/core/job_tools.py | 7 ++++--- .../sorters/internal/spyking_circus2.py | 13 +++++++------ .../sortingcomponents/clustering/circus.py | 4 ++-- .../clustering/random_projections.py | 18 +++--------------- 4 files changed, 16 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 0eec3e6b85..96770bbe5f 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -62,9 +62,10 @@ def fix_job_kwargs(runtime_job_kwargs): ) # remove None runtime_job_kwargs_exclude_none = runtime_job_kwargs.copy() - for job_key, job_value in runtime_job_kwargs.items(): - if job_value is None: - del runtime_job_kwargs_exclude_none[job_key] + # Whe should remove these lines, otherwise, we can not reset values for total_memory/chunk_size/... on the fly + #for job_key, job_value in runtime_job_kwargs.items(): + # if job_value is None: + # del runtime_job_kwargs_exclude_none[job_key] job_kwargs.update(runtime_job_kwargs_exclude_none) # if n_jobs is -1, set to os.cpu_count() (n_jobs is always in global job_kwargs) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 44b09c7668..72deed118b 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -38,7 +38,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "select_per_channel": False, }, "clustering": {"legacy": False}, - "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, + "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, "shared_memory": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -220,19 +220,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates = templates.to_sparse(sparsity) if params["debug"]: + templates.to_zarr(folder_path=clustering_folder / "templates") sorting = sorting.save(folder=clustering_folder / "sorting") ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces - matching_method = params["matching"]["method"] - matching_params = params["matching"]["method_kwargs"].copy() + matching_method = params["matching"].pop("method") + matching_params = params["matching"].copy() matching_params["templates"] = templates - matching_job_params = {} - matching_job_params.update(job_kwargs) + matching_job_params = job_kwargs.copy() if matching_method == "circus-omp-svd": + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: if value in matching_job_params: - matching_job_params.pop(value) + matching_job_params[value] = None matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 291328413f..2b6fa9346e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -55,8 +55,8 @@ class CircusClustering: "sparsity": {"method": "ptp", "threshold": 1}, "radius_um": 100, "n_svd": [5, 10], - "ms_before": 1, - "ms_after": 1, + "ms_before": 0.5, + "ms_after": 0.5, "random_seed": 42, "debug": False, "tmp_folder": None, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 0e5b5ca3fe..dd558a1dd4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -52,8 +52,8 @@ class RandomProjectionClustering: "sparsity": {"method": "ptp", "threshold": 1}, "radius_um": 100, "nb_projections": 10, - "ms_before": 1, - "ms_after": 1, + "ms_before": 0.5, + "ms_after": 0.5, "random_seed": 42, "smoothing_kwargs": {"window_length_ms": 0.25}, "debug": False, @@ -135,18 +135,6 @@ def main_function(cls, recording, peaks, params): clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) peak_labels = clustering[0] - # peak_labels = -1 * np.ones(len(peaks), dtype=int) - # nb_clusters = 0 - # for c in np.unique(peaks['channel_index']): - # mask = peaks['channel_index'] == c - # clustering = hdbscan.hdbscan(hdbscan_data[mask], **d['hdbscan_kwargs']) - # local_labels = clustering[0] - # valid_clusters = local_labels > -1 - # if np.sum(valid_clusters) > 0: - # local_labels[valid_clusters] += nb_clusters - # peak_labels[mask] = local_labels - # nb_clusters += len(np.unique(local_labels[valid_clusters])) - labels = np.unique(peak_labels) labels = labels[labels >= 0] @@ -176,7 +164,7 @@ def main_function(cls, recording, peaks, params): cleaning_matching_params = params["job_kwargs"].copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: if value in cleaning_matching_params: - cleaning_matching_params.pop(value) + cleaning_matching_params[value] = None cleaning_matching_params["chunk_duration"] = "100ms" cleaning_matching_params["n_jobs"] = 1 cleaning_matching_params["verbose"] = False From a98fb2ea90822d227ed43bc6e48599b693dc8a05 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Feb 2024 08:23:34 +0000 Subject: [PATCH 071/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/job_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 96770bbe5f..310af84b30 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -63,7 +63,7 @@ def fix_job_kwargs(runtime_job_kwargs): # remove None runtime_job_kwargs_exclude_none = runtime_job_kwargs.copy() # Whe should remove these lines, otherwise, we can not reset values for total_memory/chunk_size/... on the fly - #for job_key, job_value in runtime_job_kwargs.items(): + # for job_key, job_value in runtime_job_kwargs.items(): # if job_value is None: # del runtime_job_kwargs_exclude_none[job_key] job_kwargs.update(runtime_job_kwargs_exclude_none) From 7d4ec099202203fc6904120b4717d5c6cc9d5d31 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 14 Feb 2024 10:00:27 +0100 Subject: [PATCH 072/192] remove empty templates during export to sparse --- src/spikeinterface/core/template.py | 46 ++++++++++++++----- .../sorters/internal/spyking_circus2.py | 4 +- .../clustering/clustering_tools.py | 4 +- .../clustering/random_projections.py | 2 +- 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index c5c9a4a0cf..2fc34e02d1 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -108,20 +108,39 @@ def __post_init__(self): if not self._are_passed_templates_sparse(): raise ValueError("Sparsity mask passed but the templates are not sparse") - def to_sparse(self, sparsity): + def to_sparse(self, sparsity, remove_empty=True): + # Turn a dense representation of templates into a sparse one, given some sparsity. + # Templates that are empty after sparsification can be removed via the remove_empty flag assert isinstance(sparsity, ChannelSparsity), "sparsity should be of type ChannelSparsity" assert self.sparsity_mask is None, "Templates should be dense" - return Templates( - templates_array=sparsity.sparsify_templates(self.templates_array), - sampling_frequency=self.sampling_frequency, - nbefore=self.nbefore, - sparsity_mask=sparsity.mask, - channel_ids=self.channel_ids, - unit_ids=self.unit_ids, - probe=self.probe, - check_for_consistent_sparsity=self.check_for_consistent_sparsity, - ) + if not remove_empty: + return Templates( + templates_array=sparsity.sparsify_templates(self.templates_array), + sampling_frequency=self.sampling_frequency, + nbefore=self.nbefore, + sparsity_mask=sparsity.mask, + channel_ids=self.channel_ids, + unit_ids=self.unit_ids, + probe=self.probe, + check_for_consistent_sparsity=self.check_for_consistent_sparsity, + ) + + else: + templates_array = sparsity.sparsify_templates(self.templates_array) + norms = np.linalg.norm(templates_array, axis=(1, 2)) + not_empty = norms > 0 + new_sparsity = ChannelSparsity(sparsity.mask[not_empty], sparsity.unit_ids[not_empty], sparsity.channel_ids) + return Templates( + templates_array=new_sparsity.sparsify_templates(self.templates_array[not_empty]), + sampling_frequency=self.sampling_frequency, + nbefore=self.nbefore, + sparsity_mask=new_sparsity.mask, + channel_ids=self.channel_ids, + unit_ids=self.unit_ids[not_empty], + probe=self.probe, + check_for_consistent_sparsity=self.check_for_consistent_sparsity, + ) def get_one_template_dense(self, unit_index): if self.sparsity is None: @@ -362,3 +381,8 @@ def get_channel_locations(self): assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set" channel_locations = self.probe.contact_positions return channel_locations + + +def get_norms_from_templates(templates): + assert isinstance(templates, Templates) + return np.linalg.norm(templates.get_dense_templates(), axis=(1,2)) \ No newline at end of file diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 72deed118b..a205c49bc0 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -28,7 +28,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 0.5}, + "sparsity": {"method": "ptp", "threshold": 1}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { @@ -217,7 +217,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"]) - templates = templates.to_sparse(sparsity) + templates = templates.to_sparse(sparsity, remove_empty=True) if params["debug"]: templates.to_zarr(folder_path=clustering_folder / "templates") diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 90a1731f6f..00108edfc9 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -553,7 +553,7 @@ def remove_duplicates_via_matching( job_kwargs = fix_job_kwargs(job_kwargs) - templates_array = templates.templates_array + templates_array = templates.get_dense_templates() nb_templates = len(templates_array) duration = templates.nbefore + templates.nafter @@ -750,4 +750,4 @@ def remove_duplicates_via_dip(wfs_arrays, peak_labels, dip_threshold=1, cosine_t labels = np.unique(new_labels) labels = labels[labels >= 0] - return labels, new_labels + return labels, new_labels \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index dd558a1dd4..e73c7ccb65 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -159,7 +159,7 @@ def main_function(cls, recording, peaks, params): templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() ) sparsity = compute_sparsity(templates, get_noise_levels(recording), **params["sparsity"]) - templates = templates.to_sparse(sparsity) + templates = templates.to_sparse(sparsity, remove_empty=True) cleaning_matching_params = params["job_kwargs"].copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: From 7988003bf013fc99cb20dcda1b004a0299c3e4b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Feb 2024 09:03:15 +0000 Subject: [PATCH 073/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/template.py | 4 ++-- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 2fc34e02d1..4b789b184b 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -109,7 +109,7 @@ def __post_init__(self): raise ValueError("Sparsity mask passed but the templates are not sparse") def to_sparse(self, sparsity, remove_empty=True): - # Turn a dense representation of templates into a sparse one, given some sparsity. + # Turn a dense representation of templates into a sparse one, given some sparsity. # Templates that are empty after sparsification can be removed via the remove_empty flag assert isinstance(sparsity, ChannelSparsity), "sparsity should be of type ChannelSparsity" assert self.sparsity_mask is None, "Templates should be dense" @@ -385,4 +385,4 @@ def get_channel_locations(self): def get_norms_from_templates(templates): assert isinstance(templates, Templates) - return np.linalg.norm(templates.get_dense_templates(), axis=(1,2)) \ No newline at end of file + return np.linalg.norm(templates.get_dense_templates(), axis=(1, 2)) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 00108edfc9..e620f009ef 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -750,4 +750,4 @@ def remove_duplicates_via_dip(wfs_arrays, peak_labels, dip_threshold=1, cosine_t labels = np.unique(new_labels) labels = labels[labels >= 0] - return labels, new_labels \ No newline at end of file + return labels, new_labels From 579940cd1551e83ef61ff1f6368e95ce02be5fed Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 14 Feb 2024 10:05:11 +0100 Subject: [PATCH 074/192] Noise levels passed as arguments --- src/spikeinterface/sorters/internal/spyking_circus2.py | 3 ++- src/spikeinterface/sortingcomponents/clustering/circus.py | 5 ++++- .../sortingcomponents/clustering/random_projections.py | 5 ++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a205c49bc0..9419a7c77e 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -111,7 +111,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f.annotate(is_filtered=True) recording_f = zscore(recording_f, dtype="float32") - noise_levels = get_noise_levels(recording_f) + noise_levels = np.ones(recording_f.get_num_channels(), dtype=np.float32) if recording_f.check_serializability("json"): recording_f.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -160,6 +160,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs + clustering_params["noise_levels"] = noise_levels clustering_params["tmp_folder"] = sorter_output_folder / "clustering" if "legacy" in clustering_params: diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 2b6fa9346e..1f4c6ccd36 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -58,6 +58,7 @@ class CircusClustering: "ms_before": 0.5, "ms_after": 0.5, "random_seed": 42, + "noise_levels":None, "debug": False, "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, @@ -192,7 +193,9 @@ def main_function(cls, recording, peaks, params): templates = Templates( templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() ) - sparsity = compute_sparsity(templates, get_noise_levels(recording), **params["sparsity"]) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording) + sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) cleaning_matching_params = params["job_kwargs"].copy() diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index e73c7ccb65..9dc3c88ad1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -55,6 +55,7 @@ class RandomProjectionClustering: "ms_before": 0.5, "ms_after": 0.5, "random_seed": 42, + "noise_levels" : None, "smoothing_kwargs": {"window_length_ms": 0.25}, "debug": False, "tmp_folder": None, @@ -158,7 +159,9 @@ def main_function(cls, recording, peaks, params): templates = Templates( templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() ) - sparsity = compute_sparsity(templates, get_noise_levels(recording), **params["sparsity"]) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording) + sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity, remove_empty=True) cleaning_matching_params = params["job_kwargs"].copy() From f4819be33961be0895efe6e715b9acc105ea5c99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Feb 2024 09:07:56 +0000 Subject: [PATCH 075/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 1f4c6ccd36..9bdb3b0585 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -58,7 +58,7 @@ class CircusClustering: "ms_before": 0.5, "ms_after": 0.5, "random_seed": 42, - "noise_levels":None, + "noise_levels": None, "debug": False, "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 9dc3c88ad1..6b08cf82f7 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -55,7 +55,7 @@ class RandomProjectionClustering: "ms_before": 0.5, "ms_after": 0.5, "random_seed": 42, - "noise_levels" : None, + "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, "debug": False, "tmp_folder": None, From df628473d9ae16abb4358a04791b108948d6dc1c Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 14 Feb 2024 10:25:51 +0100 Subject: [PATCH 076/192] WIP --- .../clustering/clustering_tools.py | 39 +++++++------------ .../sortingcomponents/matching/circus.py | 20 +++++----- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index e620f009ef..2edb1b95ba 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -540,8 +540,7 @@ def remove_duplicates_via_matching( peak_labels, method_kwargs={}, job_kwargs={}, - tmp_folder=None, - method="circus-omp-svd", + tmp_folder=None ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core import BinaryRecordingExtractor @@ -602,30 +601,20 @@ def remove_duplicates_via_matching( sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging) local_params.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs + sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs + ) + local_params.update( + { + "overlaps": computed["overlaps"], + "normed_templates": computed["normed_templates"], + "norms": computed["norms"], + "temporal": computed["temporal"], + "spatial": computed["spatial"], + "singular": computed["singular"], + "units_overlaps": computed["units_overlaps"], + "unit_overlaps_indices": computed["unit_overlaps_indices"], + } ) - if method == "circus-omp-svd": - local_params.update( - { - "overlaps": computed["overlaps"], - "normed_templates": computed["normed_templates"], - "norms": computed["norms"], - "temporal": computed["temporal"], - "spatial": computed["spatial"], - "singular": computed["singular"], - "units_overlaps": computed["units_overlaps"], - "unit_overlaps_indices": computed["unit_overlaps_indices"], - } - ) - elif method == "circus-omp": - local_params.update( - { - "overlaps": computed["overlaps"], - "circus_templates": computed["normed_templates"], - "norms": computed["norms"], - "sparsities": computed["sparsities"], - } - ) valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) if np.sum(valid) > 0: if np.sum(valid) == 1: diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 41f92f78c2..3b0135322c 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -497,7 +497,6 @@ class CircusPeeler(BaseTemplateMatchingEngine): "min_amplitude": 0.5, "use_sparse_matrix_threshold": 0.25, "templates": None, - "sparse_kwargs": {"method": "ptp", "threshold": 1}, } @classmethod @@ -520,7 +519,6 @@ def _prepare_templates(cls, d): for count, unit_id in enumerate(all_units): (d["sparsities"][count],) = np.nonzero(sparsity[count]) - templates_array[count][:, ~sparsity[count]] = 0 d["norms"][count] = np.linalg.norm(templates_array[count]) templates_array[count] /= d["norms"][count] d["normed_templates"][count] = templates_array[count][:, sparsity[count]] @@ -612,15 +610,19 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] - if not "circus_templates" in d: + if "overlaps" not in d: d = cls._prepare_templates(d) - - d["overlaps"] = compute_overlaps( - d["normed_templates"], - d["num_samples"], - d["num_channels"], - d["sparsities"], + d["overlaps"] = compute_overlaps( + d["normed_templates"], + d["num_samples"], + d["num_channels"], + d["sparsities"], ) + else: + for key in [ + "circus_templates", "norms" + ]: + assert d[key] is not None, "If templates are provided, %d should also be there" % key d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) From fa4b53d8ecc92ef784f11e70b4dcb8466e9ca220 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Feb 2024 09:28:27 +0000 Subject: [PATCH 077/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/clustering_tools.py | 8 +------- src/spikeinterface/sortingcomponents/matching/circus.py | 6 ++---- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 2edb1b95ba..17b8dea89a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -535,13 +535,7 @@ def remove_duplicates( return labels, new_labels -def remove_duplicates_via_matching( - templates, - peak_labels, - method_kwargs={}, - job_kwargs={}, - tmp_folder=None -): +def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core import NumpySorting diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 3b0135322c..596ad84e64 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -617,11 +617,9 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["num_samples"], d["num_channels"], d["sparsities"], - ) + ) else: - for key in [ - "circus_templates", "norms" - ]: + for key in ["circus_templates", "norms"]: assert d[key] is not None, "If templates are provided, %d should also be there" % key d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) From b44f48dfcd137ebced6d0038cdc9a8f0d3232c13 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 14 Feb 2024 10:40:59 +0100 Subject: [PATCH 078/192] Fix principal component get_projection and pca_metrics. --- src/spikeinterface/core/job_tools.py | 2 +- src/spikeinterface/core/sortingresult.py | 29 +- .../core/tests/test_sortingresult.py | 6 + src/spikeinterface/core/waveform_tools.py | 2 +- .../postprocessing/amplitude_scalings.py | 6 +- .../postprocessing/principal_component.py | 353 ++++++++++-------- .../tests/common_extension_tests.py | 13 +- .../tests/test_principal_component.py | 129 ++++--- .../qualitymetrics/pca_metrics.py | 45 +-- .../quality_metric_calculator.py | 2 - .../qualitymetrics/tests/test_pca_metrics.py | 13 +- .../tests/test_quality_metric_calculator.py | 24 +- 12 files changed, 341 insertions(+), 283 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 0eec3e6b85..917aaae48c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -243,7 +243,7 @@ class ChunkRecordingExecutor: * in parallel with ProcessPoolExecutor (higher speed) The initializer ("init_func") allows to set a global context to avoid heavy serialization - (for examples, see implementation in `core.WaveformExtractor`). + (for examples, see implementation in `core.waveform_tools`). Parameters ---------- diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index a95bec94a5..a2a1553181 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -30,6 +30,9 @@ from .node_pipeline import run_node_pipeline +# TODO make some_spikes a method of SortingResult + + # high level function def start_sorting_result( sorting, recording, format="memory", folder=None, sparse=True, sparsity=None, **sparsity_kwargs @@ -581,10 +584,30 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> else: raise ValueError("SortingResult.save: wrong format") + + # propagate random_spikes_indices is already done + if self.random_spikes_indices is not None: + if unit_ids is None: + new_sortres.random_spikes_indices = self.random_spikes_indices.copy() + else: + # more tricky + spikes = self.sorting.to_spike_vector() + + keep_unit_indices = np.flatnonzero(np.isin(self.unit_ids, unit_ids)) + keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) + + selected_mask = np.zeros(spikes.size, dtype=bool) + selected_mask[self.random_spikes_indices] = True + + new_sortres.random_spikes_indices = np.flatnonzero(selected_mask[keep_spike_mask]) + + # save it + new_sortres._save_random_spikes_indices() + # make a copy of extensions - # note that the copy of extension handle itself the slicing of units when necessary + # note that the copy of extension handle itself the slicing of units when necessary and also the saveing for extension_name, extension in self.extensions.items(): - new_sortres.extensions[extension_name] = extension.copy(new_sortres, unit_ids=unit_ids) + new_ext = new_sortres.extensions[extension_name] = extension.copy(new_sortres, unit_ids=unit_ids) return new_sortres @@ -1031,7 +1054,9 @@ def select_random_spikes(self, **random_kwargs): self.random_spikes_indices = random_spikes_selection( self.sorting, self.rec_attributes["num_samples"], **random_kwargs ) + self._save_random_spikes_indices() + def _save_random_spikes_indices(self): if self.format == "binary_folder": np.save(self.folder / "random_spikes_indices.npy", self.random_spikes_indices) elif self.format == "zarr": diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index c7ef3e3776..0fa5427dd0 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -137,6 +137,12 @@ def _check_sorting_results(sortres, original_sorting): sortres.compute("dummy") keep_unit_ids = original_sorting.unit_ids[::2] sortres2 = sortres.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) + + # check that random_spikes_indices are remmaped + assert sortres2.random_spikes_indices is not None + some_spikes = sortres2.sorting.to_spike_vector()[sortres2.random_spikes_indices] + assert np.array_equal(np.unique(some_spikes["unit_index"]), np.arange(keep_unit_ids.size)) + # check propagation of result data and correct sligin assert np.array_equal(keep_unit_ids, sortres2.unit_ids) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 04813921a6..b50c991744 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -1,7 +1,7 @@ """ This module contains low-level functions to extract snippets of traces (aka "spike waveforms"). -This is internally used by WaveformExtractor, but can also be used as a sorting component. +This is internally used by SortingResult, but can also be used as a sorting component. It is a 2-step approach: 1. allocate buffers (shared file or memory) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 6553115c43..511fd10636 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -21,7 +21,7 @@ class ComputeAmplitudeScalings(ResultExtension): """ - Computes the amplitude scalings from a WaveformExtractor. + Computes the amplitude scalings from a SortingResult. Parameters ---------- @@ -35,10 +35,10 @@ class ComputeAmplitudeScalings(ResultExtension): dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. ms_before : float or None, default: None The cut out to apply before the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_before is used. + If None, the SortingResult ms_before is used. ms_after : float or None, default: None The cut out to apply after the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_after is used. + If None, the SortingResult ms_after is used. handle_collisions: bool, default: True Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 3b3f949b93..cb58b732ed 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -13,14 +13,9 @@ from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs -# from spikeinterface.core.globals import get_global_tmp_folder - _possible_modes = ["by_channel_local", "by_channel_global", "concatenated"] -# TODO handle extra sparsity - - class ComputePrincipalComponents(ResultExtension): """ Compute PC scores from waveform extractor. The PCA projections are pre-computed only @@ -51,7 +46,9 @@ class ComputePrincipalComponents(ResultExtension): >>> sorting_result.compute("principal_components", n_components=3, mode='by_channel_local') >>> ext_pca = sorting_result.get_extension("principal_components") >>> # get pre-computed projections for unit_id=1 - >>> projections = ext_pca.get_projections(unit_id=1) + >>> unit_projections = ext_pca.get_projections_one_unit(unit_id=1, sparse=False) + >>> # get pre-computed projections for some units on some channels + >>> some_projections, spike_unit_indices = ext_pca.get_some_projections(channel_ids=None, unit_ids=None) >>> # retrieve fitted pca model(s) >>> pca_model = ext_pca.get_pca_model() >>> # compute projections on new waveforms @@ -77,22 +74,15 @@ def _set_params( mode="by_channel_local", whiten=True, dtype="float32", - sparsity=None, ): assert mode in _possible_modes, "Invalid mode!" - if sparsity is not None: - # TODO alessio: implement local sparsity or not ?? - raise NotImplementedError - # the sparsity in params is ONLY the injected sparsity and not the sorting_result one params = dict( n_components=n_components, mode=mode, whiten=whiten, dtype=np.dtype(dtype), - # sparsity=sparsity, - # tmp_folder=tmp_folder, ) return params @@ -111,7 +101,27 @@ def _select_extension_data(self, unit_ids): new_data[k] = v return new_data - def get_projections(self, unit_id, sparse=False): + + def get_pca_model(self): + """ + Returns the scikit-learn PCA model objects. + + Returns + ------- + pca_models: PCA object(s) + * if mode is "by_channel_local", "pca_model" is a list of PCA model by channel + * if mode is "by_channel_global" or "concatenated", "pca_model" is a single PCA model + """ + mode = self.params["mode"] + if mode == "by_channel_local": + pca_models = [] + for chan_id in self.sorting_result.channel_ids: + pca_models.append(self.data[f"pca_model_{mode}_{chan_id}"]) + else: + pca_models = self.data[f"pca_model_{mode}"] + return pca_models + + def get_projections_one_unit(self, unit_id, sparse=False): """ Returns the computed projections for the sampled waveforms of a unit id. @@ -120,84 +130,114 @@ def get_projections(self, unit_id, sparse=False): unit_id : int or str The unit id to return PCA projections for sparse: bool, default: False - If True, and sparsity is not None, only projections on sparse channels are returned. + If True, and SortingResult must be sparse then only projections on sparse channels are returned. + Channel indices are also returned. Returns ------- projections: np.array The PCA projections (num_waveforms, num_components, num_channels). In case sparsity is used, only the projections on sparse channels are returned. + channel_indices: np.array + """ - projections = self.data[f"pca_{unit_id}"] - mode = self.params["mode"] - if mode in ("by_channel_local", "by_channel_global") and sparse: - sparsity = self.get_sparsity() - if sparsity is not None: - projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] - return projections + sparsity = self.sorting_result.sparsity + sorting = self.sorting_result.sorting + + if sparse: + assert self.params["mode"] != "concatenated", "mode concatenated cannot retrieve sparse projection" + assert sparsity is not None, "sparse projection need SortingResult to be sparse" - def get_pca_model(self): + spikes = sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + + unit_index = sorting.id_to_index(unit_id) + spike_mask = some_spikes["unit_index"] == unit_index + projections = self.data["pca_projection"][spike_mask] + + if sparsity is None: + return projections + else: + channel_indices = sparsity.unit_id_to_channel_indices[unit_id] + projections = projections[:, :, :channel_indices.size] + if sparse: + return projections, channel_indices + else: + num_chans = self.sorting_result.get_num_channels() + projections_ = np.zeros((projections.shape[0], projections.shape[1], num_chans), dtype=projections.dtype) + projections_[:, :, channel_indices] = projections + return projections_ + + def get_some_projections(self, channel_ids=None, unit_ids=None): """ - Returns the scikit-learn PCA model objects. + Returns the computed projections for the sampled waveforms of some units and some channels. + + When internally sparse, this function realign projection on given channel_ids set. + + Parameters + ---------- + channel_ids : list, default: None + List of channel ids on which projections must aligned + unit_ids : list, default: None + List of unit ids to return projections for Returns ------- - pca_models: PCA object(s) - * if mode is "by_channel_local", "pca_model" is a list of PCA model by channel - * if mode is "by_channel_global" or "concatenated", "pca_model" is a single PCA model + some_projections: np.array + The PCA projections (num_spikes, num_components, num_sparse_channels) + spike_unit_indices: np.array + Array a copy of with some_spikes["unit_index"] of returned PCA projections of shape (num_spikes, ) """ - mode = self.params["mode"] - if mode == "by_channel_local": - pca_models = [] - for chan_id in self.sorting_result.channel_ids: - pca_models.append(self.data[f"pca_model_{mode}_{chan_id}"]) + sorting = self.sorting_result.sorting + if unit_ids is None: + unit_ids = sorting.unit_ids + + if channel_ids is None: + channel_ids = self.sorting_result.channel_ids + + channel_indices = self.sorting_result.channel_ids_to_indices(channel_ids) + + # note : internally when sparse PCA are not aligned!! Exactly like waveforms. + all_projections = self.data["pca_projection"] + num_components = all_projections.shape[1] + dtype = all_projections.dtype + + sparsity = self.sorting_result.sparsity + + spikes = sorting.to_spike_vector() + some_spikes = spikes[self.sorting_result.random_spikes_indices] + + unit_indices = sorting.ids_to_indices(unit_ids) + selected_inds = np.flatnonzero(np.isin(some_spikes["unit_index"], unit_indices)) + + print(selected_inds.size, unit_indices, some_spikes["unit_index"].size) + print(np.min(selected_inds), np.max(selected_inds)) + spike_unit_indices = some_spikes["unit_index"][selected_inds] + + if sparsity is None: + some_projections = all_projections[selected_inds, :, :][:, :, channel_indices] else: - pca_models = self.data[f"pca_model_{mode}"] - return pca_models + # need re-alignement + some_projections = np.zeros((selected_inds.size, num_components, channel_indices.size), dtype=dtype) + + for unit_id in unit_ids: + unit_index = sorting.id_to_index(unit_id) + sparse_projection, local_chan_inds = self.get_projections_one_unit(unit_id, sparse=True) + + # keep only requested channels + channel_mask = np.isin(local_chan_inds, channel_indices) + sparse_projection = sparse_projection[:, :, channel_mask] + local_chan_inds = local_chan_inds[channel_mask] + + spike_mask = np.flatnonzero(spike_unit_indices == unit_index) + proj = np.zeros((spike_mask.size, num_components, channel_indices.size), dtype=dtype) + # inject in requested channels + channel_mask = np.isin(channel_indices, local_chan_inds) + proj[:, :, channel_mask] = sparse_projection + some_projections[spike_mask, :, :] = proj + + return some_projections, spike_unit_indices - # def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): - # """ - # Returns the computed projections for the sampled waveforms of all units. - - # Parameters - # ---------- - # channel_ids : list, default: None - # List of channel ids on which projections are computed - # unit_ids : list, default: None - # List of unit ids to return projections for - # outputs: str - # * "id": "all_labels" contain unit ids - # * "index": "all_labels" contain unit indices - - # Returns - # ------- - # all_labels: np.array - # Array with labels (ids or indices based on "outputs") of returned PCA projections - # all_projections: np.array - # The PCA projections (num_all_waveforms, num_components, num_channels) - # """ - # if unit_ids is None: - # unit_ids = self.sorting_result.sorting.unit_ids - - # all_labels = [] #  can be unit_id or unit_index - # all_projections = [] - # for unit_index, unit_id in enumerate(unit_ids): - # proj = self.get_projections(unit_id, sparse=False) - # if channel_ids is not None: - # chan_inds = self.sorting_result.chanpca_projectionnel_ids_to_indices(channel_ids) - # proj = proj[:, :, chan_inds] - # n = proj.shape[0] - # if outputs == "id": - # labels = np.array([unit_id] * n) - # elif outputs == "index": - # labels = np.ones(n, dtype="int64") - # labels[:] = unit_index - # all_labels.append(labels) - # all_projections.append(proj) - # all_labels = np.concatenate(all_labels, axis=0) - # all_projections = np.concatenate(all_projections, axis=0) - - # return all_labels, all_projections def project_new(self, new_spikes, new_waveforms, progress_bar=True): """ @@ -220,11 +260,6 @@ def project_new(self, new_spikes, new_waveforms, progress_bar=True): new_projections = self._transform_waveforms(new_spikes, new_waveforms, pca_model, progress_bar=progress_bar) return new_projections - def get_sparsity(self): - if self.sorting_result.is_sparse(): - return self.sorting_result.sparsity - return self.params["sparsity"] - def _run(self, **job_kwargs): """ Compute the PCs on waveforms extacted within the by ComputeWaveforms. @@ -265,76 +300,74 @@ def _run(self, **job_kwargs): def _get_data(self): return self.data["pca_projection"] - # @staticmethod - # def get_extension_function(): - # return compute_principal_components - - # def run_for_all_spikes(self, file_path=None, **job_kwargs): - # """ - # Project all spikes from the sorting on the PCA model. - # This is a long computation because waveform need to be extracted from each spikes. - - # Used mainly for `export_to_phy()` - - # PCs are exported to a .npy single file. - - # Parameters - # ---------- - # file_path : str or Path or None - # Path to npy file that will store the PCA projections. - # If None, output is saved in principal_components/all_pcs.npy - # {} - # """ - - # job_kwargs = fix_job_kwargs(job_kwargs) - # p = self.params - # we = self.sorting_result - # sorting = we.sorting - # assert ( - # we.has_recording() - # ), "To compute PCA projections for all spikes, the waveform extractor needs the recording" - # recording = we.recording - - # assert sorting.get_num_segments() == 1 - # assert p["mode"] in ("by_channel_local", "by_channel_global") - - # if file_path is None: - # file_path = self.extension_folder / "all_pcs.npy" - # file_path = Path(file_path) - - # sparsity = self.get_sparsity() - # if sparsity is None: - # sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} - # max_channels_per_template = we.get_num_channels() - # else: - # sparse_channels_indices = sparsity.unit_id_to_channel_indices - # max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) - - # unit_channels = [sparse_channels_indices[unit_id] for unit_id in sorting.unit_ids] - - # pca_model = self.get_pca_model() - # if p["mode"] in ["by_channel_global", "concatenated"]: - # pca_model = [pca_model] * recording.get_num_channels() - - # num_spikes = sorting.to_spike_vector().size - # shape = (num_spikes, p["n_components"], max_channels_per_template) - # all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) - # all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) - - # # and run - # func = _all_pc_extractor_chunk - # init_func = _init_work_all_pc_extractor - # init_args = ( - # recording, - # sorting.to_multiprocessing(job_kwargs["n_jobs"]), - # all_pcs_args, - # we.nbefore, - # we.nafter, - # unit_channels, - # pca_model, - # ) - # processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs) - # processor.run() + def run_for_all_spikes(self, file_path=None, **job_kwargs): + """ + Project all spikes from the sorting on the PCA model. + This is a long computation because waveform need to be extracted from each spikes. + + Used mainly for `export_to_phy()` + + PCs are exported to a .npy single file. + + Parameters + ---------- + file_path : str or Path or None + Path to npy file that will store the PCA projections. + {} + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + p = self.params + we = self.sorting_result + sorting = we.sorting + assert ( + we.has_recording() + ), "To compute PCA projections for all spikes, the waveform extractor needs the recording" + recording = we.recording + + # assert sorting.get_num_segments() == 1 + assert p["mode"] in ("by_channel_local", "by_channel_global") + + assert file_path is not None + file_path = Path(file_path) + + sparsity = self.sorting_result.sparsity + if sparsity is None: + sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} + max_channels_per_template = we.get_num_channels() + else: + sparse_channels_indices = sparsity.unit_id_to_channel_indices + max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) + + unit_channels = [sparse_channels_indices[unit_id] for unit_id in sorting.unit_ids] + + pca_model = self.get_pca_model() + if p["mode"] in ["by_channel_global", "concatenated"]: + pca_model = [pca_model] * recording.get_num_channels() + + num_spikes = sorting.to_spike_vector().size + shape = (num_spikes, p["n_components"], max_channels_per_template) + all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) + all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) + + + waveforms_ext = self.sorting_result.get_extension("waveforms") + + + # and run + func = _all_pc_extractor_chunk + init_func = _init_work_all_pc_extractor + init_args = ( + recording, + sorting.to_multiprocessing(job_kwargs["n_jobs"]), + all_pcs_args, + waveforms_ext.nbefore, + waveforms_ext.nafter, + unit_channels, + pca_model, + ) + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs) + processor.run() def _fit_by_channel_local(self, n_jobs, progress_bar): from sklearn.decomposition import IncrementalPCA @@ -368,7 +401,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): n_jobs = min(n_jobs, len(items)) with ProcessPoolExecutor(max_workers=n_jobs) as executor: - results = executor.map(partial_fit_one_channel, items) + results = executor.map(_partial_fit_one_channel, items) for chan_ind, pca_model_updated in results: pca_models[chan_ind] = pca_model_updated @@ -593,24 +626,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte return worker_ctx -# WaveformPrincipalComponent.run_for_all_spikes.__doc__ = WaveformPrincipalComponent.run_for_all_spikes.__doc__.format( -# _shared_job_kwargs_doc -# ) - register_result_extension(ComputePrincipalComponents) compute_principal_components = ComputePrincipalComponents.function_factory() -# def partial_fit_one_channel(args): -# pca_file, wf_chan = args -# with open(pca_file, "rb") as fid: -# pca_model = pickle.load(fid) -# pca_model.partial_fit(wf_chan) -# with pca_file.open("wb") as f: -# pickle.dump(pca_model, f) - -def partial_fit_one_channel(args): +def _partial_fit_one_channel(args): chan_ind, pca_model, wf_chan = args pca_model.partial_fit(wf_chan) return chan_ind, pca_model diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 8078b031e3..6736a06ccf 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -17,6 +17,7 @@ else: cache_folder = Path("cache_folder") / "postprocessing" +cache_folder.mkdir(exist_ok=True, parents=True) def get_dataset(): recording, sorting = generate_ground_truth_recording( @@ -74,16 +75,8 @@ class ResultExtensionCommonTestSuite: @classmethod def setUpClass(cls): cls.recording, cls.sorting = get_dataset() - # sparsity is computed once for all cases to save processing time - cls.sparsity = estimate_sparsity(cls.recording, cls.sorting) - - # def tearDown(self): - # for k in list(self.sorting_results.keys()): - # sorting_result = self.sorting_results.pop(k) - # if sorting_result.format != "memory": - # folder = sorting_result.folder - # del sorting_result - # shutil.rmtree(folder) + # sparsity is computed once for all cases to save processing time and force a small radius + cls.sparsity = estimate_sparsity(cls.recording, cls.sorting, method="radius", radius_um=20) @property def extension_name(self): diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index fa6a0bfd9b..857cc340bc 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -5,18 +5,9 @@ import numpy as np from spikeinterface.postprocessing import ComputePrincipalComponents, compute_principal_components -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite, cache_folder -# from spikeinterface import compute_sparsity -# from spikeinterface.postprocessing import WaveformPrincipalComponent, compute_principal_components -# from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite - -# if hasattr(pytest, "global_test_folder"): -# cache_folder = pytest.global_test_folder / "postprocessing" -# else: -# cache_folder = Path("cache_folder") / "postprocessing" - DEBUG = False @@ -38,43 +29,87 @@ def test_mode_concatenated(self): n_components = 3 sorting_result.compute("principal_components", mode="concatenated", n_components=n_components) - ext = sorting_result.get_extension(self.extension_name) + ext = sorting_result.get_extension("principal_components") assert ext is not None assert len(ext.data) > 0 pca = ext.data["pca_projection"] assert pca.ndim == 2 assert pca.shape[1] == n_components - - # def test_compute_for_all_spikes(self): - # sorting_result = self._prepare_sorting_result(format="memory", sparse=False) - - # n_components = 3 - # sorting_result.compute("principal_components", mode="by_channel_local", n_components=n_components) - # ext = sorting_result.get_extension(self.extension_name) - # ext.run_for_all_spikes() - - # pc_file1 = pc.extension_folder / "all_pc1.npy" - # pc.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) - # all_pc1 = np.load(pc_file1) - - # pc_file2 = pc.extension_folder / "all_pc2.npy" - # pc.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) - # all_pc2 = np.load(pc_file2) - - # assert np.array_equal(all_pc1, all_pc2) - - # # test with sparsity - # sparsity = compute_sparsity(we, method="radius", radius_um=50) - # we_copy = we.save(folder=cache_folder / "we_copy") - # pc_sparse = self.extension_class.get_extension_function()(we_copy, sparsity=sparsity, load_if_exists=False) - # pc_file_sparse = pc.extension_folder / "all_pc_sparse.npy" - # pc_sparse.run_for_all_spikes(pc_file_sparse, chunk_size=10000, n_jobs=1) - # all_pc_sparse = np.load(pc_file_sparse) - # all_spikes_seg0 = we_copy.sorting.to_spike_vector(concatenated=False)[0] - # for unit_index, unit_id in enumerate(we.unit_ids): - # sparse_channel_ids = sparsity.unit_id_to_channel_ids[unit_id] - # pc_unit = all_pc_sparse[all_spikes_seg0["unit_index"] == unit_index] - # assert np.allclose(pc_unit[:, :, len(sparse_channel_ids) :], 0) + + def test_get_projections(self): + + for sparse in (False, True): + + sorting_result = self._prepare_sorting_result(format="memory", sparse=sparse) + num_chans = sorting_result.get_num_channels() + n_components = 2 + + sorting_result.compute("principal_components", mode="by_channel_global", n_components=n_components) + ext = sorting_result.get_extension("principal_components") + + for unit_id in sorting_result.unit_ids: + if not sparse: + one_proj = ext.get_projections_one_unit(unit_id, sparse=False) + assert one_proj.shape[1] == n_components + assert one_proj.shape[2] == num_chans + else: + one_proj = ext.get_projections_one_unit(unit_id, sparse=False) + assert one_proj.shape[1] == n_components + assert one_proj.shape[2] == num_chans + + one_proj, chan_inds = ext.get_projections_one_unit(unit_id, sparse=True) + assert one_proj.shape[1] == n_components + assert one_proj.shape[2] < num_chans + assert one_proj.shape[2] == chan_inds.size + + some_unit_ids = sorting_result.unit_ids[::2] + some_channel_ids = sorting_result.channel_ids[::2] + + + # this should be all spikes all channels + some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) + assert some_projections.shape[0] == spike_unit_index.shape[0] + assert spike_unit_index.shape[0] == sorting_result.random_spikes_indices.size + assert some_projections.shape[1] == n_components + assert some_projections.shape[2] == num_chans + + # this should be some spikes all channels + some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) + assert some_projections.shape[0] == spike_unit_index.shape[0] + assert spike_unit_index.shape[0] < sorting_result.random_spikes_indices.size + assert some_projections.shape[1] == n_components + assert some_projections.shape[2] == num_chans + assert 1 not in spike_unit_index + + # this should be some spikes some channels + some_projections, spike_unit_index = ext.get_some_projections(channel_ids=some_channel_ids, unit_ids=some_unit_ids) + assert some_projections.shape[0] == spike_unit_index.shape[0] + assert spike_unit_index.shape[0] < sorting_result.random_spikes_indices.size + assert some_projections.shape[1] == n_components + assert some_projections.shape[2] == some_channel_ids.size + assert 1 not in spike_unit_index + + def test_compute_for_all_spikes(self): + + for sparse in (True, False): + sorting_result = self._prepare_sorting_result(format="memory", sparse=sparse) + + num_spikes = sorting_result.sorting.to_spike_vector().size + + n_components = 3 + sorting_result.compute("principal_components", mode="by_channel_local", n_components=n_components) + ext = sorting_result.get_extension("principal_components") + + pc_file1 = cache_folder / "all_pc1.npy" + ext.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) + all_pc1 = np.load(pc_file1) + assert all_pc1.shape[0] == num_spikes + + pc_file2 = cache_folder / "all_pc2.npy" + ext.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) + all_pc2 = np.load(pc_file2) + + assert np.array_equal(all_pc1, all_pc2) def test_project_new(self): from sklearn.decomposition import IncrementalPCA @@ -100,10 +135,12 @@ def test_project_new(self): if __name__ == "__main__": test = PrincipalComponentsExtensionTest() test.setUpClass() - test.test_extension() - test.test_mode_concatenated() - # test.test_compute_for_all_spikes() - test.test_project_new() + # test.test_extension() + # test.test_mode_concatenated() + # test.test_get_projections() + test.test_compute_for_all_spikes() + # test.test_project_new() + # ext = test.sorting_results["sparseTrue_memory"].get_extension("principal_components") # pca = ext.data["pca_projection"] diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 34a522273c..d4d45cfd20 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -121,40 +121,29 @@ def calculate_pc_metrics( run_in_parallel = n_jobs > 1 - units_loop = enumerate(unit_ids) - if progress_bar and not run_in_parallel: - units_loop = tqdm(units_loop, desc="Computing PCA metrics", total=len(unit_ids)) + if run_in_parallel: parallel_functions = [] - # all_labels, all_pcs = pca.get_all_projections() - # TODO: this is wring all_pcs used to be dense even when the waveform extractor was sparse - all_pcs = pca_ext.data["pca_projection"] - spikes = sorting.to_spike_vector() - some_spikes = spikes[sorting_result.random_spikes_indices] - all_labels = sorting.unit_ids[some_spikes["unit_index"]] + # this get dense projection for selected unit_ids + dense_projections, spike_unit_indices = pca_ext.get_some_projections(channel_ids=None, unit_ids=unit_ids) + all_labels = sorting.unit_ids[spike_unit_indices] items = [] for unit_id in unit_ids: - print(sorting_result.is_sparse()) if sorting_result.is_sparse(): neighbor_channel_ids = sorting_result.sparsity.unit_id_to_channel_ids[unit_id] neighbor_unit_ids = [ other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids ] - # elif sparsity is not None: - # neighbor_channel_ids = sparsity.unit_id_to_channel_ids[unit_id] - # neighbor_unit_ids = [ - # other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids - # ] else: neighbor_channel_ids = channel_ids neighbor_unit_ids = unit_ids neighbor_channel_indices = sorting_result.channel_ids_to_indices(neighbor_channel_ids) labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] - pcs = all_pcs[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) func_args = ( @@ -165,28 +154,30 @@ def calculate_pc_metrics( unit_ids, qm_params, seed, - # we.folder, n_spikes_all_units, fr_all_units, ) items.append(func_args) if not run_in_parallel: + units_loop = enumerate(unit_ids) + if progress_bar: + units_loop = tqdm(units_loop, desc="calculate_pc_metrics", total=len(unit_ids)) + for unit_ind, unit_id in units_loop: pca_metrics_unit = pca_metrics_one_unit(items[unit_ind]) for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric else: - raise NotImplementedError - # with ProcessPoolExecutor(n_jobs) as executor: - # results = executor.map(pca_metrics_one_unit, items) - # if progress_bar: - # results = tqdm(results, total=len(unit_ids)) - - # for ui, pca_metrics_unit in enumerate(results): - # unit_id = unit_ids[ui] - # for metric_name, metric in pca_metrics_unit.items(): - # pc_metrics[metric_name][unit_id] = metric + with ProcessPoolExecutor(n_jobs) as executor: + results = executor.map(pca_metrics_one_unit, items) + if progress_bar: + results = tqdm(results, total=len(unit_ids), desc="calculate_pc_metrics") + + for ui, pca_metrics_unit in enumerate(results): + unit_id = unit_ids[ui] + for metric_name, metric in pca_metrics_unit.items(): + pc_metrics[metric_name][unit_id] = metric return pc_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 59cd8b49d4..0a7f9559e2 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -16,8 +16,6 @@ from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params -# TODO : - class ComputeQualityMetrics(ResultExtension): """ diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 2741d78ea7..c1cc5524f8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -56,8 +56,17 @@ def sorting_result_simple(): def test_calculate_pc_metrics(sorting_result_simple): sorting_result = sorting_result_simple - res = calculate_pc_metrics(sorting_result) - print(pd.DataFrame(res)) + res1 = calculate_pc_metrics(sorting_result, n_jobs=1, progress_bar=True) + res1 = pd.DataFrame(res1) + + res2 = calculate_pc_metrics(sorting_result, n_jobs=2, progress_bar=True) + res2 = pd.DataFrame(res2) + + for k in res1.columns: + mask = ~np.isnan(res1[k].values) + if np.any(mask): + assert np.array_equal(res1[k].values[mask], res2[k].values[mask]) + def test_nearest_neighbors_isolation(sorting_result_simple): diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 6e88375f3f..51d768e7b5 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -17,29 +17,6 @@ compute_quality_metrics, ) -# generate_ground_truth_recording -# WaveformExtractor, -# NumpySorting, -# compute_sparsity, -# load_extractor, -# extract_waveforms, -# split_recording, -# select_segment_sorting, -# load_waveforms, -# aggregate_units, -# ) -# from spikeinterface.extractors import toy_example - -# from spikeinterface.postprocessing import ( -# compute_principal_components, -# compute_spike_amplitudes, -# compute_spike_locations, -# compute_noise_levels, -# ) -# from spikeinterface.preprocessing import scale -# from spikeinterface.qualitymetrics import QualityMetricCalculator, get_default_qm_params - -# from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite if hasattr(pytest, "global_test_folder"): @@ -137,6 +114,7 @@ def test_compute_quality_metrics_recordingless(sorting_result_simple): # make a copy and make it recordingless sorting_result_norec = sorting_result.save_as(format="memory") + sorting_result_norec.delete_extension("quality_metrics") sorting_result_norec._recording = None assert not sorting_result_norec.has_recording() From 884005c6ba0c0b8476067d95f599c2b710b1892e Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 14 Feb 2024 10:47:35 +0100 Subject: [PATCH 079/192] Cleaning params --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9419a7c77e..15c7d12975 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -40,7 +40,6 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "clustering": {"legacy": False}, "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, - "shared_memory": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.8}, @@ -158,7 +157,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): for k in ["ms_before", "ms_after"]: clustering_params["waveforms"][k] = params["general"][k] - clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs clustering_params["noise_levels"] = noise_levels clustering_params["tmp_folder"] = sorter_output_folder / "clustering" From 7bc93c5697780118b6e5851907307c32537fd460 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 14 Feb 2024 11:13:19 +0100 Subject: [PATCH 080/192] export to phy working --- src/spikeinterface/exporters/tests/common.py | 10 +- .../exporters/tests/test_export_to_phy.py | 106 +++++------------- src/spikeinterface/exporters/to_phy.py | 22 ++-- 3 files changed, 46 insertions(+), 92 deletions(-) diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index e179171ca3..a3a431384c 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -3,7 +3,7 @@ import pytest from pathlib import Path -from spikeinterface.core import generate_ground_truth_recording, start_sorting_result +from spikeinterface.core import generate_ground_truth_recording, start_sorting_result, compute_sparsity if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "exporters" @@ -33,7 +33,13 @@ def make_sorting_result(sparse=True, with_group=False): recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1]) sorting.set_property("group", [0, 0, 1, 1]) - sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) + sorting_result_unused = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=None) + sparsity_group = compute_sparsity(sorting_result_unused, method="by_property", by_property="group") + + sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=sparsity_group) + else: + sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) + sorting_result.select_random_spikes() sorting_result.compute("waveforms") sorting_result.compute("templates") diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index a7d05335a7..c0d5c0587d 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -19,7 +19,28 @@ ) -def test_export_to_phy(sorting_result_sparse_for_export): + + +def test_export_to_phy_dense(sorting_result_dense_for_export): + output_folder1 = cache_folder / "phy_output_dense" + for f in (output_folder1,): + if f.is_dir(): + shutil.rmtree(f) + + sorting_result = sorting_result_dense_for_export + + export_to_phy( + sorting_result, + output_folder1, + compute_pc_features=True, + compute_amplitudes=True, + n_jobs=1, + chunk_size=10000, + progress_bar=True, + ) + + +def test_export_to_phy_sparse(sorting_result_sparse_for_export): output_folder1 = cache_folder / "phy_output_1" output_folder2 = cache_folder / "phy_output_2" for f in (output_folder1, output_folder2): @@ -52,22 +73,20 @@ def test_export_to_phy(sorting_result_sparse_for_export): def test_export_to_phy_by_property(sorting_result_with_group_for_export): - output_folder = cache_folder / "phy_output" - output_folder_rm = cache_folder / "phy_output_rm" + output_folder = cache_folder / "phy_output_property" - for f in (output_folder, output_folder_rm): + for f in (output_folder, ): if f.is_dir(): shutil.rmtree(f) sorting_result = sorting_result_with_group_for_export + print(sorting_result.sparsity) - sparsity_group = compute_sparsity(sorting_result, method="by_property", by_property="group") export_to_phy( sorting_result, output_folder, compute_pc_features=True, compute_amplitudes=True, - sparsity=sparsity_group, n_jobs=1, chunk_size=10000, progress_bar=True, @@ -76,82 +95,13 @@ def test_export_to_phy_by_property(sorting_result_with_group_for_export): template_inds = np.load(output_folder / "template_ind.npy") assert template_inds.shape == (sorting_result.unit_ids.size, 4) - # Remove one channel - # recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7]) - # sorting_result_rm = start_sorting_result(sorting, recording_rm, , sparse=False) - # sparsity_group = compute_sparsity(sorting_result_rm, method="by_property", by_property="group") - - # export_to_phy( - # sorting_result_rm, - # output_folder_rm, - # compute_pc_features=True, - # compute_amplitudes=True, - # sparsity=sparsity_group, - # n_jobs=1, - # chunk_size=10000, - # progress_bar=True, - # ) - - # template_inds = np.load(output_folder_rm / "template_ind.npy") - # assert template_inds.shape == (num_units, 4) - # assert len(np.where(template_inds == -1)[0]) > 0 - - -def test_export_to_phy_by_sparsity(sorting_result_dense_for_export): - output_folder_radius = cache_folder / "phy_output_radius" - output_folder_multi_sparse = cache_folder / "phy_output_multi_sparse" - for f in (output_folder_radius, output_folder_multi_sparse): - if f.is_dir(): - shutil.rmtree(f) - - sorting_result = sorting_result_dense_for_export - - sparsity_radius = compute_sparsity(sorting_result, method="radius", radius_um=50.0) - export_to_phy( - sorting_result, - output_folder_radius, - compute_pc_features=True, - compute_amplitudes=True, - sparsity=sparsity_radius, - n_jobs=1, - chunk_size=10000, - progress_bar=True, - ) - - template_ind = np.load(output_folder_radius / "template_ind.npy") - pc_ind = np.load(output_folder_radius / "pc_feature_ind.npy") - # templates have different shapes! - assert -1 in template_ind - assert -1 in pc_ind - - # pre-compute PC with another sparsity - sparsity_radius_small = compute_sparsity(sorting_result, method="radius", radius_um=30.0) - pc = compute_principal_components(sorting_result, sparsity=sparsity_radius_small) - export_to_phy( - sorting_result, - output_folder_multi_sparse, - compute_pc_features=True, - compute_amplitudes=True, - sparsity=sparsity_radius, - n_jobs=1, - chunk_size=10000, - progress_bar=True, - ) - - template_ind = np.load(output_folder_multi_sparse / "template_ind.npy") - pc_ind = np.load(output_folder_multi_sparse / "pc_feature_ind.npy") - # templates have different shapes! - assert -1 in template_ind - assert -1 in pc_ind - # PC sparsity is more stringent than teplate sparsity - assert pc_ind.shape[1] < template_ind.shape[1] - if __name__ == "__main__": sorting_result_sparse = make_sorting_result(sparse=True) sorting_result_group = make_sorting_result(sparse=False, with_group=True) sorting_result_dense = make_sorting_result(sparse=False) - test_export_to_phy(sorting_result_sparse) + test_export_to_phy_dense(sorting_result_dense) + test_export_to_phy_sparse(sorting_result_sparse) test_export_to_phy_by_property(sorting_result_group) - test_export_to_phy_by_sparsity(sorting_result_dense) + diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index af6bd69c17..4e179e2b4e 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -224,19 +224,17 @@ def export_to_phy( if compute_pc_features: if not sorting_result.has_extension("principal_components"): sorting_result.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + + pca_extension = sorting_result.get_extension("principal_components") - # pc_sparsity = pc.get_sparsity() - # if pc_sparsity is None: - # pc_sparsity = used_sparsity - # max_num_channels_pc = max(len(chan_inds) for chan_inds in pc_sparsity.unit_id_to_channel_indices.values()) - raise NotImplementedError() - # pc.run_for_all_spikes(output_folder / "pc_features.npy", **job_kwargs) - - # pc_feature_ind = -np.ones((len(unit_ids), max_num_channels_pc), dtype="int64") - # for unit_ind, unit_id in enumerate(unit_ids): - # chan_inds = pc_sparsity.unit_id_to_channel_indices[unit_id] - # pc_feature_ind[unit_ind, : len(chan_inds)] = chan_inds - # np.save(str(output_folder / "pc_feature_ind.npy"), pc_feature_ind) + pca_extension.run_for_all_spikes(output_folder / "pc_features.npy", **job_kwargs) + + max_num_channels_pc = max(len(chan_inds) for chan_inds in used_sparsity.unit_id_to_channel_indices.values()) + pc_feature_ind = -np.ones((len(unit_ids), max_num_channels_pc), dtype="int64") + for unit_ind, unit_id in enumerate(unit_ids): + chan_inds = used_sparsity.unit_id_to_channel_indices[unit_id] + pc_feature_ind[unit_ind, : len(chan_inds)] = chan_inds + np.save(str(output_folder / "pc_feature_ind.npy"), pc_feature_ind) # Save .tsv metadata cluster_group = pd.DataFrame( From d4eab0f4897e5fbe40cb5aa41972aefce8e35041 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 14 Feb 2024 11:27:11 +0100 Subject: [PATCH 081/192] More clean --- src/spikeinterface/core/sortingresult.py | 5 +- .../postprocessing/amplitude_scalings.py | 74 +------------------ .../postprocessing/correlograms.py | 44 ----------- src/spikeinterface/postprocessing/isi.py | 41 ---------- .../postprocessing/principal_component.py | 2 +- .../postprocessing/template_similarity.py | 60 --------------- 6 files changed, 6 insertions(+), 220 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index a2a1553181..75a7d00dad 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -971,11 +971,12 @@ def get_extension(self, extension_name: str): if extension_name in self.extensions: return self.extensions[extension_name] - if self.has_extension(extension_name): + elif self.format != "memory" and self.has_extension(extension_name): self.load_extension(extension_name) return self.extensions[extension_name] - return None + else: + return None def load_extension(self, extension_name: str): """ diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 511fd10636..44a8b8d241 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -217,76 +217,6 @@ def _get_data(self): register_result_extension(ComputeAmplitudeScalings) compute_amplitude_scalings = ComputeAmplitudeScalings.function_factory() -# def compute_amplitude_scalings( -# waveform_extractor, -# sparsity=None, -# max_dense_channels=16, -# ms_before=None, -# ms_after=None, -# handle_collisions=True, -# delta_collision_ms=2, -# load_if_exists=False, -# outputs="concatenated", -# **job_kwargs, -# ): -# """ -# Computes the amplitude scalings from a WaveformExtractor. - -# Parameters -# ---------- -# waveform_extractor: WaveformExtractor -# The waveform extractor object -# sparsity: ChannelSparsity or None, default: None -# If waveforms are not sparse, sparsity is required if the number of channels is greater than -# `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. -# max_dense_channels: int, default: 16 -# Maximum number of channels to allow running without sparsity. To compute amplitude scaling using -# dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. -# ms_before : float or None, default: None -# The cut out to apply before the spike peak to extract local waveforms. -# If None, the WaveformExtractor ms_before is used. -# ms_after : float or None, default: None -# The cut out to apply after the spike peak to extract local waveforms. -# If None, the WaveformExtractor ms_after is used. -# handle_collisions: bool, default: True -# Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes -# (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a -# multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. -# delta_collision_ms: float, default: 2 -# The maximum time difference in ms before and after a spike to gather colliding spikes. -# load_if_exists : bool, default: False -# Whether to load precomputed spike amplitudes, if they already exist. -# outputs: "concatenated" | "by_unit", default: "concatenated" -# How the output should be returned -# {} - -# Returns -# ------- -# amplitude_scalings: np.array or list of dict -# The amplitude scalings. -# - If "concatenated" all amplitudes for all spikes and all units are concatenated -# - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) -# """ -# if load_if_exists and waveform_extractor.is_extension(AmplitudeScalingsCalculator.extension_name): -# sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) -# else: -# sac = AmplitudeScalingsCalculator(waveform_extractor) -# sac.set_params( -# sparsity=sparsity, -# max_dense_channels=max_dense_channels, -# ms_before=ms_before, -# ms_after=ms_after, -# handle_collisions=handle_collisions, -# delta_collision_ms=delta_collision_ms, -# ) -# sac.run(**job_kwargs) - -# amps = sac.get_data(outputs=outputs) -# return amps - - -# compute_amplitude_scalings.__doc__.format(_shared_job_kwargs_doc) - class AmplitudeScalingNode(PipelineNode): def __init__( @@ -618,8 +548,8 @@ def fit_collision( # Parameters # ---------- -# we : WaveformExtractor -# The WaveformExtractor object. +# we : SortingResult +# The SortingResult object. # sparsity : ChannelSparsity, default=None # The ChannelSparsity. If None, only main channels are plotted. # num_collisions : int, default=None diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 260a2693e8..5be43ac05a 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -167,50 +167,6 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ return _compute_crosscorr_numba(spike_times1.astype(np.int64), spike_times2.astype(np.int64), window_size, bin_size) -# def compute_correlograms( -# waveform_or_sorting_extractor, -# load_if_exists=False, -# window_ms: float = 50.0, -# bin_ms: float = 1.0, -# method: str = "auto", -# ): -# """Compute auto and cross correlograms. - -# Parameters -# ---------- -# waveform_or_sorting_extractor : WaveformExtractor or BaseSorting -# If WaveformExtractor, the correlograms are saved as WaveformExtensions -# load_if_exists : bool, default: False -# Whether to load precomputed crosscorrelograms, if they already exist -# window_ms : float, default: 100.0 -# The window in ms -# bin_ms : float, default: 5 -# The bin size in ms -# method : "auto" | "numpy" | "numba", default: "auto" -# If "auto" and numba is installed, numba is used, otherwise numpy is used - -# Returns -# ------- -# ccgs : np.array -# Correlograms with shape (num_units, num_units, num_bins) -# The diagonal of ccgs is the auto correlogram. -# ccgs[A, B, :] is the symetrie of ccgs[B, A, :] -# ccgs[A, B, :] have to be read as the histogram of spiketimesA - spiketimesB -# bins : np.array -# The bin edges in ms -# """ -# if isinstance(waveform_or_sorting_extractor, WaveformExtractor): -# if load_if_exists and waveform_or_sorting_extractor.is_extension(CorrelogramsCalculator.extension_name): -# ccc = waveform_or_sorting_extractor.load_extension(CorrelogramsCalculator.extension_name) -# else: -# ccc = CorrelogramsCalculator(waveform_or_sorting_extractor) -# ccc.set_params(window_ms=window_ms, bin_ms=bin_ms, method=method) -# ccc.run() -# ccgs, bins = ccc.get_data() -# return ccgs, bins -# else: -# return compute_correlograms_on_sorting(waveform_or_sorting_extractor, window_ms=window_ms, bin_ms=bin_ms, method=method) - def compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): """ diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index dbcdfc268b..fffe793655 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -69,47 +69,6 @@ def _get_data(self): compute_isi_histograms = ComputeISIHistograms.function_factory() -# def compute_isi_histograms( -# waveform_or_sorting_extractor, -# load_if_exists=False, -# window_ms: float = 50.0, -# bin_ms: float = 1.0, -# method: str = "auto", -# ): -# """Compute ISI histograms. - -# Parameters -# ---------- -# waveform_or_sorting_extractor : WaveformExtractor or BaseSorting -# If WaveformExtractor, the ISI histograms are saved as WaveformExtensions -# load_if_exists : bool, default: False -# Whether to load precomputed crosscorrelograms, if they already exist -# window_ms : float, default: 50 -# The window in ms -# bin_ms : float, default: 1 -# The bin size in ms -# method : "auto" | "numpy" | "numba", default: "auto" -# . If "auto" and numba is installed, numba is used, otherwise numpy is used - -# Returns -# ------- -# isi_histograms : np.array -# IDI_histograms with shape (num_units, num_bins) -# bins : np.array -# The bin edges in ms -# """ -# if isinstance(waveform_or_sorting_extractor, WaveformExtractor): -# if load_if_exists and waveform_or_sorting_extractor.is_extension(ISIHistogramsCalculator.extension_name): -# isic = waveform_or_sorting_extractor.load_extension(ISIHistogramsCalculator.extension_name) -# else: -# isic = ISIHistogramsCalculator(waveform_or_sorting_extractor) -# isic.set_params(window_ms=window_ms, bin_ms=bin_ms, method=method) -# isic.run() -# isi_histograms, bins = isic.get_data() -# return isi_histograms, bins -# else: -# return _compute_isi_histograms(waveform_or_sorting_extractor, window_ms=window_ms, bin_ms=bin_ms, method=method) - def _compute_isi_histograms(sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): """ diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index cb58b732ed..1e8278c31d 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -19,7 +19,7 @@ class ComputePrincipalComponents(ResultExtension): """ Compute PC scores from waveform extractor. The PCA projections are pre-computed only - on the sampled waveforms available from the WaveformExtractor. + on the sampled waveforms available from the extensions "waveforms". Parameters ---------- diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 2c4334dc15..9bd28d5080 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -82,66 +82,6 @@ def compute_template_similarity_by_pair(sorting_result_1, sorting_result_2, meth return similmarity -# def _compute_template_similarity( -# waveform_extractor, load_if_exists=False, method="cosine_similarity", waveform_extractor_other=None -# ): -# import sklearn.metrics.pairwise - -# templates = waveform_extractor.get_all_templates() -# s = templates.shape -# if method == "cosine_similarity": -# templates_flat = templates.reshape(s[0], -1) -# if waveform_extractor_other is not None: -# templates_other = waveform_extractor_other.get_all_templates() -# s_other = templates_other.shape -# templates_other_flat = templates_other.reshape(s_other[0], -1) -# assert len(templates_flat[0]) == len(templates_other_flat[0]), ( -# "Templates from second WaveformExtractor " "don't have the correct shape!" -# ) -# else: -# templates_other_flat = None -# similarity = sklearn.metrics.pairwise.cosine_similarity(templates_flat, templates_other_flat) -# # elif method == '': -# else: -# raise ValueError(f"compute_template_similarity(method {method}) not exists") - -# return similarity - - -# def compute_template_similarity( -# waveform_extractor, load_if_exists=False, method="cosine_similarity", waveform_extractor_other=None -# ): -# """Compute similarity between templates with several methods. - -# Parameters -# ---------- -# waveform_extractor: WaveformExtractor -# A waveform extractor object -# load_if_exists : bool, default: False -# Whether to load precomputed similarity, if is already exists. -# method: str, default: "cosine_similarity" -# The method to compute the similarity -# waveform_extractor_other: WaveformExtractor, default: None -# A second waveform extractor object - -# Returns -# ------- -# similarity: np.array -# The similarity matrix -# """ -# if waveform_extractor_other is None: -# if load_if_exists and waveform_extractor.is_extension(TemplateSimilarityCalculator.extension_name): -# tmc = waveform_extractor.load_extension(TemplateSimilarityCalculator.extension_name) -# else: -# tmc = TemplateSimilarityCalculator(waveform_extractor) -# tmc.set_params(method=method) -# tmc.run() -# similarity = tmc.get_data() -# return similarity -# else: -# return _compute_template_similarity(waveform_extractor, waveform_extractor_other, method) - - def check_equal_template_with_distribution_overlap( waveforms0, waveforms1, template0=None, template1=None, num_shift=2, quantile_limit=0.8, return_shift=False ): From 9032acec78e283cfb515f699ac111a6676f14938 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 14 Feb 2024 12:08:03 +0100 Subject: [PATCH 082/192] Revert patch --- src/spikeinterface/core/job_tools.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 7525805994..6ce098d8f9 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -62,10 +62,9 @@ def fix_job_kwargs(runtime_job_kwargs): ) # remove None runtime_job_kwargs_exclude_none = runtime_job_kwargs.copy() - # Whe should remove these lines, otherwise, we can not reset values for total_memory/chunk_size/... on the fly - # for job_key, job_value in runtime_job_kwargs.items(): - # if job_value is None: - # del runtime_job_kwargs_exclude_none[job_key] + for job_key, job_value in runtime_job_kwargs.items(): + if job_value is None: + del runtime_job_kwargs_exclude_none[job_key] job_kwargs.update(runtime_job_kwargs_exclude_none) # if n_jobs is -1, set to os.cpu_count() (n_jobs is always in global job_kwargs) From d9934bbacd68bd67812417f7c4e8e5c92101f187 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:10:40 +0000 Subject: [PATCH 083/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/job_tools.py | 4 ++-- src/spikeinterface/core/sortingresult.py | 3 +-- .../core/tests/test_sortingresult.py | 3 +-- ...veforms_extractor_backwards_compatibility.py | 1 - ...veforms_extractor_backwards_compatibility.py | 13 ++++++------- src/spikeinterface/exporters/tests/common.py | 8 ++++++-- .../exporters/tests/test_export_to_phy.py | 5 +---- .../exporters/tests/test_report.py | 6 +----- src/spikeinterface/exporters/to_phy.py | 2 +- .../postprocessing/principal_component.py | 17 +++++++---------- .../tests/common_extension_tests.py | 1 + .../tests/test_principal_component.py | 13 ++++++------- .../qualitymetrics/pca_metrics.py | 2 -- .../qualitymetrics/tests/test_pca_metrics.py | 1 - .../tests/test_quality_metric_calculator.py | 1 - .../widgets/all_amplitudes_distributions.py | 2 +- src/spikeinterface/widgets/spike_locations.py | 4 ++-- .../widgets/template_similarity.py | 4 ++-- .../widgets/tests/test_widgets.py | 2 +- src/spikeinterface/widgets/unit_depths.py | 2 +- src/spikeinterface/widgets/unit_locations.py | 2 +- src/spikeinterface/widgets/unit_probe_map.py | 2 +- src/spikeinterface/widgets/unit_summary.py | 2 +- src/spikeinterface/widgets/unit_waveforms.py | 9 ++++++--- .../widgets/unit_waveforms_density_map.py | 2 +- 25 files changed, 50 insertions(+), 61 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 6ce098d8f9..917aaae48c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -63,8 +63,8 @@ def fix_job_kwargs(runtime_job_kwargs): # remove None runtime_job_kwargs_exclude_none = runtime_job_kwargs.copy() for job_key, job_value in runtime_job_kwargs.items(): - if job_value is None: - del runtime_job_kwargs_exclude_none[job_key] + if job_value is None: + del runtime_job_kwargs_exclude_none[job_key] job_kwargs.update(runtime_job_kwargs_exclude_none) # if n_jobs is -1, set to os.cpu_count() (n_jobs is always in global job_kwargs) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index a2a1553181..d331f7e12b 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -584,11 +584,10 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> else: raise ValueError("SortingResult.save: wrong format") - # propagate random_spikes_indices is already done if self.random_spikes_indices is not None: if unit_ids is None: - new_sortres.random_spikes_indices = self.random_spikes_indices.copy() + new_sortres.random_spikes_indices = self.random_spikes_indices.copy() else: # more tricky spikes = self.sorting.to_spike_vector() diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index 0fa5427dd0..a3c204364d 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -137,13 +137,12 @@ def _check_sorting_results(sortres, original_sorting): sortres.compute("dummy") keep_unit_ids = original_sorting.unit_ids[::2] sortres2 = sortres.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) - + # check that random_spikes_indices are remmaped assert sortres2.random_spikes_indices is not None some_spikes = sortres2.sorting.to_spike_vector()[sortres2.random_spikes_indices] assert np.array_equal(np.unique(some_spikes["unit_index"]), np.arange(keep_unit_ids.size)) - # check propagation of result data and correct sligin assert np.array_equal(keep_unit_ids, sortres2.unit_ids) data = sortres2.get_extension("dummy").data diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index 646789423f..538a96343b 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -104,7 +104,6 @@ def test_read_old_waveforms_extractor_binary(): print(data.shape) - if __name__ == "__main__": # test_extract_waveforms() test_read_old_waveforms_extractor_binary() diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index abf9edc86a..412028c94f 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -461,24 +461,21 @@ def _read_old_waveforms_extractor_binary(folder): if len(templates) > 0: ext = ComputeTemplates(sorting_result) ext.params = dict( - nbefore=nbefore, - nafter=nafter, - return_scaled=params["return_scaled"], - operators=list(templates.keys()) + nbefore=nbefore, nafter=nafter, return_scaled=params["return_scaled"], operators=list(templates.keys()) ) for mode, arr in templates.items(): ext.data[mode] = arr sorting_result.extensions["templates"] = ext - # old extensions with same names and equvalent data except similarity>template_similarity + # old extensions with same names and equvalent data except similarity>template_similarity old_extension_to_new_class = { "spike_amplitudes": "spike_amplitudes", "spike_locations": "spike_locations", "amplitude_scalings": "amplitude_scalings", - "template_metrics" : "template_metrics", + "template_metrics": "template_metrics", "similarity": "template_similarity", "unit_locations": "unit_locations", - "correlograms" : "correlograms", + "correlograms": "correlograms", "isi_histograms": "isi_histograms", "noise_levels": "noise_levels", "quality_metrics": "quality_metrics", @@ -505,6 +502,7 @@ def _read_old_waveforms_extractor_binary(folder): ext.data["amplitude_scalings"] = np.load(ext_folder / "amplitude_scalings.npy") elif new_name == "template_metrics": import pandas as pd + ext.data["metrics"] = pd.read_csv(ext_folder / "metrics.csv", index_col=0) elif new_name == "template_similarity": ext.data["similarity"] = np.load(ext_folder / "similarity.npy") @@ -520,6 +518,7 @@ def _read_old_waveforms_extractor_binary(folder): ext.data["noise_levels"] = np.load(ext_folder / "noise_levels.npy") elif new_name == "quality_metrics": import pandas as pd + ext.data["metrics"] = pd.read_csv(ext_folder / "metrics.csv", index_col=0) # elif new_name == "principal_components": # # TODO: alessio this is for you diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index a3a431384c..981fc1c465 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -33,10 +33,14 @@ def make_sorting_result(sparse=True, with_group=False): recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1]) sorting.set_property("group", [0, 0, 1, 1]) - sorting_result_unused = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=None) + sorting_result_unused = start_sorting_result( + sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=None + ) sparsity_group = compute_sparsity(sorting_result_unused, method="by_property", by_property="group") - sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=sparsity_group) + sorting_result = start_sorting_result( + sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=sparsity_group + ) else: sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index c0d5c0587d..3394a9a6d1 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -19,8 +19,6 @@ ) - - def test_export_to_phy_dense(sorting_result_dense_for_export): output_folder1 = cache_folder / "phy_output_dense" for f in (output_folder1,): @@ -75,7 +73,7 @@ def test_export_to_phy_sparse(sorting_result_sparse_for_export): def test_export_to_phy_by_property(sorting_result_with_group_for_export): output_folder = cache_folder / "phy_output_property" - for f in (output_folder, ): + for f in (output_folder,): if f.is_dir(): shutil.rmtree(f) @@ -104,4 +102,3 @@ def test_export_to_phy_by_property(sorting_result_with_group_for_export): test_export_to_phy_dense(sorting_result_dense) test_export_to_phy_sparse(sorting_result_sparse) test_export_to_phy_by_property(sorting_result_group) - diff --git a/src/spikeinterface/exporters/tests/test_report.py b/src/spikeinterface/exporters/tests/test_report.py index cc3a0b2a64..8ed6feb1df 100644 --- a/src/spikeinterface/exporters/tests/test_report.py +++ b/src/spikeinterface/exporters/tests/test_report.py @@ -5,11 +5,7 @@ from spikeinterface.exporters import export_report -from spikeinterface.exporters.tests.common import ( - cache_folder, - make_sorting_result, - sorting_result_sparse_for_export -) +from spikeinterface.exporters.tests.common import cache_folder, make_sorting_result, sorting_result_sparse_for_export def test_export_report(sorting_result_sparse_for_export): diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 4e179e2b4e..4d66c2769a 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -224,7 +224,7 @@ def export_to_phy( if compute_pc_features: if not sorting_result.has_extension("principal_components"): sorting_result.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - + pca_extension = sorting_result.get_extension("principal_components") pca_extension.run_for_all_spikes(output_folder / "pc_features.npy", **job_kwargs) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index cb58b732ed..dc647b5b49 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -101,7 +101,6 @@ def _select_extension_data(self, unit_ids): new_data[k] = v return new_data - def get_pca_model(self): """ Returns the scikit-learn PCA model objects. @@ -143,7 +142,7 @@ def get_projections_one_unit(self, unit_id, sparse=False): """ sparsity = self.sorting_result.sparsity sorting = self.sorting_result.sorting - + if sparse: assert self.params["mode"] != "concatenated", "mode concatenated cannot retrieve sparse projection" assert sparsity is not None, "sparse projection need SortingResult to be sparse" @@ -159,12 +158,14 @@ def get_projections_one_unit(self, unit_id, sparse=False): return projections else: channel_indices = sparsity.unit_id_to_channel_indices[unit_id] - projections = projections[:, :, :channel_indices.size] + projections = projections[:, :, : channel_indices.size] if sparse: return projections, channel_indices else: num_chans = self.sorting_result.get_num_channels() - projections_ = np.zeros((projections.shape[0], projections.shape[1], num_chans), dtype=projections.dtype) + projections_ = np.zeros( + (projections.shape[0], projections.shape[1], num_chans), dtype=projections.dtype + ) projections_[:, :, channel_indices] = projections return projections_ @@ -191,7 +192,7 @@ def get_some_projections(self, channel_ids=None, unit_ids=None): sorting = self.sorting_result.sorting if unit_ids is None: unit_ids = sorting.unit_ids - + if channel_ids is None: channel_ids = self.sorting_result.channel_ids @@ -235,9 +236,8 @@ def get_some_projections(self, channel_ids=None, unit_ids=None): channel_mask = np.isin(channel_indices, local_chan_inds) proj[:, :, channel_mask] = sparse_projection some_projections[spike_mask, :, :] = proj - - return some_projections, spike_unit_indices + return some_projections, spike_unit_indices def project_new(self, new_spikes, new_waveforms, progress_bar=True): """ @@ -350,10 +350,8 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) - waveforms_ext = self.sorting_result.get_extension("waveforms") - # and run func = _all_pc_extractor_chunk init_func = _init_work_all_pc_extractor @@ -630,7 +628,6 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte compute_principal_components = ComputePrincipalComponents.function_factory() - def _partial_fit_one_channel(args): chan_ind, pca_model, wf_chan = args pca_model.partial_fit(wf_chan) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 6736a06ccf..214d0c3f16 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -19,6 +19,7 @@ cache_folder.mkdir(exist_ok=True, parents=True) + def get_dataset(): recording, sorting = generate_ground_truth_recording( durations=[15.0, 5.0], diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 857cc340bc..4205358420 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -8,7 +8,6 @@ from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite, cache_folder - DEBUG = False @@ -35,11 +34,11 @@ def test_mode_concatenated(self): pca = ext.data["pca_projection"] assert pca.ndim == 2 assert pca.shape[1] == n_components - + def test_get_projections(self): for sparse in (False, True): - + sorting_result = self._prepare_sorting_result(format="memory", sparse=sparse) num_chans = sorting_result.get_num_channels() n_components = 2 @@ -61,10 +60,9 @@ def test_get_projections(self): assert one_proj.shape[1] == n_components assert one_proj.shape[2] < num_chans assert one_proj.shape[2] == chan_inds.size - + some_unit_ids = sorting_result.unit_ids[::2] some_channel_ids = sorting_result.channel_ids[::2] - # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) @@ -82,7 +80,9 @@ def test_get_projections(self): assert 1 not in spike_unit_index # this should be some spikes some channels - some_projections, spike_unit_index = ext.get_some_projections(channel_ids=some_channel_ids, unit_ids=some_unit_ids) + some_projections, spike_unit_index = ext.get_some_projections( + channel_ids=some_channel_ids, unit_ids=some_unit_ids + ) assert some_projections.shape[0] == spike_unit_index.shape[0] assert spike_unit_index.shape[0] < sorting_result.random_spikes_indices.size assert some_projections.shape[1] == n_components @@ -140,7 +140,6 @@ def test_project_new(self): # test.test_get_projections() test.test_compute_for_all_spikes() # test.test_project_new() - # ext = test.sorting_results["sparseTrue_memory"].get_extension("principal_components") # pca = ext.data["pca_projection"] diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index d4d45cfd20..f6ac46d24c 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -121,8 +121,6 @@ def calculate_pc_metrics( run_in_parallel = n_jobs > 1 - - if run_in_parallel: parallel_functions = [] diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index c1cc5524f8..29b334e97f 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -68,7 +68,6 @@ def test_calculate_pc_metrics(sorting_result_simple): assert np.array_equal(res1[k].values[mask], res2[k].values[mask]) - def test_nearest_neighbors_isolation(sorting_result_simple): sorting_result = sorting_result_simple this_unit_id = sorting_result.unit_ids[0] diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 51d768e7b5..79bfc4ee50 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -18,7 +18,6 @@ ) - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "qualitymetrics" else: diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 4ba9661d66..f865542018 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -24,7 +24,7 @@ class AllAmplitudesDistributionsWidget(BaseWidget): """ def __init__(self, sorting_result: SortingResult, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs): - + sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, "spike_amplitudes") diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 9427c62f46..b1791f0912 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -52,10 +52,10 @@ def __init__( hide_axis=False, backend=None, **backend_kwargs, - ): + ): sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, "spike_locations") - + spike_locations_by_units = sorting_result.get_extension("spike_locations").get_data(outputs="by_unit") sorting = sorting_result.sorting diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 6800a55b51..39883094cf 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -37,10 +37,10 @@ def __init__( show_colorbar=True, backend=None, **backend_kwargs, - ): + ): sorting_result = self.ensure_sorting_result(sorting_result) self.check_extensions(sorting_result, "template_similarity") - + tsc = sorting_result.get_extension("template_similarity") similarity = tsc.get_data().copy() diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 877383cd4b..1cd1ba477f 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -96,7 +96,7 @@ def setUpClass(cls): cls.sorting_result_sparse.compute(extensions_to_compute, **job_kwargs) cls.skip_backends = ["ipywidgets", "ephyviewer"] - # cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] + # cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] if ON_GITHUB and not KACHERY_CLOUD_SET: cls.skip_backends.append("sortingview") diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index cfc141f396..b99bb2a274 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -29,7 +29,7 @@ class UnitDepthsWidget(BaseWidget): def __init__(self, sorting_result, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs): sorting_result = self.ensure_sorting_result(sorting_result) - + unit_ids = sorting_result.sorting.unit_ids if unit_colors is None: diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 7fc635e419..ec5660fdcc 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -49,7 +49,7 @@ def __init__( **backend_kwargs, ): sorting_result = self.ensure_sorting_result(sorting_result) - + self.check_extensions(sorting_result, "unit_locations") ulc = sorting_result.get_extension("unit_locations") unit_locations = ulc.get_data(outputs="by_unit") diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index 14da93079d..895ef6709c 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -43,7 +43,7 @@ def __init__( **backend_kwargs, ): sorting_result = self.ensure_sorting_result(sorting_result) - + if unit_ids is None: unit_ids = sorting_result.unit_ids self.unit_ids = unit_ids diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 27026f645e..09d0dfa2c9 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -46,7 +46,7 @@ def __init__( ): sorting_result = self.ensure_sorting_result(sorting_result) - + if unit_colors is None: unit_colors = get_unit_colors(sorting_result.sorting) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index ca98888428..65ce40edf0 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -454,19 +454,22 @@ def _update_plot(self, change): templates_ext = self.sorting_result.get_extension("templates") templates = templates_ext.get_templates(unit_ids=unit_ids, operator="average") - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids data_plot["templates"] = templates - templates_shadings = self._get_template_shadings(self.sorting_result, unit_ids, data_plot["templates_percentile_shading"]) + templates_shadings = self._get_template_shadings( + self.sorting_result, unit_ids, data_plot["templates_percentile_shading"] + ) data_plot["templates_shading"] = templates_shadings data_plot["same_axis"] = same_axis data_plot["plot_templates"] = plot_templates data_plot["do_shading"] = do_shading data_plot["scale"] = self.scaler.value if data_plot["plot_waveforms"]: - data_plot["wfs_by_ids"] = {unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids} + data_plot["wfs_by_ids"] = { + unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids + } # TODO option for plot_legend diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 659ba33f80..ce0053e9af 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -49,7 +49,7 @@ def __init__( **backend_kwargs, ): sorting_result = self.ensure_sorting_result(sorting_result) - + if channel_ids is None: channel_ids = sorting_result.channel_ids From 629fe330c0460eb52bf7d9bf3a23223c83c4d9cb Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 14 Feb 2024 13:14:19 +0100 Subject: [PATCH 084/192] DO NOT MERGE --- src/spikeinterface/core/template.py | 47 ++++++------------- .../sorters/internal/spyking_circus2.py | 6 ++- .../sortingcomponents/clustering/circus.py | 17 ++++--- .../clustering/random_projections.py | 19 ++++---- .../sortingcomponents/matching/tdc.py | 19 ++++---- src/spikeinterface/sortingcomponents/tools.py | 21 +++++++++ 6 files changed, 69 insertions(+), 60 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 4b789b184b..c4302bf0b1 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -108,39 +108,25 @@ def __post_init__(self): if not self._are_passed_templates_sparse(): raise ValueError("Sparsity mask passed but the templates are not sparse") - def to_sparse(self, sparsity, remove_empty=True): + def to_sparse(self, sparsity): # Turn a dense representation of templates into a sparse one, given some sparsity. # Templates that are empty after sparsification can be removed via the remove_empty flag assert isinstance(sparsity, ChannelSparsity), "sparsity should be of type ChannelSparsity" assert self.sparsity_mask is None, "Templates should be dense" - if not remove_empty: - return Templates( - templates_array=sparsity.sparsify_templates(self.templates_array), - sampling_frequency=self.sampling_frequency, - nbefore=self.nbefore, - sparsity_mask=sparsity.mask, - channel_ids=self.channel_ids, - unit_ids=self.unit_ids, - probe=self.probe, - check_for_consistent_sparsity=self.check_for_consistent_sparsity, - ) - - else: - templates_array = sparsity.sparsify_templates(self.templates_array) - norms = np.linalg.norm(templates_array, axis=(1, 2)) - not_empty = norms > 0 - new_sparsity = ChannelSparsity(sparsity.mask[not_empty], sparsity.unit_ids[not_empty], sparsity.channel_ids) - return Templates( - templates_array=new_sparsity.sparsify_templates(self.templates_array[not_empty]), - sampling_frequency=self.sampling_frequency, - nbefore=self.nbefore, - sparsity_mask=new_sparsity.mask, - channel_ids=self.channel_ids, - unit_ids=self.unit_ids[not_empty], - probe=self.probe, - check_for_consistent_sparsity=self.check_for_consistent_sparsity, - ) + # if np.any(sparsity.mask.sum(axis=1) == 0): + # print('Warning: some templates are defined on 0 channels. Consider removing them') + + return Templates( + templates_array=sparsity.sparsify_templates(self.templates_array), + sampling_frequency=self.sampling_frequency, + nbefore=self.nbefore, + sparsity_mask=sparsity.mask, + channel_ids=self.channel_ids, + unit_ids=self.unit_ids, + probe=self.probe, + check_for_consistent_sparsity=self.check_for_consistent_sparsity, + ) def get_one_template_dense(self, unit_index): if self.sparsity is None: @@ -381,8 +367,3 @@ def get_channel_locations(self): assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set" channel_locations = self.probe.contact_positions return channel_locations - - -def get_norms_from_templates(templates): - assert isinstance(templates, Templates) - return np.linalg.norm(templates.get_dense_templates(), axis=(1, 2)) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 15c7d12975..52ed56b52e 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -14,6 +14,7 @@ from spikeinterface.sortingcomponents.tools import cache_preprocessing from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity +from spikeinterface.sortingcomponents.tools import remove_empty_templates try: import hdbscan @@ -28,7 +29,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 1}, + "sparsity": {"method": "ptp", "threshold": 5}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { @@ -216,7 +217,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"]) - templates = templates.to_sparse(sparsity, remove_empty=True) + templates = templates.to_sparse(sparsity) + templates = remove_empty_templates(templates) if params["debug"]: templates.to_zarr(folder_path=clustering_folder / "templates") diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 9bdb3b0585..65323dfb86 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -28,6 +28,7 @@ from sklearn.decomposition import TruncatedSVD from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity +from spikeinterface.sortingcomponents.tools import remove_empty_templates import pickle, json from spikeinterface.core.node_pipeline import ( run_node_pipeline, @@ -173,13 +174,11 @@ def main_function(cls, recording, peaks, params): labels = np.unique(peak_labels) labels = labels[labels >= 0] - spikes = np.zeros(len(peaks), dtype=minimum_spike_dtype) - spikes["sample_index"] = peaks["sample_index"] - spikes["segment_index"] = peaks["segment_index"] - spikes["unit_index"] = peak_labels - - if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) + spikes = np.zeros(np.sum(peak_labels > -1), dtype=minimum_spike_dtype) + mask = peak_labels > -1 + spikes["sample_index"] = peaks[mask]["sample_index"] + spikes["segment_index"] = peaks[mask]["segment_index"] + spikes["unit_index"] = peak_labels[mask] unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) @@ -197,6 +196,10 @@ def main_function(cls, recording, peaks, params): params["noise_levels"] = get_noise_levels(recording) sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) + templates = remove_empty_templates(templates) + + if verbose: + print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) cleaning_matching_params = params["job_kwargs"].copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 6b08cf82f7..b864b7dbf0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -27,6 +27,7 @@ from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity +from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.core.node_pipeline import ( run_node_pipeline, ExtractDenseWaveforms, @@ -139,13 +140,11 @@ def main_function(cls, recording, peaks, params): labels = np.unique(peak_labels) labels = labels[labels >= 0] - spikes = np.zeros(len(peaks), dtype=minimum_spike_dtype) - spikes["sample_index"] = peaks["sample_index"] - spikes["segment_index"] = peaks["segment_index"] - spikes["unit_index"] = peak_labels - - if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) + spikes = np.zeros(np.sum(peak_labels > -1), dtype=minimum_spike_dtype) + mask = peak_labels > -1 + spikes["sample_index"] = peaks[mask]["sample_index"] + spikes["segment_index"] = peaks[mask]["segment_index"] + spikes["unit_index"] = peak_labels[mask] unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) @@ -162,7 +161,11 @@ def main_function(cls, recording, peaks, params): if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording) sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) - templates = templates.to_sparse(sparsity, remove_empty=True) + templates = templates.to_sparse(sparsity) + templates = remove_empty_templates(templates) + + if verbose: + print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) cleaning_matching_params = params["job_kwargs"].copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 82fa5e3224..e00fa16ccd 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -107,16 +107,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): channel_distance = get_channel_distances(recording) d["neighbours_mask"] = channel_distance < d["radius_um"] - # sparsity = compute_sparsity( - # templates, method="best_channels" - # ) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) - # template_sparsity_inds = sparsity.unit_id_to_channel_indices - # template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - # for unit_index, unit_id in enumerate(unit_ids): - # chan_inds = template_sparsity_inds[unit_id] - # template_sparsity[unit_index, chan_inds] = True - - template_sparsity = templates.sparsity.mask + sparsity = compute_sparsity( + templates, method="best_channels" + ) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) + template_sparsity_inds = sparsity.unit_id_to_channel_indices + template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") + for unit_index, unit_id in enumerate(unit_ids): + chan_inds = template_sparsity_inds[unit_id] + template_sparsity[unit_index, chan_inds] = True + d["template_sparsity"] = template_sparsity extremum_channel = get_template_extremum_channel(templates, peak_sign=d["peak_sign"], outputs="index") diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 794df36bf4..adf201685c 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -100,3 +100,24 @@ def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache recording = recording.save_to_zarr(**extra_kwargs) return recording + + +def remove_empty_templates(templates): + + from spikeinterface.core.sparsity import ChannelSparsity + from spikeinterface.core.template import Templates + + templates_array = templates.get_dense_templates() + not_empty = templates.sparsity_mask.sum(axis=1) > 0 + sparse = np.zeros((len(not_empty), templates.)) + new_sparsity = ChannelSparsity(, templates.unit_ids[not_empty], templates.channel_ids) + return Templates( + templates_array=new_sparsity.sparsify_templates(templates_array[not_empty]), + sampling_frequency=templates.sampling_frequency, + nbefore=templates.nbefore, + sparsity_mask=new_sparsity.mask, + channel_ids=templates.channel_ids, + unit_ids=templates.unit_ids[not_empty], + probe=templates.probe, + check_for_consistent_sparsity=templates.check_for_consistent_sparsity, + ) From 49f7c0bc93ebba068699c22f53678e5db3f22846 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Feb 2024 12:17:11 +0000 Subject: [PATCH 085/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/tdc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index e00fa16ccd..cb8dd226d6 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -108,13 +108,13 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["neighbours_mask"] = channel_distance < d["radius_um"] sparsity = compute_sparsity( - templates, method="best_channels" + templates, method="best_channels" ) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) template_sparsity_inds = sparsity.unit_id_to_channel_indices template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") for unit_index, unit_id in enumerate(unit_ids): - chan_inds = template_sparsity_inds[unit_id] - template_sparsity[unit_index, chan_inds] = True + chan_inds = template_sparsity_inds[unit_id] + template_sparsity[unit_index, chan_inds] = True d["template_sparsity"] = template_sparsity From 302eabe1a4951818006d0b74e9d87900c5d4d4c5 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 14 Feb 2024 14:10:10 +0100 Subject: [PATCH 086/192] MOre clean --- src/spikeinterface/core/sortingresult.py | 3 +++ src/spikeinterface/widgets/base.py | 27 +++++++++++++++------ src/spikeinterface/widgets/rasters.py | 2 ++ src/spikeinterface/widgets/unit_presence.py | 4 ++- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 75a7d00dad..da793c4707 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -153,6 +153,9 @@ class SortingResult: This internally save a copy of the Sorting and extract main recording attributes (without traces) so the SortingResult object can be reload even if references to the original sorting and/or to the original recording are lost. + + SortingResult() should not never be used directly for creating: use instead start_sorting_result(sorting, resording, ...) + or eventually SortingResult.create(...) """ def __init__( diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 26cfd0fa23..eb715b1c3a 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -5,7 +5,7 @@ global default_backend_ default_backend_ = "matplotlib" -from ..core import SortingResult +from ..core import SortingResult, BaseSorting from ..core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor @@ -106,13 +106,26 @@ def do_plot(self): func(self.data_plot, **self.backend_kwargs) @classmethod - def ensure_sorting_result(cls, sorting_result_or_waveform_extractor): - if isinstance(sorting_result_or_waveform_extractor, SortingResult): - return sorting_result_or_waveform_extractor - elif isinstance(sorting_result_or_waveform_extractor, MockWaveformExtractor): - return sorting_result_or_waveform_extractor.sorting_result + def ensure_sorting_result(cls, input): + # internal help to accept both SortingResult or MockWaveformExtractor for a ploter + if isinstance(input, SortingResult): + return input + elif isinstance(input, MockWaveformExtractor): + return input.sorting_result else: - return sorting_result_or_waveform_extractor + return input + + @classmethod + def ensure_sorting(cls, input): + # internal help to accept both Sorting or SortingResult or MockWaveformExtractor for a ploter + if isinstance(input, BaseSorting): + return input + elif isinstance(input, SortingResult): + return input.sorting + elif isinstance(input, MockWaveformExtractor): + return input.sorting_result.sorting + else: + return input @staticmethod def check_extensions(sorting_result, extensions): diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 0e8b902e03..a460a8e179 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -27,6 +27,8 @@ class RasterWidget(BaseWidget): def __init__( self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", backend=None, **backend_kwargs ): + sorting = self.ensure_sorting(sorting) + if segment_index is None: if sorting.get_num_segments() != 1: raise ValueError("You must provide segment_index=...") diff --git a/src/spikeinterface/widgets/unit_presence.py b/src/spikeinterface/widgets/unit_presence.py index fa6f3c69f7..69f673b0db 100644 --- a/src/spikeinterface/widgets/unit_presence.py +++ b/src/spikeinterface/widgets/unit_presence.py @@ -32,7 +32,9 @@ def __init__( smooth_sigma=4.5, backend=None, **backend_kwargs, - ): + ): + sorting = self.ensure_sorting(sorting) + if segment_index is None: nseg = sorting.get_num_segments() if nseg != 1: From 5de4e28dc8cde96601b1f6327ed57f1cedebfd97 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 14 Feb 2024 14:16:56 +0100 Subject: [PATCH 087/192] remove_empty_templates() --- src/spikeinterface/sortingcomponents/tools.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index adf201685c..3d7e40da14 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -9,6 +9,9 @@ except: HAVE_PSUTIL = False +from spikeinterface.core.sparsity import ChannelSparsity +from spikeinterface.core.template import Templates + from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer from spikeinterface.core.job_tools import split_job_kwargs @@ -103,21 +106,18 @@ def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache def remove_empty_templates(templates): - - from spikeinterface.core.sparsity import ChannelSparsity - from spikeinterface.core.template import Templates - - templates_array = templates.get_dense_templates() + """ + Clean A Template with sparse representtaion by removing units that have no channel + on the sparsity mask + """ + assert templates.sparsity_mask is not None, "Need sparse Templates object" not_empty = templates.sparsity_mask.sum(axis=1) > 0 - sparse = np.zeros((len(not_empty), templates.)) - new_sparsity = ChannelSparsity(, templates.unit_ids[not_empty], templates.channel_ids) return Templates( - templates_array=new_sparsity.sparsify_templates(templates_array[not_empty]), + templates_array=templates.templates_array[not_empty, :, :], sampling_frequency=templates.sampling_frequency, nbefore=templates.nbefore, - sparsity_mask=new_sparsity.mask, + sparsity_mask=templates.sparsity_mask[not_empty, :], channel_ids=templates.channel_ids, unit_ids=templates.unit_ids[not_empty], probe=templates.probe, - check_for_consistent_sparsity=templates.check_for_consistent_sparsity, ) From 521a0973dbc093c4f4f0ed59eb9f9c908f8b53d8 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 14 Feb 2024 14:26:17 +0100 Subject: [PATCH 088/192] Update src/spikeinterface/core/sparsity.py --- src/spikeinterface/core/sparsity.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 9de9856aa1..550aa55ebc 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -222,8 +222,7 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo def sparsify_templates(self, templates_array: np.ndarray) -> np.ndarray: max_num_active_channels = self.max_num_active_channels - num_samples = templates_array.shape[1] - sparsified_shape = (self.num_units, num_samples, max_num_active_channels) + sparsified_shape = (self.num_units, self.num_samples, max_num_active_channels) sparse_templates = np.zeros(shape=sparsified_shape, dtype=templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): template = templates_array[unit_index, ...] From 46967c1560208d4625a8183c2b6bed8b00788595 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 14 Feb 2024 17:59:36 +0100 Subject: [PATCH 089/192] Fix tests for template matching --- src/spikeinterface/core/sparsity.py | 5 +- src/spikeinterface/core/template.py | 2 +- .../sortingcomponents/matching/method_list.py | 2 +- .../tests/test_template_matching.py | 132 +++++++++--------- .../test_waveform_thresholder.py | 28 ++-- 5 files changed, 84 insertions(+), 85 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 550aa55ebc..fd80fbf181 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -221,8 +221,11 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo return int(excess_zeros) == 0 def sparsify_templates(self, templates_array: np.ndarray) -> np.ndarray: + assert templates_array.shape[0] == self.num_units + assert templates_array.shape[2] == self.num_channels + max_num_active_channels = self.max_num_active_channels - sparsified_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparsified_shape = (self.num_units, templates_array.shape[1], max_num_active_channels) sparse_templates = np.zeros(shape=sparsified_shape, dtype=templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): template = templates_array[unit_index, ...] diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index c4302bf0b1..d85faa7513 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -110,7 +110,7 @@ def __post_init__(self): def to_sparse(self, sparsity): # Turn a dense representation of templates into a sparse one, given some sparsity. - # Templates that are empty after sparsification can be removed via the remove_empty flag + # Note that nothing prevent Templates tobe empty after sparsification if the sparse mask have no channels for some units assert isinstance(sparsity, ChannelSparsity), "sparsity should be of type ChannelSparsity" assert self.sparsity_mask is None, "Templates should be dense" diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index bd8dfd21bc..ca6c0db924 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -7,7 +7,7 @@ matching_methods = { "naive": NaiveMatching, - "tridesclous": TridesclousPeeler, + "tdc-peeler": TridesclousPeeler, "circus": CircusPeeler, "circus-omp-svd": CircusOMPSVDPeeler, "wobble": WobbleMatch, diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index 35c7617c47..e9d1017be7 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -2,96 +2,92 @@ import numpy as np from pathlib import Path -from spikeinterface import NumpySorting -from spikeinterface import extract_waveforms -from spikeinterface.core import get_noise_levels +from spikeinterface import NumpySorting, start_sorting_result, get_noise_levels, compute_sparsity from spikeinterface.sortingcomponents.matching import find_spikes_from_templates, matching_methods from spikeinterface.sortingcomponents.tests.common import make_dataset -DEBUG = False -def make_waveform_extractor(): - recording, sorting = make_dataset() - waveform_extractor = extract_waveforms( - recording=recording, - sorting=sorting, - folder=None, - mode="memory", - ms_before=1, - ms_after=2.0, - max_spikes_per_unit=500, - return_scaled=False, - n_jobs=1, - chunk_size=30000, - ) - return waveform_extractor +job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) + +def get_sorting_result(): + recording, sorting = make_dataset() + sorting_result = start_sorting_result(sorting, recording, sparse=False) + sorting_result.select_random_spikes() + sorting_result.compute("fast_templates", **job_kwargs) + sorting_result.compute("noise_levels") + return sorting_result -@pytest.fixture(name="waveform_extractor", scope="module") -def waveform_extractor_fixture(): - return make_waveform_extractor() +@pytest.fixture(name="sorting_result", scope="module") +def sorting_result_fixture(): + return get_sorting_result() @pytest.mark.parametrize("method", matching_methods.keys()) -def test_find_spikes_from_templates(method, waveform_extractor): - recording = waveform_extractor._recording - waveform = waveform_extractor.get_waveforms(waveform_extractor.unit_ids[0]) - num_waveforms, _, _ = waveform.shape - assert num_waveforms != 0 - method_kwargs_all = {"waveform_extractor": waveform_extractor, "noise_levels": get_noise_levels(recording)} +def test_find_spikes_from_templates(method, sorting_result): + recording = sorting_result.recording + # waveform = waveform_extractor.get_waveforms(waveform_extractor.unit_ids[0]) + # num_waveforms, _, _ = waveform.shape + # assert num_waveforms != 0 + + templates = sorting_result.get_extension("fast_templates").get_data(outputs="Templates") + sparsity = compute_sparsity(sorting_result, method="snr", threshold=2) + templates = templates.to_sparse(sparsity) + + noise_levels = sorting_result.get_extension("noise_levels").get_data() + + # sorting_result + method_kwargs_all = {"templates": templates, "noise_levels": noise_levels} method_kwargs = {} - method_kwargs["wobble"] = { - "templates": waveform_extractor.get_all_templates(), - "nbefore": waveform_extractor.nbefore, - "nafter": waveform_extractor.nafter, - } + # method_kwargs["wobble"] = { + # "templates": waveform_extractor.get_all_templates(), + # "nbefore": waveform_extractor.nbefore, + # "nafter": waveform_extractor.nafter, + # } sampling_frequency = recording.get_sampling_frequency() - result = {} - method_kwargs_ = method_kwargs.get(method, {}) method_kwargs_.update(method_kwargs_all) spikes = find_spikes_from_templates( - recording, method=method, method_kwargs=method_kwargs_, n_jobs=2, chunk_size=1000, progress_bar=True + recording, method=method, method_kwargs=method_kwargs_, **job_kwargs ) - result[method] = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) - - # debug - if DEBUG: - import matplotlib.pyplot as plt - import spikeinterface.full as si - - plt.ion() - - metrics = si.compute_quality_metrics( - waveform_extractor, - metric_names=["snr"], - load_if_exists=True, - ) - - comparisons = {} - for method in matching_methods.keys(): - comp = si.compare_sorter_to_ground_truth(gt_sorting, result[method]) - comparisons[method] = comp - si.plot_agreement_matrix(comp) - plt.title(method) - si.plot_sorting_performance( - comp, - metrics, - performance_name="accuracy", - metric_name="snr", - ) - plt.title(method) - plt.show() + + + # DEBUG = False + + # if DEBUG: + # import matplotlib.pyplot as plt + # import spikeinterface.full as si + + # sorting_result.compute("waveforms") + # sorting_result.compute("templates") + + + # gt_sorting = sorting_result.sorting + + # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) + + # metrics = si.compute_quality_metrics(sorting_result, metric_names=["snr"]) + + # fig, ax = plt.subplots() + # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) + # si.plot_agreement_matrix(comp, ax=ax) + # ax.set_title(method) + # plt.show() if __name__ == "__main__": - waveform_extractor = make_waveform_extractor() - method = "naive" - test_find_spikes_from_templates(method, waveform_extractor) + sorting_result = get_sorting_result() + # method = "naive" + # method = "tdc-peeler" + # method = "circus" + method = "circus-omp-svd" + # method = "wobble" + test_find_spikes_from_templates(method, sorting_result) + diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py index 572e6c36c1..fdbabc7584 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py @@ -9,7 +9,7 @@ @pytest.fixture(scope="module") -def extract_waveforms(generated_recording): +def extract_dense_waveforms_node(generated_recording): # Parameters ms_before = 1.0 ms_after = 1.0 @@ -20,16 +20,16 @@ def extract_waveforms(generated_recording): ) -def test_waveform_thresholder_ptp(extract_waveforms, generated_recording, detected_peaks, chunk_executor_kwargs): +def test_waveform_thresholder_ptp(extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs): recording = generated_recording peaks = detected_peaks tresholded_waveforms_ptp = WaveformThresholder( - recording=recording, parents=[extract_waveforms], feature="ptp", threshold=3, return_output=True + recording=recording, parents=[extract_dense_waveforms_node], feature="ptp", threshold=3, return_output=True ) noise_levels = tresholded_waveforms_ptp.noise_levels - pipeline_nodes = [extract_waveforms, tresholded_waveforms_ptp] + pipeline_nodes = [extract_dense_waveforms_node, tresholded_waveforms_ptp] # Extract projected waveforms and compare waveforms, tresholded_waveforms = run_peak_pipeline( recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs @@ -39,15 +39,15 @@ def test_waveform_thresholder_ptp(extract_waveforms, generated_recording, detect assert np.all(data[data != 0] > 3) -def test_waveform_thresholder_mean(extract_waveforms, generated_recording, detected_peaks, chunk_executor_kwargs): +def test_waveform_thresholder_mean(extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs): recording = generated_recording peaks = detected_peaks tresholded_waveforms_mean = WaveformThresholder( - recording=recording, parents=[extract_waveforms], feature="mean", threshold=0, return_output=True + recording=recording, parents=[extract_dense_waveforms_node], feature="mean", threshold=0, return_output=True ) - pipeline_nodes = [extract_waveforms, tresholded_waveforms_mean] + pipeline_nodes = [extract_dense_waveforms_node, tresholded_waveforms_mean] # Extract projected waveforms and compare waveforms, tresholded_waveforms = run_peak_pipeline( recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs @@ -56,16 +56,16 @@ def test_waveform_thresholder_mean(extract_waveforms, generated_recording, detec assert np.all(tresholded_waveforms.mean(axis=1) >= 0) -def test_waveform_thresholder_energy(extract_waveforms, generated_recording, detected_peaks, chunk_executor_kwargs): +def test_waveform_thresholder_energy(extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs): recording = generated_recording peaks = detected_peaks tresholded_waveforms_energy = WaveformThresholder( - recording=recording, parents=[extract_waveforms], feature="energy", threshold=3, return_output=True + recording=recording, parents=[extract_dense_waveforms_node], feature="energy", threshold=3, return_output=True ) noise_levels = tresholded_waveforms_energy.noise_levels - pipeline_nodes = [extract_waveforms, tresholded_waveforms_energy] + pipeline_nodes = [extract_dense_waveforms_node, tresholded_waveforms_energy] # Extract projected waveforms and compare waveforms, tresholded_waveforms = run_peak_pipeline( recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs @@ -75,7 +75,7 @@ def test_waveform_thresholder_energy(extract_waveforms, generated_recording, det assert np.all(data[data != 0] > 3) -def test_waveform_thresholder_operator(extract_waveforms, generated_recording, detected_peaks, chunk_executor_kwargs): +def test_waveform_thresholder_operator(extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs): recording = generated_recording peaks = detected_peaks @@ -83,7 +83,7 @@ def test_waveform_thresholder_operator(extract_waveforms, generated_recording, d tresholded_waveforms_peak = WaveformThresholder( recording=recording, - parents=[extract_waveforms], + parents=[extract_dense_waveforms_node], feature="peak_voltage", threshold=5, operator=operator.ge, @@ -91,11 +91,11 @@ def test_waveform_thresholder_operator(extract_waveforms, generated_recording, d ) noise_levels = tresholded_waveforms_peak.noise_levels - pipeline_nodes = [extract_waveforms, tresholded_waveforms_peak] + pipeline_nodes = [extract_dense_waveforms_node, tresholded_waveforms_peak] # Extract projected waveforms and compare waveforms, tresholded_waveforms = run_peak_pipeline( recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs ) - data = tresholded_waveforms[:, extract_waveforms.nbefore, :] / noise_levels + data = tresholded_waveforms[:, extract_dense_waveforms_node.nbefore, :] / noise_levels assert np.all(data[data != 0] <= 5) From 9e2ccdd9b9063c05375bb110fb1df48c4fb44624 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 14 Feb 2024 21:09:30 +0100 Subject: [PATCH 090/192] Remoce extract_waveforms from tridesclous2 --- .../sorters/internal/tridesclous2.py | 42 +++++++++++++------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 782758178e..15809bfe54 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -6,9 +6,11 @@ from spikeinterface.core import ( get_noise_levels, - extract_waveforms, NumpySorting, get_channel_distances, + estimate_templates_average, + Templates, + compute_sparsity, ) from spikeinterface.core.job_tools import fix_job_kwargs @@ -277,33 +279,47 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): new_peaks["sample_index"] -= peak_shifts # clean very small cluster before peeler + post_clean_label = post_merge_label.copy() + minimum_cluster_size = 25 - labels_set, count = np.unique(post_merge_label, return_counts=True) + labels_set, count = np.unique(post_clean_label, return_counts=True) to_remove = labels_set[count < minimum_cluster_size] - - mask = np.isin(post_merge_label, to_remove) - post_merge_label[mask] = -1 + mask = np.isin(post_clean_label, to_remove) + post_clean_label[mask] = -1 # final label sets - labels_set = np.unique(post_merge_label) + labels_set = np.unique(post_clean_label) labels_set = labels_set[labels_set >= 0] - mask = post_merge_label >= 0 - sorting_temp = NumpySorting.from_times_labels( + mask = post_clean_label >= 0 + sorting_pre_peeler = NumpySorting.from_times_labels( new_peaks["sample_index"][mask], post_merge_label[mask], sampling_frequency, unit_ids=labels_set, ) - sorting_temp = sorting_temp.save(folder=sorter_output_folder / "sorting_temp") + # sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler") + + + nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.) + nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.) + templates_array = estimate_templates_average(recording, sorting_pre_peeler.to_spike_vector(), sorting_pre_peeler.unit_ids, + nbefore, nafter, return_scaled=False, **job_kwargs) + templates_dense = Templates( + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + probe=recording.get_probe() + ) + # TODO : try other methods for sparsity + sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.) + templates = templates_dense.to_sparse(sparsity) - we = extract_waveforms(recording, sorting_temp, sorter_output_folder / "waveforms_temp", **params["templates"]) # snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum") # print(snrs) # matching_params = params["matching"].copy() - # matching_params["waveform_extractor"] = we # matching_params["noise_levels"] = noise_levels # matching_params["peak_sign"] = params["detection"]["peak_sign"] # matching_params["detect_threshold"] = params["detection"]["detect_threshold"] @@ -316,7 +332,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_method = params["matching"]["method"] matching_params = params["matching"]["method_kwargs"].copy() - matching_params["waveform_extractor"] = we + matching_params["templates"] = templates matching_params["noise_levels"] = noise_levels # matching_params["peak_sign"] = params["detection"]["peak_sign"] # matching_params["detect_threshold"] = params["detection"]["detect_threshold"] @@ -339,6 +355,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) if params["save_array"]: + sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler") + np.save(sorter_output_folder / "noise_levels.npy", noise_levels) np.save(sorter_output_folder / "all_peaks.npy", all_peaks) np.save(sorter_output_folder / "post_split_label.npy", post_split_label) From e59b735f90ef811648f817465f095ae914a11acb Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 14 Feb 2024 21:22:43 +0100 Subject: [PATCH 091/192] tdc2 sparsity --- src/spikeinterface/sorters/internal/tridesclous2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 15809bfe54..de4e2d44ec 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -312,7 +312,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe=recording.get_probe() ) # TODO : try other methods for sparsity - sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.) + # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.) + sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=1.) templates = templates_dense.to_sparse(sparsity) From bb5520c268bf52deb7606c255dee3e595ef3b467 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 15 Feb 2024 13:31:46 +0100 Subject: [PATCH 092/192] Rmove more WaveformExtractor --- src/spikeinterface/core/result_core.py | 4 ++- src/spikeinterface/core/sortingresult.py | 6 ++-- src/spikeinterface/core/sparsity.py | 15 +++++----- .../tests/test_quality_metric_calculator.py | 2 +- .../sortingcomponents/clustering/main.py | 4 +-- .../clustering/position_and_features.py | 28 ++++++++++--------- .../sortingcomponents/matching/main.py | 4 --- .../sortingcomponents/matching/naive.py | 1 - .../sortingcomponents/matching/tdc.py | 1 - 9 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/result_core.py index 60dab77994..c6be619319 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/result_core.py @@ -4,7 +4,9 @@ * ComputeTemplates Theses two classes replace the WaveformExtractor -It also implement ComputeFastTemplates which is equivalent but without extacting waveforms. +It also implement: + * ComputeFastTemplates which is equivalent but without extacting waveforms. + * ComputeNoiseLevels which is very convinient to have """ import numpy as np diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index bd36408ba1..9d3b1a0aaa 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -641,14 +641,14 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingResult Parameters ---------- unit_ids : list or array - The unit ids to keep in the new WaveformExtractor object + The unit ids to keep in the new SortingResult object folder : Path or None The new folder where selected waveforms are copied format: a Returns ------- - we : WaveformExtractor + we : SortingResult The newly create waveform extractor with the selected units """ # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! @@ -1354,7 +1354,7 @@ def load_data(self): self.data[ext_data_name] = ext_data def copy(self, new_sorting_result, unit_ids=None): - # alessio : please note that this also replace the old BaseWaveformExtractorExtension.select_units!!! + # alessio : please note that this also replace the old select_units!!! new_extension = self.__class__(new_sorting_result) new_extension.params = self.params.copy() if unit_ids is None: diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index fd80fbf181..cbbedfec6c 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -71,26 +71,26 @@ class ChannelSparsity: Examples -------- - The class can also be used to construct/estimate the sparsity from a Waveformextractor + The class can also be used to construct/estimate the sparsity from a SortingResult or a Templates with several methods: Using the N best channels (largest template amplitude): - >>> sparsity = ChannelSparsity.from_best_channels(we, num_channels, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_best_channels(sorting_result, num_channels, peak_sign="neg") Using a neighborhood by radius: - >>> sparsity = ChannelSparsity.from_radius(we, radius_um, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_radius(sorting_result, radius_um, peak_sign="neg") Using a SNR threshold: - >>> sparsity = ChannelSparsity.from_snr(we, threshold, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_snr(sorting_result, threshold, peak_sign="neg") Using a template energy threshold: - >>> sparsity = ChannelSparsity.from_energy(we, threshold) + >>> sparsity = ChannelSparsity.from_energy(sorting_result, threshold) Using a recording/sorting property (e.g. "group"): - >>> sparsity = ChannelSparsity.from_property(we, by_property="group") + >>> sparsity = ChannelSparsity.from_property(sorting_result, by_property="group") """ @@ -481,7 +481,6 @@ def compute_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - from .waveform_extractor import WaveformExtractor from .waveforms_extractor_backwards_compatibility import MockWaveformExtractor from .sortingresult import SortingResult @@ -549,7 +548,7 @@ def estimate_sparsity( **job_kwargs, ): """ - Estimate the sparsity without needing a WaveformExtractor. + Estimate the sparsity without needing a SortingResult or Templates object This is faster than `spikeinterface.waveforms_extractor.precompute_sparsity()` and it traverses the recording to compute the average templates for each unit. diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 79bfc4ee50..4d94371de5 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -164,7 +164,7 @@ def test_empty_units(sorting_result_simple): assert np.all(np.isnan(metrics_empty.loc[empty_unit_id])) -# @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics() +# TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics() # def test_amplitude_cutoff(self): # we = self.we_short diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 151f4be270..4cb0db7db6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -14,8 +14,8 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, ---------- recording: RecordingExtractor The recording extractor object - peaks: WaveformExtractor - The waveform extractor + peaks: numpy.array + The peak vector method: str Which method to use ("stupid" | "XXXX") method_kwargs: dict, default: dict() diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index 871c0aab31..f317706838 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -14,12 +14,11 @@ HAVE_HDBSCAN = False import random, string, os -from spikeinterface.core import get_global_tmp_folder, get_noise_levels, get_channel_distances -from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler +from spikeinterface.core import get_global_tmp_folder, get_noise_levels from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting -from spikeinterface.core import extract_waveforms +from spikeinterface.core import estimate_templates_average, Templates from spikeinterface.sortingcomponents.features_from_peaks import compute_features_from_peaks @@ -169,18 +168,21 @@ def main_function(cls, recording, peaks, params): tmp_folder = Path(os.path.join(get_global_tmp_folder(), name)) sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) - we = extract_waveforms( - recording, - sorting, - tmp_folder, - overwrite=True, - ms_before=params["ms_before"], - ms_after=params["ms_after"], - **params["job_kwargs"], - return_scaled=False, + + nbefore = int(params["ms_before"] * fs / 1000.) + nafter = int(params["ms_after"] * fs / 1000.) + templates_array = estimate_templates_average(recording, sorting.to_spike_vector(), sorting.unit_ids, + nbefore, nafter, return_scaled=False, **params["job_kwargs"]) + templates = Templates( + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + probe=recording.get_probe() ) + + labels, peak_labels = remove_duplicates_via_matching( - we, peak_labels, job_kwargs=params["job_kwargs"], **params["cleaning_kwargs"] + templates, peak_labels, job_kwargs=params["job_kwargs"], **params["cleaning_kwargs"] ) shutil.rmtree(tmp_folder) diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 37eb4d2ec4..1c5c947b02 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -30,10 +30,6 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr method_kwargs: Optionaly returns for debug purpose. - Notes - ----- - For all methods except "wobble", templates are represented as a WaveformExtractor in method_kwargs - so statistics can be extracted. For "wobble" templates are represented as a numpy.ndarray. """ from .method_list import matching_methods diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 951c61b5cb..c172e90fd8 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -4,7 +4,6 @@ import numpy as np -from spikeinterface.core import WaveformExtractor, get_template_extremum_channel from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive from spikeinterface.core.template import Templates diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index cb8dd226d6..44a7aa00ee 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -3,7 +3,6 @@ import numpy as np import scipy from spikeinterface.core import ( - WaveformExtractor, get_noise_levels, get_channel_distances, compute_sparsity, From 1a4bb30613f7448a7e2da5ce69699fca035b074c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 15 Feb 2024 14:32:02 +0100 Subject: [PATCH 093/192] remove extract_waveforms() from temporal_pca --- .../waveforms/temporal_pca.py | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 32b79aa7fb..4e640ea044 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -14,7 +14,7 @@ from spikeinterface.postprocessing import compute_principal_components from spikeinterface.core import BaseRecording from spikeinterface.core.sparsity import ChannelSparsity -from spikeinterface import extract_waveforms, NumpySorting +from spikeinterface import NumpySorting, start_sorting_result from spikeinterface.core.job_tools import _shared_job_kwargs_doc from .waveform_utils import to_temporal_representation, from_temporal_representation @@ -138,25 +138,14 @@ def fit( # Creates a numpy sorting object where the spike times are the peak times and the unit ids are the peak channel sorting = NumpySorting.from_peaks(peaks, recording.sampling_frequency, recording.channel_ids) - # Create a waveform extractor - we = extract_waveforms( - recording, - sorting, - ms_before=ms_before, - ms_after=ms_after, - folder=None, - mode="memory", - max_spikes_per_unit=None, - **job_kwargs, - ) - # compute PCA by_channel_global (with sparsity) - sparsity = ChannelSparsity.from_radius(we, radius_um=radius_um) if radius_um else None - pc = compute_principal_components( - we, n_components=n_components, mode="by_channel_global", sparsity=sparsity, whiten=whiten - ) + # TODO alessio, herberto : the fitting is done with a SortingResult which is a postprocessing object, I think we should not do this for a component + sorting_result = start_sorting_result(sorting, recording, sparse=True) + sorting_result.select_random_spikes() + sorting_result.compute("waveforms", ms_before=ms_before, ms_after=ms_after) + sorting_result.compute("principal_components", n_components=n_components, mode="by_channel_global", whiten=whiten) + pca_model = sorting_result.get_extension("principal_components").get_pca_model() - pca_model = pc.get_pca_model() params = { "ms_before": ms_before, "ms_after": ms_after, From 1531c6b82687e09cf9fd8555228dc2fbeae7dc0c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 15 Feb 2024 14:48:46 +0100 Subject: [PATCH 094/192] Clean some TODOs --- .../postprocessing/amplitude_scalings.py | 10 +---- .../postprocessing/spike_amplitudes.py | 1 - .../postprocessing/spike_locations.py | 4 -- .../postprocessing/template_metrics.py | 6 --- .../tests/test_principal_component.py | 10 ++--- .../tests/test_template_similarity.py | 40 ++++++++++++------- 6 files changed, 31 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 44a8b8d241..18526613de 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -71,10 +71,6 @@ class ComputeAmplitudeScalings(ResultExtension): def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) - # extremum_channel_inds = get_template_extremum_channel(self.sorting_result, outputs="index") - # self.spikes = self.sorting_result.sorting.to_spike_vector( - # extremum_channel_inds=extremum_channel_inds, use_cache=False - # ) self.collisions = None def _set_params( @@ -150,16 +146,14 @@ def _get_pipeline_nodes(self): if self.sorting_result.is_sparse() and self.params["sparsity"] is None: sparsity = self.sorting_result.sparsity elif self.sorting_result.is_sparse() and self.params["sparsity"] is not None: - raise NotImplementedError sparsity = self.params["sparsity"] # assert provided sparsity is sparser than the one in the waveform extractor - waveform_sparsity = we.sparsity + waveform_sparsity = self.sorting_result.sparsity assert np.all( np.sum(waveform_sparsity.mask, 1) - np.sum(sparsity.mask, 1) > 0 ), "The provided sparsity needs to be sparser than the one in the waveform extractor!" elif not self.sorting_result.is_sparse() and self.params["sparsity"] is not None: - raise NotImplementedError - # sparsity = self.params["sparsity"] + sparsity = self.params["sparsity"] else: if self.params["max_dense_channels"] is not None: assert recording.get_num_channels() <= self.params["max_dense_channels"], "" diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index baa278643d..b9477f0a22 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -116,7 +116,6 @@ def _get_pipeline_nodes(self): return nodes def _run(self, **job_kwargs): - # TODO later gather to disk when format="binary_folder" job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amps = run_node_pipeline( diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 32bd5f1455..bc646676b8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -11,9 +11,6 @@ from spikeinterface.core.node_pipeline import SpikeRetriever, run_node_pipeline -# TODO job_kwargs - - class ComputeSpikeLocations(ResultExtension): """ Localize spikes in 2D or 3D with several methods given the template. @@ -126,7 +123,6 @@ def _get_pipeline_nodes(self): return nodes def _run(self, **job_kwargs): - # TODO later gather to disk when format="binary_folder" job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() spike_locations = run_node_pipeline( diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 5e9591996f..3f9817934e 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -18,8 +18,6 @@ # DEBUG = False -# TODO handle external sparsity - def get_single_channel_template_metric_names(): return deepcopy(list(_single_channel_metric_name_to_func.keys())) @@ -117,10 +115,6 @@ def _set_params( include_multi_channel_metrics=False, ): - if sparsity is not None: - # TODO handle extra sparsity - raise NotImplementedError - # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() if include_multi_channel_metrics or ( metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 4205358420..b686e078ee 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -19,8 +19,6 @@ class PrincipalComponentsExtensionTest(ResultExtensionCommonTestSuite, unittest. # mode concatenated cannot be tested here because it do not work with sparse=True ] - # TODO : put back theses tests - def test_mode_concatenated(self): # this is tested outside "extension_function_params_list" because it do not support sparsity! @@ -135,11 +133,11 @@ def test_project_new(self): if __name__ == "__main__": test = PrincipalComponentsExtensionTest() test.setUpClass() - # test.test_extension() - # test.test_mode_concatenated() - # test.test_get_projections() + test.test_extension() + test.test_mode_concatenated() + test.test_get_projections() test.test_compute_for_all_spikes() - # test.test_project_new() + test.test_project_new() # ext = test.sorting_results["sparseTrue_memory"].get_extension("principal_components") # pca = ext.data["pca_projection"] diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index e1e8f1231a..7f48ccb525 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -1,6 +1,6 @@ import unittest -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite, get_sorting_result, get_dataset from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity @@ -11,21 +11,31 @@ class SimilarityExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase) dict(method="cosine_similarity"), ] - # # extend common test - # def test_check_equal_template_with_distribution_overlap(self): - # we = self.we1 - # for unit_id0 in we.unit_ids: - # waveforms0 = we.get_waveforms(unit_id0) - # for unit_id1 in we.unit_ids: - # if unit_id0 == unit_id1: - # continue - # waveforms1 = we.get_waveforms(unit_id1) - # check_equal_template_with_distribution_overlap(waveforms0, waveforms1) + +def test_check_equal_template_with_distribution_overlap(): + + recording, sorting = get_dataset() + + sorting_result = get_sorting_result(recording, sorting, sparsity=None) + sorting_result.select_random_spikes() + sorting_result.compute("waveforms") + sorting_result.compute("templates") + + wf_ext = sorting_result.get_extension("waveforms") + + for unit_id0 in sorting_result.unit_ids: + waveforms0 = wf_ext.get_waveforms_one_unit(unit_id0) + for unit_id1 in sorting_result.unit_ids: + if unit_id0 == unit_id1: + continue + waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) + check_equal_template_with_distribution_overlap(waveforms0, waveforms1) -# TODO check_equal_template_with_distribution_overlap if __name__ == "__main__": - test = SimilarityExtensionTest() - test.setUpClass() - test.test_extension() + # test = SimilarityExtensionTest() + # test.setUpClass() + # test.test_extension() + + test_check_equal_template_with_distribution_overlap() From 23b52be5758e8d91ccd0808567e72aa44a3eecb1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 15 Feb 2024 17:01:11 +0100 Subject: [PATCH 095/192] RIP: WaveformExtractor!! --- src/spikeinterface/core/__init__.py | 10 +- .../core/tests/test_waveform_extractor.py | 620 ----- src/spikeinterface/core/waveform_extractor.py | 2164 ----------------- 3 files changed, 5 insertions(+), 2789 deletions(-) delete mode 100644 src/spikeinterface/core/tests/test_waveform_extractor.py delete mode 100644 src/spikeinterface/core/waveform_extractor.py diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 2a24e42f50..4f386c645c 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -109,13 +109,13 @@ # waveform extractor # Important not for compatibility!! # This wil be commented after 0.100 relase but the module will not be removed. -from .waveform_extractor import ( - WaveformExtractor, - BaseWaveformExtractorExtension, +# from .waveform_extractor import ( +# WaveformExtractor, +# BaseWaveformExtractorExtension, # extract_waveforms, # load_waveforms, - precompute_sparsity, -) +# precompute_sparsity, +# ) # retrieve datasets from .datasets import download_dataset diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py deleted file mode 100644 index 787c94dee8..0000000000 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ /dev/null @@ -1,620 +0,0 @@ -import pytest -from pathlib import Path -import shutil -import numpy as np -import platform -import zarr - - -from spikeinterface.core import ( - generate_recording, - generate_sorting, - NumpySorting, - ChannelSparsity, - generate_ground_truth_recording, -) -from spikeinterface import WaveformExtractor, BaseRecording, extract_waveforms, load_waveforms -from spikeinterface.core.waveform_extractor import precompute_sparsity - - -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" - - -def test_WaveformExtractor(): - durations = [30, 40] - sampling_frequency = 30000.0 - - # 2 segments - num_channels = 4 - recording = generate_recording( - num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency - ) - recording.annotate(is_filtered=True) - # folder_rec = cache_folder / "wf_rec1" - # recording = recording.save(folder=folder_rec) - num_units = 15 - sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) - - # test with dump !!!! - recording = recording.save() - sorting = sorting.save() - - mask = np.zeros((num_units, num_channels), dtype=bool) - mask[:, ::2] = True - num_sparse_channels = 2 - sparsity_ext = ChannelSparsity(mask, sorting.unit_ids, recording.channel_ids) - - for mode in ["folder", "memory"]: - for sparsity in [None, sparsity_ext]: - folder = cache_folder / "test_waveform_extractor" - if folder.is_dir(): - shutil.rmtree(folder) - - print(mode, sparsity) - - if mode == "memory": - wf_folder = None - else: - wf_folder = folder - - sparse = sparsity is not None - we = extract_waveforms( - recording, - sorting, - wf_folder, - mode=mode, - sparsity=sparsity, - sparse=sparse, - ms_before=1.0, - ms_after=1.6, - max_spikes_per_unit=500, - n_jobs=4, - chunk_size=30000, - progress_bar=True, - ) - num_samples = int(sampling_frequency * (1 + 1.6) / 1000.0) - wfs = we.get_waveforms(0) - print(wfs.shape, num_samples) - assert wfs.shape[0] <= 500 - if sparsity is None: - assert wfs.shape[1:] == (num_samples, num_channels) - else: - assert wfs.shape[1:] == (num_samples, num_sparse_channels) - - wfs, sampled_index = we.get_waveforms(0, with_index=True) - - if mode == "folder": - # load back - we = WaveformExtractor.load(folder) - - if sparsity is not None: - assert we.is_sparse() - - wfs = we.get_waveforms(0) - if mode == "folder": - assert isinstance(wfs, np.memmap) - wfs_array = we.get_waveforms(0, lazy=False) - assert isinstance(wfs_array, np.ndarray) - - ## Test force dense mode - wfs = we.get_waveforms(0, force_dense=True) - assert wfs.shape[2] == num_channels - - template = we.get_template(0) - if sparsity is None: - assert template.shape == (num_samples, num_channels) - else: - assert template.shape == (num_samples, num_sparse_channels) - templates = we.get_all_templates() - assert templates.shape == (num_units, num_samples, num_channels) - - template = we.get_template(0, force_dense=True) - assert template.shape == (num_samples, num_channels) - - if sparsity is not None: - assert np.all(templates[:, :, 1] == 0) - assert np.all(templates[:, :, 3] == 0) - - template_std = we.get_template(0, mode="std") - if sparsity is None: - assert template_std.shape == (num_samples, num_channels) - else: - assert template_std.shape == (num_samples, num_sparse_channels) - template_std = we.get_all_templates(mode="std") - assert template_std.shape == (num_units, num_samples, num_channels) - - if sparsity is not None: - assert np.all(template_std[:, :, 1] == 0) - assert np.all(template_std[:, :, 3] == 0) - - template_segment = we.get_template_segment(unit_id=0, segment_index=0) - if sparsity is None: - assert template_segment.shape == (num_samples, num_channels) - else: - assert template_segment.shape == (num_samples, num_sparse_channels) - - # test filter units - keep_units = sorting.get_unit_ids()[::2] - if (cache_folder / "we_filt").is_dir(): - shutil.rmtree(cache_folder / "we_filt") - wf_filt = we.select_units(keep_units, cache_folder / "we_filt") - for unit in wf_filt.sorting.get_unit_ids(): - assert unit in keep_units - filtered_templates = wf_filt.get_all_templates() - assert filtered_templates.shape == (len(keep_units), num_samples, num_channels) - if sparsity is not None: - wf_filt.is_sparse() - - # test save - if (cache_folder / f"we_saved_{mode}").is_dir(): - shutil.rmtree(cache_folder / f"we_saved_{mode}") - we_saved = we.save(cache_folder / f"we_saved_{mode}") - for unit_id in we_saved.unit_ids: - assert np.array_equal(we.get_waveforms(unit_id), we_saved.get_waveforms(unit_id)) - assert np.array_equal(we.get_sampled_indices(unit_id), we_saved.get_sampled_indices(unit_id)) - assert np.array_equal(we.get_all_templates(), we_saved.get_all_templates()) - wfs = we_saved.get_waveforms(0) - assert isinstance(wfs, np.memmap) - wfs_array = we_saved.get_waveforms(0, lazy=False) - assert isinstance(wfs_array, np.ndarray) - - if (cache_folder / f"we_saved_{mode}.zarr").is_dir(): - shutil.rmtree(cache_folder / f"we_saved_{mode}.zarr") - we_saved_zarr = we.save(cache_folder / f"we_saved_{mode}", format="zarr") - for unit_id in we_saved_zarr.unit_ids: - assert np.array_equal(we.get_waveforms(unit_id), we_saved_zarr.get_waveforms(unit_id)) - assert np.array_equal(we.get_sampled_indices(unit_id), we_saved_zarr.get_sampled_indices(unit_id)) - assert np.array_equal(we.get_all_templates(), we_saved_zarr.get_all_templates()) - wfs = we_saved_zarr.get_waveforms(0) - assert isinstance(wfs, zarr.Array) - wfs_array = we_saved_zarr.get_waveforms(0, lazy=False) - assert isinstance(wfs_array, np.ndarray) - - # test delete_waveforms - assert we.has_waveforms() - assert we_saved.has_waveforms() - assert we_saved_zarr.has_waveforms() - - we.delete_waveforms() - we_saved.delete_waveforms() - we_saved_zarr.delete_waveforms() - assert not we.has_waveforms() - assert not we_saved.has_waveforms() - assert not we_saved_zarr.has_waveforms() - - # after reloading, get_waveforms/sampled_indices should result in an AssertionError - we_loaded = load_waveforms(cache_folder / f"we_saved_{mode}") - we_loaded_zarr = load_waveforms(cache_folder / f"we_saved_{mode}.zarr") - assert not we_loaded.has_waveforms() - assert not we_loaded_zarr.has_waveforms() - with pytest.raises(AssertionError): - we_loaded.get_waveforms(we_loaded.unit_ids[0]) - with pytest.raises(AssertionError): - we_loaded_zarr.get_waveforms(we_loaded.unit_ids[0]) - with pytest.raises(AssertionError): - we_loaded.get_sampled_indices(we_loaded.unit_ids[0]) - with pytest.raises(AssertionError): - we_loaded_zarr.get_sampled_indices(we_loaded.unit_ids[0]) - - -def test_extract_waveforms(): - # 2 segments - - durations = [30, 40] - sampling_frequency = 30000.0 - - recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency) - recording.annotate(is_filtered=True) - folder_rec = cache_folder / "wf_rec2" - - sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations) - folder_sort = cache_folder / "wf_sort2" - - if folder_rec.is_dir(): - shutil.rmtree(folder_rec) - if folder_sort.is_dir(): - shutil.rmtree(folder_sort) - recording = recording.save(folder=folder_rec) - # we force "npz_folder" because we want to force the to_multiprocessing to be a SharedMemorySorting - sorting = sorting.save(folder=folder_sort, format="npz_folder") - - # 1 job - folder1 = cache_folder / "test_extract_waveforms_1job" - if folder1.is_dir(): - shutil.rmtree(folder1) - we1 = extract_waveforms(recording, sorting, folder1, max_spikes_per_unit=None, return_scaled=False) - - # 2 job - folder2 = cache_folder / "test_extract_waveforms_2job" - if folder2.is_dir(): - shutil.rmtree(folder2) - we2 = extract_waveforms( - recording, sorting, folder2, n_jobs=2, total_memory="10M", max_spikes_per_unit=None, return_scaled=False - ) - wf1 = we1.get_waveforms(0) - wf2 = we2.get_waveforms(0) - assert np.array_equal(wf1, wf2) - - # return scaled with set scaling values to recording - folder3 = cache_folder / "test_extract_waveforms_returnscaled" - if folder3.is_dir(): - shutil.rmtree(folder3) - gain = 0.1 - recording.set_channel_gains(gain) - recording.set_channel_offsets(0) - we3 = extract_waveforms( - recording, sorting, folder3, n_jobs=2, total_memory="10M", max_spikes_per_unit=None, return_scaled=True - ) - wf3 = we3.get_waveforms(0) - assert np.array_equal((wf1).astype("float32") * gain, wf3) - - # test in memory - we_mem = extract_waveforms( - recording, - sorting, - folder=None, - mode="memory", - n_jobs=2, - total_memory="10M", - max_spikes_per_unit=None, - return_scaled=True, - ) - wf_mem = we_mem.get_waveforms(0) - assert np.array_equal(wf_mem, wf3) - - # Test unfiltered recording - recording.annotate(is_filtered=False) - folder_crash = cache_folder / "test_extract_waveforms_crash" - with pytest.raises(Exception): - we1 = extract_waveforms(recording, sorting, folder_crash, max_spikes_per_unit=None, return_scaled=False) - - folder_unfiltered = cache_folder / "test_extract_waveforms_unfiltered" - if folder_unfiltered.is_dir(): - shutil.rmtree(folder_unfiltered) - we1 = extract_waveforms( - recording, sorting, folder_unfiltered, allow_unfiltered=True, max_spikes_per_unit=None, return_scaled=False - ) - recording.annotate(is_filtered=True) - - # test with sparsity estimation - folder4 = cache_folder / "test_extract_waveforms_compute_sparsity" - if folder4.is_dir(): - shutil.rmtree(folder4) - we4 = extract_waveforms( - recording, - sorting, - folder4, - max_spikes_per_unit=100, - return_scaled=True, - sparse=True, - method="radius", - radius_um=50.0, - n_jobs=2, - chunk_duration="500ms", - ) - assert we4.sparsity is not None - - # test with sparsity estimation - folder5 = cache_folder / "test_extract_waveforms_compute_sparsity_tmp_folder" - sparsity_temp_folder = cache_folder / "tmp_sparsity" - if folder5.is_dir(): - shutil.rmtree(folder5) - - we5 = extract_waveforms( - recording, - sorting, - folder5, - max_spikes_per_unit=100, - return_scaled=True, - sparse=True, - sparsity_temp_folder=sparsity_temp_folder, - method="radius", - radius_um=50.0, - n_jobs=2, - chunk_duration="500ms", - ) - assert we5.sparsity is not None - # tmp folder is cleaned up - assert not sparsity_temp_folder.is_dir() - - # should raise an error if sparsity_temp_folder is not empty - with pytest.raises(AssertionError): - if folder5.is_dir(): - shutil.rmtree(folder5) - sparsity_temp_folder.mkdir() - we5 = extract_waveforms( - recording, - sorting, - folder5, - max_spikes_per_unit=100, - return_scaled=True, - sparse=True, - sparsity_temp_folder=sparsity_temp_folder, - method="radius", - radius_um=50.0, - n_jobs=2, - chunk_duration="500ms", - ) - - -def test_recordingless(): - durations = [30, 40] - sampling_frequency = 30000.0 - - # 2 segments - num_channels = 2 - recording = generate_recording( - num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency - ) - recording.annotate(is_filtered=True) - num_units = 15 - sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) - - # now save and delete saved file - recording = recording.save(folder=cache_folder / "recording1") - sorting = sorting.save(folder=cache_folder / "sorting1") - - # recording and sorting are not serializable - wf_folder = cache_folder / "wf_recordingless" - - # save with relative paths - we = extract_waveforms(recording, sorting, wf_folder, use_relative_path=True, return_scaled=False) - we_loaded = WaveformExtractor.load(wf_folder, with_recording=False) - - assert isinstance(we.recording, BaseRecording) - assert not we_loaded.has_recording() - with pytest.raises(ValueError): - # reccording cannot be accessible - rec = we_loaded.recording - assert we.sampling_frequency == we_loaded.sampling_frequency - assert np.array_equal(we.recording.channel_ids, np.array(we_loaded.channel_ids)) - assert np.array_equal(we.recording.get_channel_locations(), np.array(we_loaded.get_channel_locations())) - assert we.get_num_channels() == we_loaded.get_num_channels() - assert all( - we.recording.get_num_samples(seg) == we_loaded.get_num_samples(seg) - for seg in range(we_loaded.get_num_segments()) - ) - assert we.recording.get_total_duration() == we_loaded.get_total_duration() - - for key in we.recording.get_property_keys(): - if key != "contact_vector": # contact vector is saved as probe - np.testing.assert_array_equal(we.recording.get_property(key), we_loaded.get_recording_property(key)) - - probe = we_loaded.get_probe() - probegroup = we_loaded.get_probegroup() - - # delete original recording and rely on rec_attributes - if platform.system() != "Windows": - # this avoid reference on the folder - del we, recording - shutil.rmtree(cache_folder / "recording1") - we_loaded = WaveformExtractor.load(wf_folder, with_recording=False) - assert not we_loaded.has_recording() - - -def test_unfiltered_extraction(): - durations = [30, 40] - sampling_frequency = 30000.0 - - # 2 segments - num_channels = 2 - recording = generate_recording( - num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency - ) - recording.annotate(is_filtered=False) - folder_rec = cache_folder / "wf_unfiltered" - recording = recording.save(folder=folder_rec) - num_units = 15 - sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) - - # test with dump !!!! - recording = recording.save() - sorting = sorting.save() - - folder = cache_folder / "test_waveform_extractor_unfiltered" - if folder.is_dir(): - shutil.rmtree(folder) - - for mode in ["folder", "memory"]: - if mode == "memory": - wf_folder = None - else: - wf_folder = folder - - with pytest.raises(Exception): - we = WaveformExtractor.create(recording, sorting, wf_folder, mode=mode, allow_unfiltered=False) - if wf_folder is not None: - shutil.rmtree(wf_folder) - we = WaveformExtractor.create(recording, sorting, wf_folder, mode=mode, allow_unfiltered=True) - - ms_before = 2.0 - ms_after = 3.0 - max_spikes_per_unit = 500 - num_samples = int((ms_before + ms_after) * sampling_frequency / 1000.0) - we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) - we.run_extract_waveforms(n_jobs=1, chunk_size=30000) - we.run_extract_waveforms(n_jobs=4, chunk_size=30000, progress_bar=True) - - wfs = we.get_waveforms(0) - assert wfs.shape[0] <= max_spikes_per_unit - assert wfs.shape[1:] == (num_samples, num_channels) - - wfs, sampled_index = we.get_waveforms(0, with_index=True) - - if mode == "folder": - # load back - we = WaveformExtractor.load_from_folder(folder) - - wfs = we.get_waveforms(0) - - template = we.get_template(0) - assert template.shape == (num_samples, 2) - templates = we.get_all_templates() - assert templates.shape == (num_units, num_samples, num_channels) - - wf_std = we.get_template(0, mode="std") - assert wf_std.shape == (num_samples, num_channels) - wfs_std = we.get_all_templates(mode="std") - assert wfs_std.shape == (num_units, num_samples, num_channels) - - wf_prct = we.get_template(0, mode="percentile", percentile=10) - assert wf_prct.shape == (num_samples, num_channels) - wfs_prct = we.get_all_templates(mode="percentile", percentile=10) - assert wfs_prct.shape == (num_units, num_samples, num_channels) - - # percentile mode should fail if percentile is None or not in [0, 100] - with pytest.raises(AssertionError): - wf_prct = we.get_template(0, mode="percentile") - with pytest.raises(AssertionError): - wfs_prct = we.get_all_templates(mode="percentile") - with pytest.raises(AssertionError): - wfs_prct = we.get_all_templates(mode="percentile", percentile=101) - - wf_segment = we.get_template_segment(unit_id=0, segment_index=0) - assert wf_segment.shape == (num_samples, num_channels) - assert wf_segment.shape == (num_samples, num_channels) - - -def test_portability(): - durations = [30, 40] - sampling_frequency = 30000.0 - - folder_to_move = cache_folder / "original_folder" - if folder_to_move.is_dir(): - shutil.rmtree(folder_to_move) - folder_to_move.mkdir() - folder_moved = cache_folder / "moved_folder" - if folder_moved.is_dir(): - shutil.rmtree(folder_moved) - # folder_moved.mkdir() - - # 2 segments - num_channels = 2 - recording = generate_recording( - num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency - ) - recording.annotate(is_filtered=True) - folder_rec = folder_to_move / "rec" - recording = recording.save(folder=folder_rec) - num_units = 15 - sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) - folder_sort = folder_to_move / "sort" - sorting = sorting.save(folder=folder_sort) - - wf_folder = folder_to_move / "waveform_extractor" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - - # save with relative paths - we = extract_waveforms(recording, sorting, wf_folder, use_relative_path=True) - - # move all to a separate folder - shutil.copytree(folder_to_move, folder_moved) - wf_folder_moved = folder_moved / "waveform_extractor" - we_loaded = load_waveforms(folder=wf_folder_moved, with_recording=True, sorting=sorting) - - assert we_loaded.recording is not None - assert we_loaded.sorting is not None - - assert np.allclose(we.channel_ids, we_loaded.recording.channel_ids) - assert np.allclose(we.unit_ids, we_loaded.unit_ids) - - for unit in we.unit_ids: - wf = we.get_waveforms(unit_id=unit) - wf_loaded = we_loaded.get_waveforms(unit_id=unit) - - assert np.allclose(wf, wf_loaded) - - -def test_empty_sorting(): - sf = 30000 - num_channels = 2 - - recording = generate_recording(num_channels=num_channels, sampling_frequency=sf, durations=[15.32]) - sorting = NumpySorting.from_unit_dict({}, sf) - - folder = cache_folder / "empty_sorting" - wvf_extractor = extract_waveforms(recording, sorting, folder, allow_unfiltered=True) - - assert len(wvf_extractor.unit_ids) == 0 - assert wvf_extractor.get_all_templates().shape == (0, wvf_extractor.nsamples, num_channels) - - -def test_compute_sparsity(): - durations = [30, 40] - sampling_frequency = 30000.0 - - num_channels = 4 - recording = generate_recording( - num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency - ) - recording.annotate(is_filtered=True) - - num_units = 15 - sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) - - # test with dump - recording = recording.save() - sorting = sorting.save() - - job_kwargs = dict(n_jobs=4, chunk_size=30000, progress_bar=False) - - for kwargs in [dict(method="radius", radius_um=50.0), dict(method="best_channels", num_channels=2)]: - sparsity = precompute_sparsity( - recording, - sorting, - num_spikes_for_sparsity=100, - unit_batch_size=2, - ms_before=1.0, - ms_after=1.5, - **kwargs, - **job_kwargs, - ) - print(sparsity) - - -def test_non_json_object(): - recording, sorting = generate_ground_truth_recording( - durations=[30, 40], - sampling_frequency=30000.0, - num_channels=32, - num_units=5, - ) - - # recording is not save to keep it in memory - sorting = sorting.save() - - wf_folder = cache_folder / "test_waveform_extractor" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - - we = extract_waveforms( - recording, - sorting, - wf_folder, - mode="folder", - sparsity=None, - sparse=False, - ms_before=1.0, - ms_after=1.6, - max_spikes_per_unit=50, - n_jobs=4, - chunk_size=30000, - progress_bar=True, - ) - - # This used to fail because of json - we = load_waveforms(wf_folder) - - -if __name__ == "__main__": - # test_WaveformExtractor() - # test_extract_waveforms() - # test_portability() - test_recordingless() - # test_compute_sparsity() - # test_non_json_object() - test_empty_sorting() diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py deleted file mode 100644 index 8c3f53b64c..0000000000 --- a/src/spikeinterface/core/waveform_extractor.py +++ /dev/null @@ -1,2164 +0,0 @@ -from __future__ import annotations - -import math -import pickle -from pathlib import Path -import shutil -from typing import Literal, Optional -import json -import os -import weakref - -import numpy as np -from copy import deepcopy -from warnings import warn - -import probeinterface - -from .base import load_extractor -from .baserecording import BaseRecording -from .basesorting import BaseSorting -from .core_tools import check_json -from .job_tools import _shared_job_kwargs_doc, split_job_kwargs, fix_job_kwargs -from .numpyextractors import NumpySorting -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes -from .sparsity import ChannelSparsity, compute_sparsity, _sparsity_doc -from .waveform_tools import extract_waveforms_to_buffers, has_exceeding_spikes - -_possible_template_modes = ("average", "std", "median", "percentile") - - -class WaveformExtractor: - """ - Class to extract waveform on paired Recording-Sorting objects. - Waveforms are persistent on disk and cached in memory. - - Parameters - ---------- - recording: Recording | None - The recording object - sorting: Sorting - The sorting object - folder: Path - The folder where waveforms are cached - rec_attributes: None or dict - When recording is None then a minimal dict with some attributes - is needed. - allow_unfiltered: bool, default: False - If true, will accept unfiltered recording. - Returns - ------- - we: WaveformExtractor - The WaveformExtractor object - - Examples - -------- - - >>> # Instantiate - >>> we = WaveformExtractor.create(recording, sorting, folder) - - >>> # Compute - >>> we = we.set_params(...) - >>> we = we.run_extract_waveforms(...) - - >>> # Retrieve - >>> waveforms = we.get_waveforms(unit_id) - >>> template = we.get_template(unit_id, mode="median") - - >>> # Load from folder (in another session) - >>> we = WaveformExtractor.load(folder) - - """ - - extensions = [] - - def __init__( - self, - recording: Optional[BaseRecording], - sorting: BaseSorting, - folder=None, - rec_attributes=None, - allow_unfiltered: bool = False, - sparsity=None, - ) -> None: - self.sorting = sorting - self._rec_attributes = None - self.set_recording(recording, rec_attributes, allow_unfiltered) - - # cache in memory - self._waveforms = {} - self._template_cache = {} - self._params = {} - self._loaded_extensions = dict() - self._is_read_only = False - self.sparsity = sparsity - - self.folder = folder - if self.folder is not None: - self.folder = Path(self.folder) - if self.folder.suffix == ".zarr": - import zarr - - self.format = "zarr" - self._waveforms_root = zarr.open(self.folder, mode="r") - self._params = self._waveforms_root.attrs["params"] - else: - self.format = "binary" - if (self.folder / "params.json").is_file(): - with open(str(self.folder / "params.json"), "r") as f: - self._params = json.load(f) - if not os.access(self.folder, os.W_OK): - self._is_read_only = True - else: - # this is in case of in-memory - self.format = "memory" - self._memory_objects = None - - def __repr__(self) -> str: - clsname = self.__class__.__name__ - nseg = self.get_num_segments() - nchan = self.get_num_channels() - nunits = self.sorting.get_num_units() - txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments" - if len(self._params) > 0: - max_spikes_per_unit = self._params["max_spikes_per_unit"] - txt = txt + f"\n before:{self.nbefore} after:{self.nafter} n_per_units:{max_spikes_per_unit}" - if self.is_sparse(): - txt += " - sparse" - return txt - - @classmethod - def load(cls, folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None) -> "WaveformExtractor": - folder = Path(folder) - assert folder.is_dir(), "Waveform folder does not exists" - if folder.suffix == ".zarr": - return WaveformExtractor.load_from_zarr(folder, with_recording=with_recording, sorting=sorting) - else: - return WaveformExtractor.load_from_folder(folder, with_recording=with_recording, sorting=sorting) - - @classmethod - def load_from_folder( - cls, folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None - ) -> "WaveformExtractor": - folder = Path(folder) - assert folder.is_dir(), f"This waveform folder does not exists {folder}" - - if not with_recording: - # load - recording = None - rec_attributes_file = folder / "recording_info" / "recording_attributes.json" - if not rec_attributes_file.exists(): - raise ValueError( - "This WaveformExtractor folder was created with an older version of spikeinterface" - "\nYou cannot use the mode with_recording=False" - ) - with open(rec_attributes_file, "r") as f: - rec_attributes = json.load(f) - # the probe is handle ouside the main json - probegroup_file = folder / "recording_info" / "probegroup.json" - if probegroup_file.is_file(): - rec_attributes["probegroup"] = probeinterface.read_probeinterface(probegroup_file) - else: - rec_attributes["probegroup"] = None - else: - recording = None - if (folder / "recording.json").exists(): - try: - recording = load_extractor(folder / "recording.json", base_folder=folder) - except: - pass - elif (folder / "recording.pickle").exists(): - try: - recording = load_extractor(folder / "recording.pickle", base_folder=folder) - except: - pass - if recording is None: - raise Exception("The recording could not be loaded. You can use the `with_recording=False` argument") - rec_attributes = None - - if sorting is None: - if (folder / "sorting.json").exists(): - sorting = load_extractor(folder / "sorting.json", base_folder=folder) - elif (folder / "sorting.pickle").exists(): - sorting = load_extractor(folder / "sorting.pickle", base_folder=folder) - else: - raise FileNotFoundError("load_waveforms() impossible to find the sorting object (json or pickle)") - - # the sparsity is the sparsity of the saved/cached waveforms arrays - sparsity_file = folder / "sparsity.json" - if sparsity_file.is_file(): - with open(sparsity_file, mode="r") as f: - sparsity = ChannelSparsity.from_dict(json.load(f)) - else: - sparsity = None - - we = cls( - recording, sorting, folder=folder, rec_attributes=rec_attributes, allow_unfiltered=True, sparsity=sparsity - ) - - for mode in _possible_template_modes: - # load cached templates - template_file = folder / f"templates_{mode}.npy" - if template_file.is_file(): - we._template_cache[mode] = np.load(template_file) - - return we - - @classmethod - def load_from_zarr( - cls, folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None - ) -> "WaveformExtractor": - import zarr - - folder = Path(folder) - assert folder.is_dir(), f"This waveform folder does not exists {folder}" - assert folder.suffix == ".zarr" - - waveforms_root = zarr.open(folder, mode="r+") - - if not with_recording: - # load - recording = None - rec_attributes = waveforms_root.require_group("recording_info").attrs["recording_attributes"] - # the probe is handle ouside the main json - if "probegroup" in waveforms_root.require_group("recording_info").attrs: - probegroup_dict = waveforms_root.require_group("recording_info").attrs["probegroup"] - rec_attributes["probegroup"] = probeinterface.Probe.from_dict(probegroup_dict) - else: - rec_attributes["probegroup"] = None - else: - try: - recording_dict = waveforms_root.attrs["recording"] - recording = load_extractor(recording_dict, base_folder=folder) - rec_attributes = None - except: - raise Exception("The recording could not be loaded. You can use the `with_recording=False` argument") - - if sorting is None: - sorting_dict = waveforms_root.attrs["sorting"] - sorting = load_extractor(sorting_dict, base_folder=folder) - - if "sparsity" in waveforms_root.attrs: - sparsity = waveforms_root.attrs["sparsity"] - else: - sparsity = None - - we = cls( - recording, sorting, folder=folder, rec_attributes=rec_attributes, allow_unfiltered=True, sparsity=sparsity - ) - - for mode in _possible_template_modes: - # load cached templates - if f"templates_{mode}" in waveforms_root.keys(): - we._template_cache[mode] = waveforms_root[f"templates_{mode}"] - return we - - @classmethod - def create( - cls, - recording: BaseRecording, - sorting: BaseSorting, - folder, - mode: Literal["folder", "memory"] = "folder", - remove_if_exists: bool = False, - use_relative_path: bool = False, - allow_unfiltered: bool = False, - sparsity=None, - ) -> "WaveformExtractor": - assert mode in ("folder", "memory") - # create rec_attributes - if has_exceeding_spikes(recording, sorting): - raise ValueError( - "The sorting object has spikes exceeding the recording duration. You have to remove those spikes " - "with the `spikeinterface.curation.remove_excess_spikes()` function" - ) - rec_attributes = get_rec_attributes(recording) - if mode == "folder": - folder = Path(folder) - if folder.is_dir(): - if remove_if_exists: - shutil.rmtree(folder) - else: - raise FileExistsError(f"Folder {folder} already exists") - folder.mkdir(parents=True) - - if use_relative_path: - relative_to = folder - else: - relative_to = None - - if recording.check_serializability("json"): - recording.dump(folder / "recording.json", relative_to=relative_to) - elif recording.check_serializability("pickle"): - recording.dump(folder / "recording.pickle", relative_to=relative_to) - - if sorting.check_serializability("json"): - sorting.dump(folder / "sorting.json", relative_to=relative_to) - elif sorting.check_serializability("pickle"): - sorting.dump(folder / "sorting.pickle", relative_to=relative_to) - else: - warn( - "Sorting object is not serializable to file, which might result in downstream errors for " - "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." - ) - - # dump some attributes of the recording for the mode with_recording=False at next load - rec_attributes_file = folder / "recording_info" / "recording_attributes.json" - rec_attributes_file.parent.mkdir() - rec_attributes_file.write_text(json.dumps(check_json(rec_attributes), indent=4), encoding="utf8") - if recording.get_probegroup() is not None: - probegroup_file = folder / "recording_info" / "probegroup.json" - probeinterface.write_probeinterface(probegroup_file, recording.get_probegroup()) - - with open(rec_attributes_file, "r") as f: - rec_attributes = json.load(f) - - if sparsity is not None: - with open(folder / "sparsity.json", mode="w") as f: - json.dump(check_json(sparsity.to_dict()), f) - - return cls( - recording, - sorting, - folder, - allow_unfiltered=allow_unfiltered, - sparsity=sparsity, - rec_attributes=rec_attributes, - ) - - def is_sparse(self) -> bool: - return self.sparsity is not None - - def has_waveforms(self) -> bool: - if self.folder is not None: - if self.format == "binary": - return (self.folder / "waveforms").is_dir() - elif self.format == "zarr": - import zarr - - root = zarr.open(self.folder) - return "waveforms" in root.keys() - else: - return self._memory_objects is not None - - def delete_waveforms(self) -> None: - """ - Deletes waveforms folder. - """ - assert self.has_waveforms(), "WaveformExtractor object doesn't have waveforms already!" - if self.folder is not None: - if self.format == "binary": - shutil.rmtree(self.folder / "waveforms") - elif self.format == "zarr": - import zarr - - root = zarr.open(self.folder) - del root["waveforms"] - else: - self._memory_objects = None - - @classmethod - def register_extension(cls, extension_class) -> None: - """ - This maintains a list of possible extensions that are available. - It depends on the imported submodules (e.g. for postprocessing module). - - For instance: - import spikeinterface as si - si.WaveformExtractor.extensions == [] - - from spikeinterface.postprocessing import WaveformPrincipalComponent - si.WaveformExtractor.extensions == [WaveformPrincipalComponent, ...] - - """ - assert issubclass(extension_class, BaseWaveformExtractorExtension) - assert extension_class.extension_name is not None, "extension_name must not be None" - assert all( - extension_class.extension_name != ext.extension_name for ext in cls.extensions - ), "Extension name already exists" - cls.extensions.append(extension_class) - - # map some method from recording and sorting - @property - def recording(self) -> BaseRecording: - if not self.has_recording(): - raise ValueError( - 'WaveformExtractor is used in mode "with_recording=False" ' "this operation needs the recording" - ) - return self._recording - - @property - def channel_ids(self) -> np.ndarray: - if self.has_recording(): - return self.recording.channel_ids - else: - return np.array(self._rec_attributes["channel_ids"]) - - @property - def sampling_frequency(self) -> float: - return self.sorting.get_sampling_frequency() - - @property - def unit_ids(self) -> np.ndarray: - return self.sorting.unit_ids - - @property - def nbefore(self) -> int: - nbefore = int(self._params["ms_before"] * self.sampling_frequency / 1000.0) - return nbefore - - @property - def nafter(self) -> int: - nafter = int(self._params["ms_after"] * self.sampling_frequency / 1000.0) - return nafter - - @property - def nsamples(self) -> int: - return self.nbefore + self.nafter - - @property - def return_scaled(self) -> bool: - return self._params["return_scaled"] - - @property - def dtype(self): - return self._params["dtype"] - - def is_read_only(self) -> bool: - return self._is_read_only - - def has_recording(self) -> bool: - return self._recording is not None - - def get_num_samples(self, segment_index: Optional[int] = None) -> int: - if self.has_recording(): - return self.recording.get_num_samples(segment_index) - else: - assert "num_samples" in self._rec_attributes, "'num_samples' is not available" - # we use self.sorting to check segment_index - segment_index = self.sorting._check_segment_index(segment_index) - return self._rec_attributes["num_samples"][segment_index] - - def get_total_samples(self) -> int: - s = 0 - for segment_index in range(self.get_num_segments()): - s += self.get_num_samples(segment_index) - return s - - def get_total_duration(self) -> float: - duration = self.get_total_samples() / self.sampling_frequency - return duration - - def get_num_channels(self) -> int: - if self.has_recording(): - return self.recording.get_num_channels() - else: - return self._rec_attributes["num_channels"] - - def get_num_segments(self) -> int: - return self.sorting.get_num_segments() - - def get_probegroup(self): - if self.has_recording(): - return self.recording.get_probegroup() - else: - return self._rec_attributes["probegroup"] - - def is_filtered(self) -> bool: - if self.has_recording(): - return self.recording.is_filtered() - else: - return self._rec_attributes["is_filtered"] - - def get_probe(self): - probegroup = self.get_probegroup() - assert len(probegroup.probes) == 1, "There are several probes. Use `get_probegroup()`" - return probegroup.probes[0] - - def get_channel_locations(self) -> np.ndarray: - # important note : contrary to recording - # this give all channel locations, so no kwargs like channel_ids and axes - if self.has_recording(): - return self.recording.get_channel_locations() - else: - if self.get_probegroup() is not None: - all_probes = self.get_probegroup().probes - # check that multiple probes are non-overlapping - check_probe_do_not_overlap(all_probes) - all_positions = np.vstack([probe.contact_positions for probe in all_probes]) - return all_positions - else: - raise Exception("There are no channel locations") - - def channel_ids_to_indices(self, channel_ids) -> np.ndarray: - if self.has_recording(): - return self.recording.ids_to_indices(channel_ids) - else: - all_channel_ids = self._rec_attributes["channel_ids"] - indices = np.array([all_channel_ids.index(id) for id in channel_ids], dtype=int) - return indices - - def get_recording_property(self, key) -> np.ndarray: - if self.has_recording(): - return self.recording.get_property(key) - else: - assert "properties" in self._rec_attributes, "'properties' are not available" - values = np.array(self._rec_attributes["properties"].get(key, None)) - return values - - def get_sorting_property(self, key) -> np.ndarray: - return self.sorting.get_property(key) - - def get_extension_class(self, extension_name: str): - """ - Get extension class from name and check if registered. - - Parameters - ---------- - extension_name: str - The extension name. - - Returns - ------- - ext_class: - The class of the extension. - """ - extensions_dict = {ext.extension_name: ext for ext in self.extensions} - assert extension_name in extensions_dict, "Extension is not registered, please import related module before" - ext_class = extensions_dict[extension_name] - return ext_class - - def has_extension(self, extension_name: str) -> bool: - """ - Check if the extension exists in memory or in the folder. - - Parameters - ---------- - extension_name: str - The extension name. - - Returns - ------- - exists: bool - Whether the extension exists or not - """ - if self.folder is None: - return extension_name in self._loaded_extensions - - if extension_name in self._loaded_extensions: - # extension already loaded in memory - return True - else: - if self.format == "binary": - return (self.folder / extension_name).is_dir() and ( - self.folder / extension_name / "params.json" - ).is_file() - elif self.format == "zarr": - return ( - extension_name in self._waveforms_root.keys() - and "params" in self._waveforms_root[extension_name].attrs.keys() - ) - - def is_extension(self, extension_name) -> bool: - warn( - "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.has_extension(extension_name) - - def load_extension(self, extension_name: str): - """ - Load an extension from its name. - The module of the extension must be loaded and registered. - - Parameters - ---------- - extension_name: str - The extension name. - - Returns - ------- - ext_instanace: - The loaded instance of the extension - """ - if self.folder is not None and extension_name not in self._loaded_extensions: - if self.has_extension(extension_name): - ext_class = self.get_extension_class(extension_name) - ext = ext_class.load(self.folder, self) - if extension_name not in self._loaded_extensions: - raise Exception(f"Extension {extension_name} not available") - return self._loaded_extensions[extension_name] - - def delete_extension(self, extension_name) -> None: - """ - Deletes an existing extension. - - Parameters - ---------- - extension_name: str - The extension name. - """ - assert self.has_extension(extension_name), f"The extension {extension_name} is not available" - del self._loaded_extensions[extension_name] - if self.folder is not None and (self.folder / extension_name).is_dir(): - shutil.rmtree(self.folder / extension_name) - - def get_available_extension_names(self): - """ - Return a list of loaded or available extension names either in memory or - in persistent extension folders. - Then instances can be loaded with we.load_extension(extension_name) - - Importante note: extension modules need to be loaded (and so registered) - before this call, otherwise extensions will be ignored even if the folder - exists. - - Returns - ------- - extension_names_in_folder: list - A list of names of computed extension in this folder - """ - extension_names_in_folder = [] - for extension_class in self.extensions: - if self.has_extension(extension_class.extension_name): - extension_names_in_folder.append(extension_class.extension_name) - return extension_names_in_folder - - def _reset(self) -> None: - self._waveforms = {} - self._template_cache = {} - self._params = {} - - if self.folder is not None: - waveform_folder = self.folder / "waveforms" - if waveform_folder.is_dir(): - shutil.rmtree(waveform_folder) - for mode in _possible_template_modes: - template_file = self.folder / f"templates_{mode}.npy" - if template_file.is_file(): - template_file.unlink() - - waveform_folder.mkdir() - else: - # remove shared objects - self._memory_objects = None - - def set_recording( - self, recording: Optional[BaseRecording], rec_attributes: Optional[dict] = None, allow_unfiltered: bool = False - ) -> None: - """ - Sets the recording object and attributes for the WaveformExtractor. - - Parameters - ---------- - recording: Recording | None - The recording object - rec_attributes: None or dict - When recording is None then a minimal dict with some attributes - is needed. - allow_unfiltered: bool, default: False - If true, will accept unfiltered recording. - """ - - if recording is None: # Recordless mode. - if rec_attributes is None: - raise ValueError("WaveformExtractor: if recording is None, then rec_attributes must be provided.") - for k in ( - "channel_ids", - "sampling_frequency", - "num_channels", - ): # Some check on minimal attributes (probegroup is not mandatory) - if k not in rec_attributes: - raise ValueError(f"WaveformExtractor: Missing key '{k}' in rec_attributes") - for k in ("num_samples", "properties", "is_filtered"): - if k not in rec_attributes: - warn( - f"Missing optional key in rec_attributes {k}: " - f"some recordingless functions might not be available" - ) - else: - if rec_attributes is None: - rec_attributes = get_rec_attributes(recording) - - if recording.get_num_segments() != self.get_num_segments(): - raise ValueError( - f"Couldn't set the WaveformExtractor recording: num_segments do not match!\n{self.get_num_segments()} != {recording.get_num_segments()}" - ) - if not math.isclose(recording.sampling_frequency, self.sampling_frequency, abs_tol=1e-2, rel_tol=1e-5): - raise ValueError( - f"Couldn't set the WaveformExtractor recording: sampling frequency doesn't match!\n{self.sampling_frequency} != {recording.sampling_frequency}" - ) - if self._rec_attributes is not None: - reference_channel_ids = self._rec_attributes["channel_ids"] - else: - reference_channel_ids = rec_attributes["channel_ids"] - if not np.array_equal(reference_channel_ids, recording.channel_ids): - raise ValueError( - f"Couldn't set the WaveformExtractor recording: channel_ids do not match!\n{reference_channel_ids}" - ) - - if not recording.is_filtered() and not allow_unfiltered: - raise Exception( - "The recording is not filtered, you must filter it using `bandpass_filter()`." - "If the recording is already filtered, you can also do " - "`recording.annotate(is_filtered=True).\n" - "If you trully want to extract unfiltered waveforms, use `allow_unfiltered=True`." - ) - - self._recording = recording - self._rec_attributes = rec_attributes - - def set_params( - self, - ms_before: float = 1.0, - ms_after: float = 2.0, - max_spikes_per_unit: int = 500, - return_scaled: bool = False, - dtype=None, - ) -> None: - """ - Set parameters for waveform extraction - - Parameters - ---------- - ms_before: float - Cut out in ms before spike time - ms_after: float - Cut out in ms after spike time - max_spikes_per_unit: int - Maximum number of spikes to extract per unit - return_scaled: bool - If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV. - dtype: np.dtype - The dtype of the computed waveforms - """ - self._reset() - - if dtype is None: - dtype = self.recording.get_dtype() - - if return_scaled: - # check if has scaled values: - if not self.recording.has_scaled(): - print("Setting 'return_scaled' to False") - return_scaled = False - - if np.issubdtype(dtype, np.integer) and return_scaled: - dtype = "float32" - - dtype = np.dtype(dtype) - - if max_spikes_per_unit is not None: - max_spikes_per_unit = int(max_spikes_per_unit) - - self._params = dict( - ms_before=float(ms_before), - ms_after=float(ms_after), - max_spikes_per_unit=max_spikes_per_unit, - return_scaled=return_scaled, - dtype=dtype.str, - ) - - if self.folder is not None: - (self.folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") - - def select_units(self, unit_ids, new_folder=None, use_relative_path: bool = False) -> "WaveformExtractor": - """ - Filters units by creating a new waveform extractor object in a new folder. - - Extensions are also updated to filter the selected unit ids. - - Parameters - ---------- - unit_ids : list or array - The unit ids to keep in the new WaveformExtractor object - new_folder : Path or None - The new folder where selected waveforms are copied - - Returns - ------- - we : WaveformExtractor - The newly create waveform extractor with the selected units - """ - sorting = self.sorting.select_units(unit_ids) - unit_indices = self.sorting.ids_to_indices(unit_ids) - - if self.folder is not None and new_folder is not None: - if self.format == "binary": - new_folder = Path(new_folder) - assert not new_folder.is_dir(), f"{new_folder} already exists!" - new_folder.mkdir(parents=True) - - # create new waveform extractor folder - shutil.copyfile(self.folder / "params.json", new_folder / "params.json") - - if use_relative_path: - relative_to = new_folder - else: - relative_to = None - - if self.has_recording(): - self.recording.dump(new_folder / "recording.json", relative_to=relative_to) - - shutil.copytree(self.folder / "recording_info", new_folder / "recording_info") - - sorting.dump(new_folder / "sorting.json", relative_to=relative_to) - - # create and populate waveforms folder - new_waveforms_folder = new_folder / "waveforms" - new_waveforms_folder.mkdir() - - waveforms_files = [f for f in (self.folder / "waveforms").iterdir() if f.suffix == ".npy"] - for unit in sorting.get_unit_ids(): - for wf_file in waveforms_files: - if f"waveforms_{unit}.npy" in wf_file.name or f"sampled_index_{unit}.npy" in wf_file.name: - shutil.copyfile(wf_file, new_waveforms_folder / wf_file.name) - - template_files = [f for f in self.folder.iterdir() if "template" in f.name and f.suffix == ".npy"] - for tmp_file in template_files: - templates_data_sliced = np.load(tmp_file)[unit_indices] - np.save(new_waveforms_folder / tmp_file.name, templates_data_sliced) - - # slice masks - if self.is_sparse(): - mask = self.sparsity.mask[unit_indices] - new_sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids) - with (new_folder / "sparsity.json").open("w") as f: - json.dump(check_json(new_sparsity.to_dict()), f) - - we = WaveformExtractor.load(new_folder, with_recording=self.has_recording()) - - elif self.format == "zarr": - raise NotImplementedError( - "For zarr format, `select_units()` to a folder is not supported yet. " - "You can select units in two steps:\n" - "1. `we_new = select_units(unit_ids, new_folder=None)`\n" - "2. `we_new.save(folder='new_folder', format='zarr')`" - ) - else: - sorting = self.sorting.select_units(unit_ids) - if self.is_sparse(): - mask = self.sparsity.mask[unit_indices] - sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids) - else: - sparsity = None - if self.has_recording(): - we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) - else: - we = WaveformExtractor( - recording=None, - sorting=sorting, - folder=None, - sparsity=sparsity, - rec_attributes=self._rec_attributes, - allow_unfiltered=True, - ) - we._params = self._params - # copy memory objects - if self.has_waveforms(): - we._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}} - for unit_id in unit_ids: - if self.format == "memory": - we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] - we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][ - unit_id - ] - else: - we._memory_objects["wfs_arrays"][unit_id] = self.get_waveforms(unit_id) - we._memory_objects["sampled_indices"][unit_id] = self.get_sampled_indices(unit_id) - - # finally select extensions data - for ext_name in self.get_available_extension_names(): - ext = self.load_extension(ext_name) - ext.select_units(unit_ids, new_waveform_extractor=we) - - return we - - def save( - self, folder, format="binary", use_relative_path: bool = False, overwrite: bool = False, sparsity=None, **kwargs - ) -> "WaveformExtractor": - """ - Save WaveformExtractor object to disk. - - Parameters - ---------- - folder : str or Path - The output waveform folder - format : "binary" | "zarr", default: "binary" - The backend to use for saving the waveforms - overwrite : bool - If True and folder exists, it is deleted, default: False - use_relative_path : bool, default: False - If True, the recording and sorting paths are relative to the waveforms folder. - This allows portability of the waveform folder provided that the relative paths are the same, - but forces all the data files to be in the same drive - sparsity : ChannelSparsity, default: None - If given and WaveformExtractor is not sparse, it makes the returned WaveformExtractor sparse - """ - folder = Path(folder) - if use_relative_path: - relative_to = folder - else: - relative_to = None - - probegroup = None - if self.has_recording(): - rec_attributes = dict( - channel_ids=self.recording.channel_ids, - sampling_frequency=self.recording.get_sampling_frequency(), - num_channels=self.recording.get_num_channels(), - ) - if self.recording.get_probegroup() is not None: - probegroup = self.recording.get_probegroup() - else: - rec_attributes = deepcopy(self._rec_attributes) - probegroup = rec_attributes["probegroup"] - - if self.is_sparse(): - assert sparsity is None, "WaveformExtractor is already sparse!" - - if format == "binary": - if folder.is_dir() and overwrite: - shutil.rmtree(folder) - assert not folder.is_dir(), "Folder already exists. Use 'overwrite=True'" - folder.mkdir(parents=True) - # write metadata - (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") - - if self.has_recording(): - if self.recording.check_serializability("json"): - self.recording.dump(folder / "recording.json", relative_to=relative_to) - elif self.recording.check_serializability("pickle"): - self.recording.dump(folder / "recording.pickle", relative_to=relative_to) - - if self.sorting.check_serializability("json"): - self.sorting.dump(folder / "sorting.json", relative_to=relative_to) - elif self.sorting.check_serializability("pickle"): - self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to) - else: - warn( - "Sorting object is not serializable to file, which might result in downstream errors for " - "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." - ) - - # dump some attributes of the recording for the mode with_recording=False at next load - rec_attributes_file = folder / "recording_info" / "recording_attributes.json" - rec_attributes_file.parent.mkdir() - rec_attributes_file.write_text(json.dumps(check_json(rec_attributes), indent=4), encoding="utf8") - if probegroup is not None: - probegroup_file = folder / "recording_info" / "probegroup.json" - probeinterface.write_probeinterface(probegroup_file, probegroup) - with open(rec_attributes_file, "r") as f: - rec_attributes = json.load(f) - for mode, templates in self._template_cache.items(): - templates_save = templates.copy() - if sparsity is not None: - expanded_mask = np.tile(sparsity.mask[:, np.newaxis, :], (1, templates_save.shape[1], 1)) - templates_save[~expanded_mask] = 0 - template_file = folder / f"templates_{mode}.npy" - np.save(template_file, templates_save) - if sparsity is not None: - with (folder / "sparsity.json").open("w") as f: - json.dump(check_json(sparsity.to_dict()), f) - # now waveforms and templates - if self.has_waveforms(): - waveform_folder = folder / "waveforms" - waveform_folder.mkdir() - for unit_ind, unit_id in enumerate(self.unit_ids): - waveforms, sampled_indices = self.get_waveforms(unit_id, with_index=True) - if sparsity is not None: - waveforms = waveforms[:, :, sparsity.mask[unit_ind]] - np.save(waveform_folder / f"waveforms_{unit_id}.npy", waveforms) - np.save(waveform_folder / f"sampled_index_{unit_id}.npy", sampled_indices) - elif format == "zarr": - import zarr - from .zarrextractors import get_default_zarr_compressor - - if folder.suffix != ".zarr": - folder = folder.parent / f"{folder.stem}.zarr" - if folder.is_dir() and overwrite: - shutil.rmtree(folder) - assert not folder.is_dir(), "Folder already exists. Use 'overwrite=True'" - zarr_root = zarr.open(str(folder), mode="w") - # write metadata - zarr_root.attrs["params"] = check_json(self._params) - if self.has_recording(): - if self.recording.check_serializability("json"): - rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) - zarr_root.attrs["recording"] = check_json(rec_dict) - if self.sorting.check_serializability("json"): - sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) - zarr_root.attrs["sorting"] = check_json(sort_dict) - else: - warn( - "Sorting object is not json serializable, which might result in downstream errors for " - "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." - ) - recording_info = zarr_root.create_group("recording_info") - recording_info.attrs["recording_attributes"] = check_json(rec_attributes) - if probegroup is not None: - recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) - # save waveforms and templates - compressor = kwargs.get("compressor", None) - if compressor is None: - compressor = get_default_zarr_compressor() - print( - f"Using default zarr compressor: {compressor}. To use a different compressor, use the " - f"'compressor' argument" - ) - for mode, templates in self._template_cache.items(): - templates_save = templates.copy() - if sparsity is not None: - expanded_mask = np.tile(sparsity.mask[:, np.newaxis, :], (1, templates_save.shape[1], 1)) - templates_save[~expanded_mask] = 0 - zarr_root.create_dataset(name=f"templates_{mode}", data=templates_save, compressor=compressor) - if sparsity is not None: - zarr_root.attrs["sparsity"] = check_json(sparsity.to_dict()) - if self.has_waveforms(): - waveform_group = zarr_root.create_group("waveforms") - for unit_ind, unit_id in enumerate(self.unit_ids): - waveforms, sampled_indices = self.get_waveforms(unit_id, with_index=True) - if sparsity is not None: - waveforms = waveforms[:, :, sparsity.mask[unit_ind]] - waveform_group.create_dataset(name=f"waveforms_{unit_id}", data=waveforms, compressor=compressor) - waveform_group.create_dataset( - name=f"sampled_index_{unit_id}", data=sampled_indices, compressor=compressor - ) - - new_we = WaveformExtractor.load(folder) - - # save waveform extensions - for ext_name in self.get_available_extension_names(): - ext = self.load_extension(ext_name) - if sparsity is None: - ext.copy(new_we) - else: - if ext.handle_sparsity: - print( - f"WaveformExtractor.save() : {ext.extension_name} cannot be propagated with sparsity" - f"It is recommended to recompute {ext.extension_name} to properly handle sparsity" - ) - else: - ext.copy(new_we) - - return new_we - - def get_waveforms( - self, - unit_id, - with_index: bool = False, - cache: bool = False, - lazy: bool = True, - sparsity=None, - force_dense: bool = False, - ): - """ - Return waveforms for the specified unit id. - - Parameters - ---------- - unit_id: int or str - Unit id to retrieve waveforms for - with_index: bool, default: False - If True, spike indices of extracted waveforms are returned - cache: bool, default: False - If True, waveforms are cached to the self._waveforms dictionary - lazy: bool, default: True - If True, waveforms are loaded as memmap objects (when format="binary") or Zarr datasets - (when format="zarr"). - If False, waveforms are loaded as np.array objects - sparsity: ChannelSparsity, default: None - Sparsity to apply to the waveforms (if WaveformExtractor is not sparse) - force_dense: bool, default: False - Return dense waveforms even if the waveform extractor is sparse - - Returns - ------- - wfs: np.array - The returned waveform (num_spikes, num_samples, num_channels) - indices: np.array - If "with_index" is True, the spike indices corresponding to the waveforms extracted - """ - assert unit_id in self.sorting.unit_ids, "'unit_id' is invalid" - assert self.has_waveforms(), "Waveforms have been deleted!" - - wfs = self._waveforms.get(unit_id, None) - if wfs is None: - if self.folder is not None: - if self.format == "binary": - waveform_file = self.folder / "waveforms" / f"waveforms_{unit_id}.npy" - if not waveform_file.is_file(): - raise Exception( - "Waveforms not extracted yet: " "please do WaveformExtractor.run_extract_waveforms() first" - ) - if lazy: - wfs = np.load(str(waveform_file), mmap_mode="r") - else: - wfs = np.load(waveform_file) - elif self.format == "zarr": - waveforms_group = self._waveforms_root["waveforms"] - if f"waveforms_{unit_id}" not in waveforms_group.keys(): - raise Exception( - "Waveforms not extracted yet: " "please do WaveformExtractor.run_extract_waveforms() first" - ) - if lazy: - wfs = waveforms_group[f"waveforms_{unit_id}"] - else: - wfs = waveforms_group[f"waveforms_{unit_id}"][:] - if cache: - self._waveforms[unit_id] = wfs - else: - wfs = self._memory_objects["wfs_arrays"][unit_id] - - if sparsity is not None: - assert not self.is_sparse(), "Waveforms are alreayd sparse! Cannot apply an additional sparsity." - wfs = wfs[:, :, sparsity.mask[self.sorting.id_to_index(unit_id)]] - - if force_dense: - num_channels = self.get_num_channels() - dense_wfs = np.zeros((wfs.shape[0], wfs.shape[1], num_channels), dtype=np.float32) - unit_ind = self.sorting.id_to_index(unit_id) - if sparsity is not None: - unit_sparsity = sparsity.mask[unit_ind] - dense_wfs[:, :, unit_sparsity] = wfs - wfs = dense_wfs - elif self.is_sparse(): - unit_sparsity = self.sparsity.mask[unit_ind] - dense_wfs[:, :, unit_sparsity] = wfs - wfs = dense_wfs - - if with_index: - sampled_index = self.get_sampled_indices(unit_id) - return wfs, sampled_index - else: - return wfs - - def get_sampled_indices(self, unit_id): - """ - Return sampled spike indices of extracted waveforms - - Parameters - ---------- - unit_id: int or str - Unit id to retrieve indices for - - Returns - ------- - sampled_indices: np.array - The sampled indices - """ - assert self.has_waveforms(), "Sample indices and waveforms have been deleted!" - if self.folder is not None: - if self.format == "binary": - sampled_index_file = self.folder / "waveforms" / f"sampled_index_{unit_id}.npy" - sampled_index = np.load(sampled_index_file) - elif self.format == "zarr": - waveforms_group = self._waveforms_root["waveforms"] - if f"sampled_index_{unit_id}" not in waveforms_group.keys(): - raise Exception( - "Waveforms not extracted yet: " "please do WaveformExtractor.run_extract_waveforms() first" - ) - sampled_index = waveforms_group[f"sampled_index_{unit_id}"][:] - else: - sampled_index = self._memory_objects["sampled_indices"][unit_id] - return sampled_index - - def get_waveforms_segment(self, segment_index: int, unit_id, sparsity): - """ - Return waveforms from a specified segment and unit_id. - - Parameters - ---------- - segment_index: int - The segment index to retrieve waveforms from - unit_id: int or str - Unit id to retrieve waveforms for - sparsity: ChannelSparsity, default: None - Sparsity to apply to the waveforms (if WaveformExtractor is not sparse) - - Returns - ------- - wfs: np.array - The returned waveform (num_spikes, num_samples, num_channels) - """ - wfs, index_ar = self.get_waveforms(unit_id, with_index=True, sparsity=sparsity) - mask = index_ar["segment_index"] == segment_index - return wfs[mask, :, :] - - def precompute_templates(self, modes=("average", "std", "median", "percentile"), percentile=None) -> None: - """ - Precompute all templates for different "modes": - * average - * std - * median - * percentile - - Parameters - ---------- - modes: list - The modes to compute the templates - percentile: float, default: None - Percentile to use for mode="percentile" - - The results is cached in memory as a 3d ndarray (nunits, nsamples, nchans) - and also saved as an npy file in the folder to avoid recomputation each time. - """ - # TODO : run this in parallel - - unit_ids = self.unit_ids - num_chans = self.get_num_channels() - - mode_names = {} - for mode in modes: - mode_name = mode if mode != "percentile" else f"{mode}_{percentile}" - mode_names[mode] = mode_name - dtype = self._params["dtype"] if mode == "median" else np.float32 - templates = np.zeros((len(unit_ids), self.nsamples, num_chans), dtype=dtype) - self._template_cache[mode_names[mode]] = templates - - for unit_ind, unit_id in enumerate(unit_ids): - wfs = self.get_waveforms(unit_id, cache=False) - if self.sparsity is not None: - mask = self.sparsity.mask[unit_ind] - else: - mask = slice(None) - for mode in modes: - if len(wfs) == 0: - arr = np.zeros(wfs.shape[1:], dtype=wfs.dtype) - elif mode == "median": - arr = np.median(wfs, axis=0) - elif mode == "average": - arr = np.average(wfs, axis=0) - elif mode == "std": - arr = np.std(wfs, axis=0) - elif mode == "percentile": - assert percentile is not None, "percentile must be specified for mode='percentile'" - assert 0 <= percentile <= 100, "percentile must be between 0 and 100 inclusive" - arr = np.percentile(wfs, percentile, axis=0) - else: - raise ValueError(f"'mode' must be in {_possible_template_modes}") - self._template_cache[mode_names[mode]][unit_ind][:, mask] = arr - - for mode in modes: - templates = self._template_cache[mode_names[mode]] - if self.folder is not None and not self.is_read_only(): - template_file = self.folder / f"templates_{mode_names[mode]}.npy" - np.save(template_file, templates) - - def get_all_templates( - self, unit_ids: list | np.array | tuple | None = None, mode="average", percentile: float | None = None - ): - """ - Return templates (average waveforms) for multiple units. - - Parameters - ---------- - unit_ids: list or None - Unit ids to retrieve waveforms for - mode: "average" | "median" | "std" | "percentile", default: "average" - The mode to compute the templates - percentile: float, default: None - Percentile to use for mode="percentile" - - Returns - ------- - templates: np.array - The returned templates (num_units, num_samples, num_channels) - """ - if mode not in self._template_cache: - self.precompute_templates(modes=[mode], percentile=percentile) - mode_name = mode if mode != "percentile" else f"{mode}_{percentile}" - templates = self._template_cache[mode_name] - - if unit_ids is not None: - unit_indices = self.sorting.ids_to_indices(unit_ids) - templates = templates[unit_indices, :, :] - - return np.array(templates) - - def get_template( - self, unit_id, mode="average", sparsity=None, force_dense: bool = False, percentile: float | None = None - ): - """ - Return template (average waveform). - - Parameters - ---------- - unit_id: int or str - Unit id to retrieve waveforms for - mode: "average" | "median" | "std" | "percentile", default: "average" - The mode to compute the template - sparsity: ChannelSparsity, default: None - Sparsity to apply to the waveforms (if WaveformExtractor is not sparse) - force_dense: bool, default: False - Return a dense template even if the waveform extractor is sparse - percentile: float, default: None - Percentile to use for mode="percentile". - Values must be between 0 and 100 inclusive - - Returns - ------- - template: np.array - The returned template (num_samples, num_channels) - """ - assert mode in _possible_template_modes - assert unit_id in self.sorting.unit_ids - - if sparsity is not None: - assert not self.is_sparse(), "Waveforms are already sparse! Cannot apply an additional sparsity." - - unit_ind = self.sorting.id_to_index(unit_id) - - if mode in self._template_cache: - # already in the global cache - templates = self._template_cache[mode] - template = templates[unit_ind, :, :] - if sparsity is not None: - unit_sparsity = sparsity.mask[unit_ind] - elif self.sparsity is not None: - unit_sparsity = self.sparsity.mask[unit_ind] - else: - unit_sparsity = slice(None) - if not force_dense: - template = template[:, unit_sparsity] - return template - - # compute from waveforms - wfs = self.get_waveforms(unit_id, force_dense=force_dense) - if sparsity is not None and not force_dense: - wfs = wfs[:, :, sparsity.mask[unit_ind]] - - if mode == "median": - template = np.median(wfs, axis=0) - elif mode == "average": - template = np.average(wfs, axis=0) - elif mode == "std": - template = np.std(wfs, axis=0) - elif mode == "percentile": - assert percentile is not None, "percentile must be specified for mode='percentile'" - assert 0 <= percentile <= 100, "percentile must be between 0 and 100 inclusive" - template = np.percentile(wfs, percentile, axis=0) - - return np.array(template) - - def get_template_segment(self, unit_id, segment_index, mode="average", sparsity=None): - """ - Return template for the specified unit id computed from waveforms of a specific segment. - - Parameters - ---------- - unit_id: int or str - Unit id to retrieve waveforms for - segment_index: int - The segment index to retrieve template from - mode: "average" | "median" | "std", default: "average" - The mode to compute the template - sparsity: ChannelSparsity, default: None - Sparsity to apply to the waveforms (if WaveformExtractor is not sparse). - - Returns - ------- - template: np.array - The returned template (num_samples, num_channels) - - """ - assert mode in ( - "median", - "average", - "std", - ) - assert unit_id in self.sorting.unit_ids - waveforms_segment = self.get_waveforms_segment(segment_index, unit_id, sparsity=sparsity) - if mode == "median": - return np.median(waveforms_segment, axis=0) - elif mode == "average": - return np.mean(waveforms_segment, axis=0) - elif mode == "std": - return np.std(waveforms_segment, axis=0) - - def sample_spikes(self, seed=None): - nbefore = self.nbefore - nafter = self.nafter - - selected_spikes = select_random_spikes_uniformly( - self.recording, self.sorting, self._params["max_spikes_per_unit"], nbefore, nafter, seed - ) - - # store in a 2 columns (spike_index, segment_index) in a npy file - for unit_id in self.sorting.unit_ids: - n = np.sum([e.size for e in selected_spikes[unit_id]]) - sampled_index = np.zeros(n, dtype=[("spike_index", "int64"), ("segment_index", "int64")]) - pos = 0 - for segment_index in range(self.sorting.get_num_segments()): - inds = selected_spikes[unit_id][segment_index] - sampled_index[pos : pos + inds.size]["spike_index"] = inds - sampled_index[pos : pos + inds.size]["segment_index"] = segment_index - pos += inds.size - - if self.folder is not None: - sampled_index_file = self.folder / "waveforms" / f"sampled_index_{unit_id}.npy" - np.save(sampled_index_file, sampled_index) - else: - self._memory_objects["sampled_indices"][unit_id] = sampled_index - - return selected_spikes - - def run_extract_waveforms(self, seed=None, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - p = self._params - nbefore = self.nbefore - nafter = self.nafter - return_scaled = self.return_scaled - unit_ids = self.sorting.unit_ids - - if self.folder is None: - self._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}} - - selected_spikes = self.sample_spikes(seed=seed) - - selected_spike_times = [] - for segment_index in range(self.sorting.get_num_segments()): - selected_spike_times.append({}) - - for unit_id in self.sorting.unit_ids: - spike_times = self.sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - sel = selected_spikes[unit_id][segment_index] - selected_spike_times[segment_index][unit_id] = spike_times[sel] - - spikes = NumpySorting.from_unit_dict(selected_spike_times, self.sampling_frequency).to_spike_vector() - - if self.folder is not None: - wf_folder = self.folder / "waveforms" - mode = "memmap" - copy = False - else: - wf_folder = None - mode = "shared_memory" - copy = True - - if self.sparsity is None: - sparsity_mask = None - else: - sparsity_mask = self.sparsity.mask - - wfs_arrays = extract_waveforms_to_buffers( - self.recording, - spikes, - unit_ids, - nbefore, - nafter, - mode=mode, - return_scaled=return_scaled, - folder=wf_folder, - dtype=p["dtype"], - sparsity_mask=sparsity_mask, - copy=copy, - **job_kwargs, - ) - if self.folder is None: - self._memory_objects["wfs_arrays"] = wfs_arrays - - -def select_random_spikes_uniformly(recording, sorting, max_spikes_per_unit, nbefore=None, nafter=None, seed=None): - """ - Uniform random selection of spike across segment per units. - - This function does not select spikes near border if nbefore/nafter are not None. - """ - unit_ids = sorting.unit_ids - num_seg = sorting.get_num_segments() - - if seed is not None: - np.random.seed(int(seed)) - - selected_spikes = {} - for unit_id in unit_ids: - # spike per segment - n_per_segment = [sorting.get_unit_spike_train(unit_id, segment_index=i).size for i in range(num_seg)] - cum_sum = [0] + np.cumsum(n_per_segment).tolist() - total = np.sum(n_per_segment) - if max_spikes_per_unit is not None: - if total > max_spikes_per_unit: - global_indices = np.random.choice(total, size=max_spikes_per_unit, replace=False) - global_indices = np.sort(global_indices) - else: - global_indices = np.arange(total) - else: - global_indices = np.arange(total) - sel_spikes = [] - for segment_index in range(num_seg): - in_segment = (global_indices >= cum_sum[segment_index]) & (global_indices < cum_sum[segment_index + 1]) - indices = global_indices[in_segment] - cum_sum[segment_index] - - if max_spikes_per_unit is not None: - # clean border when sub selection - assert nafter is not None - spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - sampled_spike_times = spike_times[indices] - num_samples = recording.get_num_samples(segment_index=segment_index) - mask = (sampled_spike_times >= nbefore) & (sampled_spike_times < (num_samples - nafter)) - indices = indices[mask] - - sel_spikes.append(indices) - selected_spikes[unit_id] = sel_spikes - return selected_spikes - - -def extract_waveforms( - recording, - sorting, - folder=None, - mode="folder", - precompute_template=("average",), - ms_before=1.0, - ms_after=2.0, - max_spikes_per_unit=500, - overwrite=False, - return_scaled=True, - dtype=None, - sparse=True, - sparsity=None, - sparsity_temp_folder=None, - num_spikes_for_sparsity=100, - unit_batch_size=200, - allow_unfiltered=False, - use_relative_path=False, - seed=None, - load_if_exists=None, - **kwargs, -): - """ - Extracts waveform on paired Recording-Sorting objects. - Waveforms can be persistent on disk (`mode`="folder") or in-memory (`mode`="memory"). - By default, waveforms are extracted on a subset of the spikes (`max_spikes_per_unit`) and on all channels (dense). - If the `sparse` parameter is set to True, a sparsity is estimated using a small number of spikes - (`num_spikes_for_sparsity`) and waveforms are extracted and saved in sparse mode. - - - Parameters - ---------- - recording: Recording - The recording object - sorting: Sorting - The sorting object - folder: str or Path or None, default: None - The folder where waveforms are cached - mode: "folder" | "memory, default: "folder" - The mode to store waveforms. If "folder", waveforms are stored on disk in the specified folder. - The "folder" argument must be specified in case of mode "folder". - If "memory" is used, the waveforms are stored in RAM. Use this option carefully! - precompute_template: None or list, default: ["average"] - Precompute average/std/median for template. If None, no templates are precomputed - ms_before: float, default: 1.0 - Time in ms to cut before spike peak - ms_after: float, default: 2.0 - Time in ms to cut after spike peak - max_spikes_per_unit: int or None, default: 500 - Number of spikes per unit to extract waveforms from - Use None to extract waveforms for all spikes - overwrite: bool, default: False - If True and "folder" exists, the folder is removed and waveforms are recomputed - Otherwise an error is raised. - return_scaled: bool, default: True - If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV - dtype: dtype or None, default: None - Dtype of the output waveforms. If None, the recording dtype is maintained - sparse: bool, default: True - If True, before extracting all waveforms the `precompute_sparsity()` function is run using - a few spikes to get an estimate of dense templates to create a ChannelSparsity object. - Then, the waveforms will be sparse at extraction time, which saves a lot of memory. - When True, you must some provide kwargs handle `precompute_sparsity()` to control the kind of - sparsity you want to apply (by radius, by best channels, ...). - sparsity: ChannelSparsity or None, default: None - The sparsity used to compute waveforms. If this is given, `sparse` is ignored. Default None. - sparsity_temp_folder: str or Path or None, default: None - If sparse is True, this is the temporary folder where the dense waveforms are temporarily saved. - If None, dense waveforms are extracted in memory in batches (which can be controlled by the `unit_batch_size` - parameter. With a large number of units (e.g., > 400), it is advisable to use a temporary folder. - num_spikes_for_sparsity: int, default: 100 - The number of spikes to use to estimate sparsity (if sparse=True). - unit_batch_size: int, default: 200 - The number of units to process at once when extracting dense waveforms (if sparse=True and sparsity_temp_folder - is None). - allow_unfiltered: bool - If true, will accept an allow_unfiltered recording. - use_relative_path: bool, default: False - If True, the recording and sorting paths are relative to the waveforms folder. - This allows portability of the waveform folder provided that the relative paths are the same, - but forces all the data files to be in the same drive. - seed: int or None, default: None - Random seed for spike selection - - sparsity kwargs: - {} - - - job kwargs: - {} - - - Returns - ------- - we: WaveformExtractor - The WaveformExtractor object - - Examples - -------- - >>> import spikeinterface as si - - >>> # Extract dense waveforms and save to disk - >>> we = si.extract_waveforms(recording, sorting, folder="waveforms") - - >>> # Extract dense waveforms with parallel processing and save to disk - >>> job_kwargs = dict(n_jobs=8, chunk_duration="1s", progress_bar=True) - >>> we = si.extract_waveforms(recording, sorting, folder="waveforms", **job_kwargs) - - >>> # Extract dense waveforms on all spikes - >>> we = si.extract_waveforms(recording, sorting, folder="waveforms-all", max_spikes_per_unit=None) - - >>> # Extract dense waveforms in memory - >>> we = si.extract_waveforms(recording, sorting, folder=None, mode="memory") - - >>> # Extract sparse waveforms (with radius-based sparsity of 50um) and save to disk - >>> we = si.extract_waveforms(recording, sorting, folder="waveforms-sparse", mode="folder", - >>> sparse=True, num_spikes_for_sparsity=100, method="radius", radius_um=50) - """ - if load_if_exists is None: - load_if_exists = False - else: - warn("load_if_exists=True/false is deprcated. Use load_waveforms() instead.", DeprecationWarning, stacklevel=2) - - estimate_kwargs, job_kwargs = split_job_kwargs(kwargs) - - assert ( - recording.has_channel_location() - ), "Recording must have a probe or channel location to extract waveforms. Use the `set_probe()` or `set_dummy_probe_from_locations()` methods." - - if mode == "folder": - assert folder is not None - folder = Path(folder) - assert not (overwrite and load_if_exists), "Use either 'overwrite=True' or 'load_if_exists=True'" - if overwrite and folder.is_dir(): - shutil.rmtree(folder) - if load_if_exists and folder.is_dir(): - we = WaveformExtractor.load_from_folder(folder) - return we - - if sparsity is not None: - assert isinstance(sparsity, ChannelSparsity), "'sparsity' must be a ChannelSparsity object" - unit_id_to_channel_ids = sparsity.unit_id_to_channel_ids - assert all(u in sorting.unit_ids for u in unit_id_to_channel_ids), "Invalid unit ids in sparsity" - for channels in unit_id_to_channel_ids.values(): - assert all(ch in recording.channel_ids for ch in channels), "Invalid channel ids in sparsity" - elif sparse: - sparsity = precompute_sparsity( - recording, - sorting, - ms_before=ms_before, - ms_after=ms_after, - num_spikes_for_sparsity=num_spikes_for_sparsity, - unit_batch_size=unit_batch_size, - temp_folder=sparsity_temp_folder, - allow_unfiltered=allow_unfiltered, - **estimate_kwargs, - **job_kwargs, - ) - else: - sparsity = None - - we = WaveformExtractor.create( - recording, - sorting, - folder, - mode=mode, - use_relative_path=use_relative_path, - allow_unfiltered=allow_unfiltered, - sparsity=sparsity, - ) - we.set_params( - ms_before=ms_before, - ms_after=ms_after, - max_spikes_per_unit=max_spikes_per_unit, - dtype=dtype, - return_scaled=return_scaled, - ) - we.run_extract_waveforms(seed=seed, **job_kwargs) - - if precompute_template is not None: - we.precompute_templates(modes=precompute_template) - - return we - - -extract_waveforms.__doc__ = extract_waveforms.__doc__.format(_sparsity_doc, _shared_job_kwargs_doc) - - -def load_waveforms(folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None) -> WaveformExtractor: - """ - Load a waveform extractor object from disk. - - Parameters - ---------- - folder : str or Path - The folder / zarr folder where the waveform extractor is stored - with_recording : bool, default: True - If True, the recording is loaded. - If False, the WaveformExtractor object in recordingless mode. - sorting : BaseSorting, default: None - If passed, the sorting object associated to the waveform extractor - - Returns - ------- - we: WaveformExtractor - The loaded waveform extractor - """ - return WaveformExtractor.load(folder, with_recording, sorting) - - -def precompute_sparsity( - recording, - sorting, - num_spikes_for_sparsity=100, - unit_batch_size=200, - ms_before=2.0, - ms_after=3.0, - temp_folder=None, - allow_unfiltered=False, - **kwargs, -): - """ - Pre-estimate sparsity with few spikes and by unit batch. - This equivalent to compute a dense waveform extractor (with all units at once) and so - can be less memory agressive. - - Parameters - ---------- - recording: Recording - The recording object - sorting: Sorting - The sorting object - num_spikes_for_sparsity: int, default: 100 - How many spikes per unit - unit_batch_size: int or None, default: 200 - How many units are extracted at once to estimate sparsity. - If None then they are extracted all at one (but uses a lot of memory) - ms_before: float, default: 2.0 - Time in ms to cut before spike peak - ms_after: float, default: 3.0 - Time in ms to cut after spike peak - temp_folder: str or Path or None, default: None - If provided, dense waveforms are saved to this temporary folder - allow_unfiltered: bool, default: False - If true, will accept an allow_unfiltered recording. - - kwargs for sparsity strategy: - {} - - - job kwargs: - {} - - Returns - ------- - sparsity : ChannelSparsity - The estimated sparsity. - """ - - sparse_kwargs, job_kwargs = split_job_kwargs(kwargs) - - unit_ids = sorting.unit_ids - channel_ids = recording.channel_ids - - if unit_batch_size is None: - unit_batch_size = len(unit_ids) - - if temp_folder is None: - mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") - nloop = int(np.ceil((unit_ids.size / unit_batch_size))) - for i in range(nloop): - sl = slice(i * unit_batch_size, (i + 1) * unit_batch_size) - local_ids = unit_ids[sl] - local_sorting = sorting.select_units(local_ids) - local_we = extract_waveforms( - recording, - local_sorting, - folder=None, - mode="memory", - precompute_template=("average",), - ms_before=ms_before, - ms_after=ms_after, - max_spikes_per_unit=num_spikes_for_sparsity, - return_scaled=False, - allow_unfiltered=allow_unfiltered, - sparse=False, - **job_kwargs, - ) - local_sparsity = compute_sparsity(local_we, **sparse_kwargs) - mask[sl, :] = local_sparsity.mask - else: - temp_folder = Path(temp_folder) - assert ( - not temp_folder.is_dir() - ), "Temporary folder for pre-computing sparsity already exists. Provide a non-existing folder" - dense_we = extract_waveforms( - recording, - sorting, - folder=temp_folder, - precompute_template=("average",), - ms_before=ms_before, - ms_after=ms_after, - max_spikes_per_unit=num_spikes_for_sparsity, - return_scaled=False, - allow_unfiltered=allow_unfiltered, - sparse=False, - **job_kwargs, - ) - sparsity = compute_sparsity(dense_we, **sparse_kwargs) - mask = sparsity.mask - shutil.rmtree(temp_folder) - - sparsity = ChannelSparsity(mask, unit_ids, channel_ids) - return sparsity - - -precompute_sparsity.__doc__ = precompute_sparsity.__doc__.format(_sparsity_doc, _shared_job_kwargs_doc) - - -class BaseWaveformExtractorExtension: - """ - This the base class to extend the waveform extractor. - It handles persistency to disk any computations related - to a waveform extractor. - - For instance: - * principal components - * spike amplitudes - * quality metrics - - The design is done via a `WaveformExtractor.register_extension(my_extension_class)`, - so that only imported modules can be used as *extension*. - - It also enables any custum computation on top on waveform extractor to be implemented by the user. - - An extension needs to inherit from this class and implement some abstract methods: - * _reset - * _set_params - * _run - - The subclass must also save to the `self.extension_folder` any file that needs - to be reloaded when calling `_load_extension_data` - - The subclass must also set an `extension_name` attribute which is not None by default. - """ - - # must be set in inherited in subclass - extension_name = None - handle_sparsity = False - - def __init__(self, waveform_extractor): - self._waveform_extractor = weakref.ref(waveform_extractor) - - if self.waveform_extractor.folder is not None: - self.folder = self.waveform_extractor.folder - self.format = self.waveform_extractor.format - if self.format == "binary": - self.extension_folder = self.folder / self.extension_name - if not self.extension_folder.is_dir(): - if self.waveform_extractor.is_read_only(): - warn( - "WaveformExtractor: cannot save extension in read-only mode. " - "Extension will be saved in memory." - ) - self.format = "memory" - self.extension_folder = None - self.folder = None - else: - self.extension_folder.mkdir() - - else: - import zarr - - mode = "r+" if not self.waveform_extractor.is_read_only() else "r" - zarr_root = zarr.open(self.folder, mode=mode) - if self.extension_name not in zarr_root.keys(): - if self.waveform_extractor.is_read_only(): - warn( - "WaveformExtractor: cannot save extension in read-only mode. " - "Extension will be saved in memory." - ) - self.format = "memory" - self.extension_folder = None - self.folder = None - else: - self.extension_group = zarr_root.create_group(self.extension_name) - else: - self.extension_group = zarr_root[self.extension_name] - else: - self.format = "memory" - self.extension_folder = None - self.folder = None - self._extension_data = dict() - self._params = None - - # register - self.waveform_extractor._loaded_extensions[self.extension_name] = self - - @property - def waveform_extractor(self): - # Important : to avoid the WaveformExtractor referencing a BaseWaveformExtractorExtension - # and BaseWaveformExtractorExtension referencing a WaveformExtractor - # we need a weakref. Otherwise the garbage collector is not working properly - # and so the WaveformExtractor + its recording are still alive even after deleting explicitly - # the WaveformExtractor which makes it impossible to delete the folder! - we = self._waveform_extractor() - if we is None: - raise ValueError(f"The extension {self.extension_name} has lost its WaveformExtractor") - return we - - @classmethod - def load(cls, folder, waveform_extractor): - folder = Path(folder) - assert folder.is_dir(), "Waveform folder does not exists" - if folder.suffix == ".zarr": - params = cls.load_params_from_zarr(folder) - else: - params = cls.load_params_from_folder(folder) - - if "sparsity" in params and params["sparsity"] is not None: - params["sparsity"] = ChannelSparsity.from_dict(params["sparsity"]) - - # if waveform_extractor is None: - # waveform_extractor = WaveformExtractor.load(folder) - - # make instance with params - ext = cls(waveform_extractor) - ext._params = params - ext._load_extension_data() - - return ext - - @classmethod - def load_params_from_zarr(cls, folder): - """ - Load extension params from Zarr folder. - 'folder' is the waveform extractor zarr folder. - """ - import zarr - - zarr_root = zarr.open(folder, mode="r+") - assert cls.extension_name in zarr_root.keys(), ( - f"WaveformExtractor: extension {cls.extension_name} " f"is not in folder {folder}" - ) - extension_group = zarr_root[cls.extension_name] - assert "params" in extension_group.attrs, f"No params file in extension {cls.extension_name} folder" - params = extension_group.attrs["params"] - - return params - - @classmethod - def load_params_from_folder(cls, folder): - """ - Load extension params from folder. - 'folder' is the waveform extractor folder. - """ - ext_folder = Path(folder) / cls.extension_name - assert ext_folder.is_dir(), f"WaveformExtractor: extension {cls.extension_name} is not in folder {folder}" - - params_file = ext_folder / "params.json" - assert params_file.is_file(), f"No params file in extension {cls.extension_name} folder" - - with open(str(params_file), "r") as f: - params = json.load(f) - - return params - - # use load instead - def _load_extension_data(self): - if self.format == "binary": - for ext_data_file in self.extension_folder.iterdir(): - if ext_data_file.name == "params.json": - continue - ext_data_name = ext_data_file.stem - if ext_data_file.suffix == ".json": - ext_data = json.load(ext_data_file.open("r")) - elif ext_data_file.suffix == ".npy": - # The lazy loading of an extension is complicated because if we compute again - # and have a link to the old buffer on windows then it fails - # ext_data = np.load(ext_data_file, mmap_mode="r") - # so we go back to full loading - ext_data = np.load(ext_data_file) - elif ext_data_file.suffix == ".csv": - import pandas as pd - - ext_data = pd.read_csv(ext_data_file, index_col=0) - elif ext_data_file.suffix == ".pkl": - ext_data = pickle.load(ext_data_file.open("rb")) - else: - continue - self._extension_data[ext_data_name] = ext_data - elif self.format == "zarr": - for ext_data_name in self.extension_group.keys(): - ext_data_ = self.extension_group[ext_data_name] - if "dict" in ext_data_.attrs: - ext_data = ext_data_[0] - elif "dataframe" in ext_data_.attrs: - import xarray - - ext_data = xarray.open_zarr( - ext_data_.store, group=f"{self.extension_group.name}/{ext_data_name}" - ).to_pandas() - ext_data.index.rename("", inplace=True) - else: - ext_data = ext_data_ - self._extension_data[ext_data_name] = ext_data - - def run(self, **kwargs): - self._run(**kwargs) - self._save(**kwargs) - - def _run(self, **kwargs): - # must be implemented in subclass - # must populate the self._extension_data dictionary - raise NotImplementedError - - def save(self, **kwargs): - self._save(**kwargs) - - def _save(self, **kwargs): - # Only save if not read only - if self.waveform_extractor.is_read_only(): - return - - # delete already saved - self._reset_folder() - self._save_params() - - if self.format == "binary": - import pandas as pd - - for ext_data_name, ext_data in self._extension_data.items(): - if isinstance(ext_data, dict): - with (self.extension_folder / f"{ext_data_name}.json").open("w") as f: - json.dump(ext_data, f) - elif isinstance(ext_data, np.ndarray): - np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True) - else: - try: - with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f: - pickle.dump(ext_data, f) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") - elif self.format == "zarr": - from .zarrextractors import get_default_zarr_compressor - import pandas as pd - import numcodecs - - compressor = kwargs.get("compressor", None) - if compressor is None: - compressor = get_default_zarr_compressor() - for ext_data_name, ext_data in self._extension_data.items(): - if ext_data_name in self.extension_group: - del self.extension_group[ext_data_name] - if isinstance(ext_data, dict): - self.extension_group.create_dataset( - name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() - ) - self.extension_group[ext_data_name].attrs["dict"] = True - elif isinstance(ext_data, np.ndarray): - self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_xarray().to_zarr( - store=self.extension_group.store, - group=f"{self.extension_group.name}/{ext_data_name}", - mode="a", - ) - self.extension_group[ext_data_name].attrs["dataframe"] = True - else: - try: - self.extension_group.create_dataset( - name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() - ) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") - - def _reset_folder(self): - """ - Delete the extension in folder (binary or zarr) and create an empty one. - """ - if self.format == "binary" and self.extension_folder is not None: - if self.extension_folder.is_dir(): - shutil.rmtree(self.extension_folder) - self.extension_folder.mkdir() - elif self.format == "zarr": - import zarr - - zarr_root = zarr.open(self.folder, mode="r+") - self.extension_group = zarr_root.create_group(self.extension_name, overwrite=True) - - def reset(self): - """ - Reset the waveform extension. - Delete the sub folder and create a new empty one. - """ - self._reset_folder() - - self._params = None - self._extension_data = dict() - - def select_units(self, unit_ids, new_waveform_extractor): - new_extension = self.__class__(new_waveform_extractor) - new_extension.set_params(**self._params) - new_extension_data = self._select_extension_data(unit_ids=unit_ids) - new_extension._extension_data = new_extension_data - new_extension._save() - - def copy(self, new_waveform_extractor): - new_extension = self.__class__(new_waveform_extractor) - new_extension.set_params(**self._params) - new_extension._extension_data = self._extension_data - new_extension._save() - - def _select_extension_data(self, unit_ids): - # must be implemented in subclass - raise NotImplementedError - - def set_params(self, **params): - """ - Set parameters for the extension and - make it persistent in json. - """ - params = self._set_params(**params) - self._params = params - - if self.waveform_extractor.is_read_only(): - return - - self._save_params() - - def _save_params(self): - params_to_save = self._params.copy() - if "sparsity" in params_to_save and params_to_save["sparsity"] is not None: - assert isinstance( - params_to_save["sparsity"], ChannelSparsity - ), "'sparsity' parameter must be a ChannelSparsity object!" - params_to_save["sparsity"] = params_to_save["sparsity"].to_dict() - if self.format == "binary": - if self.extension_folder is not None: - param_file = self.extension_folder / "params.json" - param_file.write_text(json.dumps(check_json(params_to_save), indent=4), encoding="utf8") - elif self.format == "zarr": - self.extension_group.attrs["params"] = check_json(params_to_save) - - def _set_params(self, **params): - # must be implemented in subclass - # must return a cleaned version of params dict - raise NotImplementedError - - @staticmethod - def get_extension_function(): - # must be implemented in subclass - # must return extension function - raise NotImplementedError From cc7ed9d5181daa26acc7271899f25791c4829a69 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 15 Feb 2024 17:45:28 +0100 Subject: [PATCH 096/192] minor fixes --- src/spikeinterface/core/sparsity.py | 3 --- .../sortingcomponents/tests/test_template_matching.py | 10 +++++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index cbbedfec6c..06e900a3c5 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -159,9 +159,6 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nd or a single sparsified waveform (template) with shape (num_samples, num_active_channels). """ - if self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): - return waveforms - non_zero_indices = self.unit_id_to_channel_indices[unit_id] sparsified_waveforms = waveforms[..., non_zero_indices] diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index e9d1017be7..6bb59c6c4e 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -35,7 +35,7 @@ def test_find_spikes_from_templates(method, sorting_result): # assert num_waveforms != 0 templates = sorting_result.get_extension("fast_templates").get_data(outputs="Templates") - sparsity = compute_sparsity(sorting_result, method="snr", threshold=2) + sparsity = compute_sparsity(sorting_result, method="snr", threshold=0.5) templates = templates.to_sparse(sparsity) noise_levels = sorting_result.get_extension("noise_levels").get_data() @@ -59,7 +59,7 @@ def test_find_spikes_from_templates(method, sorting_result): - # DEBUG = False + # DEBUG = True # if DEBUG: # import matplotlib.pyplot as plt @@ -85,9 +85,9 @@ def test_find_spikes_from_templates(method, sorting_result): if __name__ == "__main__": sorting_result = get_sorting_result() # method = "naive" - # method = "tdc-peeler" + # method = "tdc-peeler" # method = "circus" - method = "circus-omp-svd" - # method = "wobble" + # method = "circus-omp-svd" + method = "wobble" test_find_spikes_from_templates(method, sorting_result) From a42fa70aa86eee83a1495f78eec9e01096302ef5 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 16 Feb 2024 13:44:22 +0100 Subject: [PATCH 097/192] Some clean --- src/spikeinterface/core/sortingresult.py | 86 +++++++++---------- .../core/tests/test_sortingresult.py | 80 ++++++++--------- .../tests/common_extension_tests.py | 4 +- .../postprocessing/unit_localization.py | 2 +- 4 files changed, 86 insertions(+), 86 deletions(-) diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortingresult.py index 9d3b1a0aaa..4de643a0db 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortingresult.py @@ -76,21 +76,21 @@ def start_sorting_result( >>> import spikeinterface as si >>> # Extract dense waveforms and save to disk with binary_folder format. - >>> sortres = si.start_sorting_result(sorting, recording, format="binary_folder", folder="/path/to_my/result") + >>> sorting_result = si.start_sorting_result(sorting, recording, format="binary_folder", folder="/path/to_my/result") >>> # Can be reload - >>> sortres = si.load_sorting_result(folder="/path/to_my/result") + >>> sorting_result = si.load_sorting_result(folder="/path/to_my/result") >>> # Can run extension - >>> sortres = si.compute("unit_locations", ...) + >>> sorting_result = si.compute("unit_locations", ...) >>> # Can be copy to another format (extensions are propagated) - >>> sortres2 = sortres.save_as(format="memory") - >>> sortres3 = sortres.save_as(format="zarr", folder="/path/to_my/result.zarr") + >>> sorting_result2 = sorting_result.save_as(format="memory") + >>> sorting_result3 = sorting_result.save_as(format="zarr", folder="/path/to_my/result.zarr") >>> # Can make a copy with a subset of units (extensions are propagated for the unit subset) - >>> sortres4 = sortres.select_units(unit_ids=sorting.units_ids[:5], format="memory") - >>> sortres5 = sortres.select_units(unit_ids=sorting.units_ids[:5], format="binary_folder", folder="/result_5units") + >>> sorting_result4 = sorting_result.select_units(unit_ids=sorting.units_ids[:5], format="memory") + >>> sorting_result5 = sorting_result.select_units(unit_ids=sorting.units_ids[:5], format="binary_folder", folder="/result_5units") """ # handle sparsity @@ -209,19 +209,19 @@ def create( check_probe_do_not_overlap(all_probes) if format == "memory": - sortres = cls.create_memory(sorting, recording, sparsity, rec_attributes=None) + sorting_result = cls.create_memory(sorting, recording, sparsity, rec_attributes=None) elif format == "binary_folder": cls.create_binary_folder(folder, sorting, recording, sparsity, rec_attributes=None) - sortres = cls.load_from_binary_folder(folder, recording=recording) - sortres.folder = folder + sorting_result = cls.load_from_binary_folder(folder, recording=recording) + sorting_result.folder = folder elif format == "zarr": cls.create_zarr(folder, sorting, recording, sparsity, rec_attributes=None) - sortres = cls.load_from_zarr(folder, recording=recording) - sortres.folder = folder + sorting_result = cls.load_from_zarr(folder, recording=recording) + sorting_result.folder = folder else: raise ValueError("SortingResult.create: wrong format") - return sortres + return sorting_result @classmethod def load(cls, folder, recording=None, load_extensions=True, format="auto"): @@ -240,16 +240,16 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto"): format = "binary_folder" if format == "binary_folder": - sortres = SortingResult.load_from_binary_folder(folder, recording=recording) + sorting_result = SortingResult.load_from_binary_folder(folder, recording=recording) elif format == "zarr": - sortres = SortingResult.load_from_zarr(folder, recording=recording) + sorting_result = SortingResult.load_from_zarr(folder, recording=recording) - sortres.folder = folder + sorting_result.folder = folder if load_extensions: - sortres.load_all_saved_extension() + sorting_result.load_all_saved_extension() - return sortres + return sorting_result @classmethod def create_memory(cls, sorting, recording, sparsity, rec_attributes): @@ -265,10 +265,10 @@ def create_memory(cls, sorting, recording, sparsity, rec_attributes): # a copy of sorting is created directly in shared memory format to avoid further duplication of spikes. sorting_copy = SharedMemorySorting.from_sorting(sorting, with_metadata=True) - sortres = SortingResult( + sorting_result = SortingResult( sorting=sorting_copy, recording=recording, rec_attributes=rec_attributes, format="memory", sparsity=sparsity ) - return sortres + return sorting_result @classmethod def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attributes): @@ -382,7 +382,7 @@ def load_from_binary_folder(cls, folder, recording=None): else: random_spikes_indices = None - sortres = SortingResult( + sorting_result = SortingResult( sorting=sorting, recording=recording, rec_attributes=rec_attributes, @@ -391,7 +391,7 @@ def load_from_binary_folder(cls, folder, recording=None): random_spikes_indices=random_spikes_indices, ) - return sortres + return sorting_result def _get_zarr_root(self, mode="r+"): import zarr @@ -528,7 +528,7 @@ def load_from_zarr(cls, folder, recording=None): else: random_spikes_indices = None - sortres = SortingResult( + sorting_result = SortingResult( sorting=sorting, recording=recording, rec_attributes=rec_attributes, @@ -537,7 +537,7 @@ def load_from_zarr(cls, folder, recording=None): random_spikes_indices=random_spikes_indices, ) - return sortres + return sorting_result def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingResult": """ @@ -570,27 +570,27 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> if format == "memory": # This make a copy of actual SortingResult - new_sortres = SortingResult.create_memory(sorting_provenance, recording, sparsity, self.rec_attributes) + new_sorting_result = SortingResult.create_memory(sorting_provenance, recording, sparsity, self.rec_attributes) elif format == "binary_folder": # create a new folder assert folder is not None, "For format='binary_folder' folder must be provided" SortingResult.create_binary_folder(folder, sorting_provenance, recording, sparsity, self.rec_attributes) - new_sortres = SortingResult.load_from_binary_folder(folder, recording=recording) - new_sortres.folder = folder + new_sorting_result = SortingResult.load_from_binary_folder(folder, recording=recording) + new_sorting_result.folder = folder elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" SortingResult.create_zarr(folder, sorting_provenance, recording, sparsity, self.rec_attributes) - new_sortres = SortingResult.load_from_zarr(folder, recording=recording) - new_sortres.folder = folder + new_sorting_result = SortingResult.load_from_zarr(folder, recording=recording) + new_sorting_result.folder = folder else: raise ValueError("SortingResult.save: wrong format") # propagate random_spikes_indices is already done if self.random_spikes_indices is not None: if unit_ids is None: - new_sortres.random_spikes_indices = self.random_spikes_indices.copy() + new_sorting_result.random_spikes_indices = self.random_spikes_indices.copy() else: # more tricky spikes = self.sorting.to_spike_vector() @@ -601,17 +601,17 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> selected_mask = np.zeros(spikes.size, dtype=bool) selected_mask[self.random_spikes_indices] = True - new_sortres.random_spikes_indices = np.flatnonzero(selected_mask[keep_spike_mask]) + new_sorting_result.random_spikes_indices = np.flatnonzero(selected_mask[keep_spike_mask]) # save it - new_sortres._save_random_spikes_indices() + new_sorting_result._save_random_spikes_indices() # make a copy of extensions # note that the copy of extension handle itself the slicing of units when necessary and also the saveing for extension_name, extension in self.extensions.items(): - new_ext = new_sortres.extensions[extension_name] = extension.copy(new_sortres, unit_ids=unit_ids) + new_ext = new_sorting_result.extensions[extension_name] = extension.copy(new_sorting_result, unit_ids=unit_ids) - return new_sortres + return new_sorting_result def save_as(self, format="memory", folder=None) -> "SortingResult": """ @@ -811,15 +811,18 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): Returns ------- - sorting_result: SortingResult - The SortingResult object + result_extension: ResultExtension + Return the extension instance. Examples -------- - >>> extension = sortres.compute("waveforms", **some_params) - >>> extension = sortres.compute_one_extension("waveforms", **some_params) + >>> Note that the return is the instance extension. + >>> extension = sorting_result.compute("waveforms", **some_params) + >>> extension = sorting_result.compute_one_extension("waveforms", **some_params) >>> wfs = extension.data["waveforms"] + >>> # Note this can be be done in the old way style BUT the return is not the same it return directly data + >>> wfs = compute_waveforms(sorting_result, **some_params) """ @@ -848,10 +851,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): self.extensions[extension_name] = extension_instance - # TODO : need discussion return extension_instance - # OR - return extension_instance.data def compute_several_extensions(self, extensions, save=True, **job_kwargs): """ @@ -873,8 +873,8 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): Examples -------- - >>> sortres.compute({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std", ]} }) - >>> sortres.compute_several_extensions({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std"]}}) + >>> sorting_result.compute({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std", ]} }) + >>> sorting_result.compute_several_extensions({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std"]}}) """ # TODO this is a simple implementation diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortingresult.py index a3c204364d..e6d7872396 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortingresult.py @@ -30,11 +30,11 @@ def get_dataset(): def test_SortingResult_memory(): recording, sorting = get_dataset() - sortres = start_sorting_result(sorting, recording, format="memory", sparse=False, sparsity=None) - _check_sorting_results(sortres, sorting) + sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_results(sorting_result, sorting) - sortres = start_sorting_result(sorting, recording, format="memory", sparse=True, sparsity=None) - _check_sorting_results(sortres, sorting) + sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True, sparsity=None) + _check_sorting_results(sorting_result, sorting) def test_SortingResult_binary_folder(): @@ -44,11 +44,11 @@ def test_SortingResult_binary_folder(): if folder.exists(): shutil.rmtree(folder) - sortres = start_sorting_result( + sorting_result = start_sorting_result( sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None ) - sortres = load_sorting_result(folder, format="auto") - _check_sorting_results(sortres, sorting) + sorting_result = load_sorting_result(folder, format="auto") + _check_sorting_results(sorting_result, sorting) def test_SortingResult_zarr(): @@ -58,46 +58,46 @@ def test_SortingResult_zarr(): if folder.exists(): shutil.rmtree(folder) - sortres = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) - sortres = load_sorting_result(folder, format="auto") - _check_sorting_results(sortres, sorting) + sorting_result = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) + sorting_result = load_sorting_result(folder, format="auto") + _check_sorting_results(sorting_result, sorting) -def _check_sorting_results(sortres, original_sorting): +def _check_sorting_results(sorting_result, original_sorting): print() - print(sortres) + print(sorting_result) register_result_extension(DummyResultExtension) - assert "channel_ids" in sortres.rec_attributes - assert "sampling_frequency" in sortres.rec_attributes - assert "num_samples" in sortres.rec_attributes + assert "channel_ids" in sorting_result.rec_attributes + assert "sampling_frequency" in sorting_result.rec_attributes + assert "num_samples" in sorting_result.rec_attributes - probe = sortres.get_probe() - sparsity = sortres.sparsity + probe = sorting_result.get_probe() + sparsity = sorting_result.sparsity # compute - sortres.compute("dummy", param1=5.5) + sorting_result.compute("dummy", param1=5.5) # equivalent - compute_dummy(sortres, param1=5.5) - ext = sortres.get_extension("dummy") + compute_dummy(sorting_result, param1=5.5) + ext = sorting_result.get_extension("dummy") assert ext is not None assert ext.params["param1"] == 5.5 - print(sortres) + print(sorting_result) # recompute - sortres.compute("dummy", param1=5.5) + sorting_result.compute("dummy", param1=5.5) # and delete - sortres.delete_extension("dummy") - ext = sortres.get_extension("dummy") + sorting_result.delete_extension("dummy") + ext = sorting_result.get_extension("dummy") assert ext is None - assert sortres.has_recording() + assert sorting_result.has_recording() - if sortres.random_spikes_indices is None: - sortres.select_random_spikes(max_spikes_per_unit=10, seed=2205) - assert sortres.random_spikes_indices is not None - assert sortres.random_spikes_indices.size == 10 * sortres.sorting.unit_ids.size + if sorting_result.random_spikes_indices is None: + sorting_result.select_random_spikes(max_spikes_per_unit=10, seed=2205) + assert sorting_result.random_spikes_indices is not None + assert sorting_result.random_spikes_indices.size == 10 * sorting_result.sorting.unit_ids.size # save to several format for format in ("memory", "binary_folder", "zarr"): @@ -112,13 +112,13 @@ def _check_sorting_results(sortres, original_sorting): folder = None # compute one extension to check the save - sortres.compute("dummy") + sorting_result.compute("dummy") - sortres2 = sortres.save_as(format=format, folder=folder) - ext = sortres2.get_extension("dummy") + sorting_result2 = sorting_result.save_as(format=format, folder=folder) + ext = sorting_result2.get_extension("dummy") assert ext is not None - data = sortres2.get_extension("dummy").data + data = sorting_result2.get_extension("dummy").data assert "result_one" in data assert data["result_two"].size == original_sorting.to_spike_vector().size @@ -134,19 +134,19 @@ def _check_sorting_results(sortres, original_sorting): else: folder = None # compute one extension to check the slice - sortres.compute("dummy") + sorting_result.compute("dummy") keep_unit_ids = original_sorting.unit_ids[::2] - sortres2 = sortres.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) + sorting_result2 = sorting_result.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) # check that random_spikes_indices are remmaped - assert sortres2.random_spikes_indices is not None - some_spikes = sortres2.sorting.to_spike_vector()[sortres2.random_spikes_indices] + assert sorting_result2.random_spikes_indices is not None + some_spikes = sorting_result2.sorting.to_spike_vector()[sorting_result2.random_spikes_indices] assert np.array_equal(np.unique(some_spikes["unit_index"]), np.arange(keep_unit_ids.size)) # check propagation of result data and correct sligin - assert np.array_equal(keep_unit_ids, sortres2.unit_ids) - data = sortres2.get_extension("dummy").data - assert data["result_one"] == sortres.get_extension("dummy").data["result_one"] + assert np.array_equal(keep_unit_ids, sorting_result2.unit_ids) + data = sorting_result2.get_extension("dummy").data + assert data["result_one"] == sorting_result.get_extension("dummy").data["result_one"] # unit 1, 3, ... should be removed assert np.all(~np.isin(data["result_two"], [1, 3])) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 214d0c3f16..3520231a9d 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -54,9 +54,9 @@ def get_sorting_result(recording, sorting, format="memory", sparsity=None, name= if folder and folder.exists(): shutil.rmtree(folder) - sortres = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity) + sorting_result = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity) - return sortres + return sorting_result class ResultExtensionCommonTestSuite: diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 6358b9fd31..be9d1fad82 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -59,7 +59,7 @@ class ComputeUnitLocations(ResultExtension): def __init__(self, sorting_result): ResultExtension.__init__(self, sorting_result) - def _set_params(self, method="center_of_mass", **method_kwargs): + def _set_params(self, method="monopolar_triangulation", **method_kwargs): params = dict(method=method, method_kwargs=method_kwargs) return params From 650f37b79980d799a839caaef7dcefe9e4345c93 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 16 Feb 2024 14:16:55 +0100 Subject: [PATCH 098/192] And finally after long debate : SortingResult is becoming SortingAnalyzer. Sorry for you Alessio! --- .../comparison/groundtruthstudy.py | 18 +- src/spikeinterface/comparison/hybrid.py | 2 +- .../comparison/paircomparisons.py | 36 +-- .../comparison/tests/test_groundtruthstudy.py | 2 +- .../comparison/tests/test_hybrid.py | 2 +- .../tests/test_templatecomparison.py | 18 +- src/spikeinterface/core/__init__.py | 6 +- ...ult_core.py => analyzer_extension_core.py} | 96 +++--- .../{sortingresult.py => sortinganalyzer.py} | 238 +++++++-------- src/spikeinterface/core/sparsity.py | 150 +++++----- src/spikeinterface/core/template_tools.py | 68 ++--- ...ore.py => test_analyzer_extension_core.py} | 82 ++--- .../core/tests/test_node_pipeline.py | 10 +- ...rtingresult.py => test_sortinganalyzer.py} | 114 +++---- .../core/tests/test_sparsity.py | 30 +- .../core/tests/test_template_tools.py | 62 ++-- ...forms_extractor_backwards_compatibility.py | 14 +- .../core/tests/test_zarrextractors.py | 2 +- src/spikeinterface/core/waveform_tools.py | 2 +- ...forms_extractor_backwards_compatibility.py | 144 ++++----- src/spikeinterface/core/zarrextractors.py | 2 +- src/spikeinterface/curation/auto_merge.py | 34 +-- .../curation/remove_redundant.py | 30 +- src/spikeinterface/curation/tests/common.py | 30 +- .../curation/tests/test_auto_merge.py | 24 +- .../curation/tests/test_remove_redundant.py | 24 +- .../tests/test_sortingview_curation.py | 18 +- src/spikeinterface/exporters/report.py | 48 +-- src/spikeinterface/exporters/tests/common.py | 44 +-- .../exporters/tests/test_export_to_phy.py | 44 +-- .../exporters/tests/test_report.py | 12 +- src/spikeinterface/exporters/to_phy.py | 84 +++--- .../postprocessing/amplitude_scalings.py | 56 ++-- .../postprocessing/correlograms.py | 28 +- src/spikeinterface/postprocessing/isi.py | 14 +- .../postprocessing/noise_level.py | 2 +- .../postprocessing/principal_component.py | 92 +++--- .../postprocessing/spike_amplitudes.py | 30 +- .../postprocessing/spike_locations.py | 30 +- .../postprocessing/template_metrics.py | 28 +- .../postprocessing/template_similarity.py | 22 +- .../tests/common_extension_tests.py | 36 +-- .../tests/test_amplitude_scalings.py | 10 +- .../tests/test_principal_component.py | 46 +-- .../tests/test_spike_amplitudes.py | 6 +- .../tests/test_template_similarity.py | 16 +- .../postprocessing/unit_localization.py | 72 ++--- .../qualitymetrics/misc_metrics.py | 280 +++++++++--------- .../qualitymetrics/pca_metrics.py | 90 +++--- .../quality_metric_calculator.py | 18 +- .../tests/test_metrics_functions.py | 184 ++++++------ .../qualitymetrics/tests/test_pca_metrics.py | 56 ++-- .../tests/test_quality_metric_calculator.py | 88 +++--- .../tests/test_template_matching.py | 44 +-- .../waveforms/temporal_pca.py | 14 +- .../widgets/all_amplitudes_distributions.py | 24 +- src/spikeinterface/widgets/amplitudes.py | 20 +- src/spikeinterface/widgets/base.py | 22 +- .../widgets/crosscorrelograms.py | 26 +- src/spikeinterface/widgets/quality_metrics.py | 14 +- src/spikeinterface/widgets/sorting_summary.py | 32 +- src/spikeinterface/widgets/spike_locations.py | 20 +- .../widgets/spikes_on_traces.py | 40 +-- .../widgets/template_metrics.py | 14 +- .../widgets/template_similarity.py | 14 +- .../widgets/tests/test_widgets.py | 126 ++++---- src/spikeinterface/widgets/unit_depths.py | 20 +- src/spikeinterface/widgets/unit_locations.py | 22 +- src/spikeinterface/widgets/unit_probe_map.py | 22 +- src/spikeinterface/widgets/unit_summary.py | 38 +-- src/spikeinterface/widgets/unit_waveforms.py | 66 ++--- .../widgets/unit_waveforms_density_map.py | 38 +-- 72 files changed, 1655 insertions(+), 1655 deletions(-) rename src/spikeinterface/core/{result_core.py => analyzer_extension_core.py} (80%) rename src/spikeinterface/core/{sortingresult.py => sortinganalyzer.py} (86%) rename src/spikeinterface/core/tests/{test_result_core.py => test_analyzer_extension_core.py} (63%) rename src/spikeinterface/core/tests/{test_sortingresult.py => test_sortinganalyzer.py} (53%) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 3c390434bb..fef06f4fe4 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -8,7 +8,7 @@ import numpy as np -from spikeinterface.core import load_extractor, start_sorting_result, load_sorting_result +from spikeinterface.core import load_extractor, create_sorting_analyzer, load_sorting_analyzer from spikeinterface.core.core_tools import SIJsonEncoder from spikeinterface.core.job_tools import split_job_kwargs @@ -284,13 +284,13 @@ def get_run_times(self, case_keys=None): return pd.Series(run_times, name="run_time") - def start_sorting_result_gt(self, case_keys=None, **kwargs): + def create_sorting_analyzer_gt(self, case_keys=None, **kwargs): if case_keys is None: case_keys = self.cases.keys() select_params, job_kwargs = split_job_kwargs(kwargs) - base_folder = self.folder / "sorting_result" + base_folder = self.folder / "sorting_analyzer" base_folder.mkdir(exist_ok=True) dataset_keys = [self.cases[key]["dataset"] for key in case_keys] @@ -299,17 +299,17 @@ def start_sorting_result_gt(self, case_keys=None, **kwargs): # the waveforms depend on the dataset key folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] - sorting_result = start_sorting_result(gt_sorting, recording, format="binray_folder", folder=folder) - sorting_result.select_random_spikes(**select_params) - sorting_result.compute("fast_templates", **job_kwargs) + sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binray_folder", folder=folder) + sorting_analyzer.select_random_spikes(**select_params) + sorting_analyzer.compute("fast_templates", **job_kwargs) def get_waveform_extractor(self, case_key=None, dataset_key=None): if case_key is not None: dataset_key = self.cases[case_key]["dataset"] - folder = self.folder / "sorting_result" / self.key_to_str(dataset_key) - sorting_result = load_sorting_result(folder) - return sorting_result + folder = self.folder / "sorting_analyzer" / self.key_to_str(dataset_key) + sorting_analyzer = load_sorting_analyzer(folder) + return sorting_analyzer def get_templates(self, key, mode="average"): we = self.get_waveform_extractor(case_key=key) diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index aaf2898987..75812bad17 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -16,7 +16,7 @@ generate_sorting_to_inject, ) -# TODO aurelien : this is still using the WaveformExtractor!!! can you change it to use SortingResult ? +# TODO aurelien : this is still using the WaveformExtractor!!! can you change it to use SortingAnalyzer ? class HybridUnitsRecording(InjectTemplatesRecording): diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index cba47a6b67..50c3ee4071 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -696,14 +696,14 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): Parameters ---------- - sorting_result_1 : SortingResult - The first SortingResult to get templates to compare - sorting_result_2 : SortingResult - The second SortingResult to get templates to compare + sorting_analyzer_1 : SortingAnalyzer + The first SortingAnalyzer to get templates to compare + sorting_analyzer_2 : SortingAnalyzer + The second SortingAnalyzer to get templates to compare unit_ids1 : list, default: None - List of units from sorting_result_1 to compare + List of units from sorting_analyzer_1 to compare unit_ids2 : list, default: None - List of units from sorting_result_2 to compare + List of units from sorting_analyzer_2 to compare similarity_method : str, default: "cosine_similarity" Method for the similaroty matrix sparsity_dict : dict, default: None @@ -719,8 +719,8 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): def __init__( self, - sorting_result_1, - sorting_result_2, + sorting_analyzer_1, + sorting_analyzer_2, name1=None, name2=None, unit_ids1=None, @@ -737,8 +737,8 @@ def __init__( name2 = "sess2" BasePairComparison.__init__( self, - object1=sorting_result_1, - object2=sorting_result_2, + object1=sorting_analyzer_1, + object2=sorting_analyzer_2, name1=name1, name2=name2, match_score=match_score, @@ -747,13 +747,13 @@ def __init__( ) MixinTemplateComparison.__init__(self, similarity_method=similarity_method, sparsity_dict=sparsity_dict) - self.sorting_result_1 = sorting_result_1 - self.sorting_result_2 = sorting_result_2 - channel_ids1 = sorting_result_1.recording.get_channel_ids() - channel_ids2 = sorting_result_2.recording.get_channel_ids() + self.sorting_analyzer_1 = sorting_analyzer_1 + self.sorting_analyzer_2 = sorting_analyzer_2 + channel_ids1 = sorting_analyzer_1.recording.get_channel_ids() + channel_ids2 = sorting_analyzer_2.recording.get_channel_ids() # two options: all channels are shared or partial channels are shared - if sorting_result_1.recording.get_num_channels() != sorting_result_2.recording.get_num_channels(): + if sorting_analyzer_1.recording.get_num_channels() != sorting_analyzer_2.recording.get_num_channels(): raise NotImplementedError if np.any([ch1 != ch2 for (ch1, ch2) in zip(channel_ids1, channel_ids2)]): # TODO: here we can check location and run it on the union. Might be useful for reconfigurable probes @@ -762,10 +762,10 @@ def __init__( self.matches = dict() if unit_ids1 is None: - unit_ids1 = sorting_result_1.sorting.get_unit_ids() + unit_ids1 = sorting_analyzer_1.sorting.get_unit_ids() if unit_ids2 is None: - unit_ids2 = sorting_result_2.sorting.get_unit_ids() + unit_ids2 = sorting_analyzer_2.sorting.get_unit_ids() self.unit_ids = [unit_ids1, unit_ids2] if sparsity_dict is not None: @@ -781,7 +781,7 @@ def _do_agreement(self): print("Agreement scores...") agreement_scores = compute_template_similarity_by_pair( - self.sorting_result_1, self.sorting_result_2, method=self.similarity_method + self.sorting_analyzer_1, self.sorting_analyzer_2, method=self.similarity_method ) import pandas as pd diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index ef79299795..b7df085fab 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -86,7 +86,7 @@ def test_GroundTruthStudy(): study.run_comparisons() print(study.comparisons) - study.start_sorting_result_gt(n_jobs=-1) + study.create_sorting_analyzer_gt(n_jobs=-1) study.compute_metrics() diff --git a/src/spikeinterface/comparison/tests/test_hybrid.py b/src/spikeinterface/comparison/tests/test_hybrid.py index d1da0005f9..8c392f7687 100644 --- a/src/spikeinterface/comparison/tests/test_hybrid.py +++ b/src/spikeinterface/comparison/tests/test_hybrid.py @@ -36,7 +36,7 @@ def setup_module(): def test_hybrid_units_recording(): wvf_extractor = load_waveforms(cache_folder / "wvf_extractor") print(wvf_extractor) - print(wvf_extractor.sorting_result) + print(wvf_extractor.sorting_analyzer) recording = wvf_extractor.recording templates = wvf_extractor.get_all_templates() diff --git a/src/spikeinterface/comparison/tests/test_templatecomparison.py b/src/spikeinterface/comparison/tests/test_templatecomparison.py index 90f35e4dbf..6e4c2d1714 100644 --- a/src/spikeinterface/comparison/tests/test_templatecomparison.py +++ b/src/spikeinterface/comparison/tests/test_templatecomparison.py @@ -3,7 +3,7 @@ from pathlib import Path import numpy as np -from spikeinterface.core import start_sorting_result +from spikeinterface.core import create_sorting_analyzer from spikeinterface.extractors import toy_example from spikeinterface.comparison import compare_templates, compare_multiple_templates @@ -41,16 +41,16 @@ def test_compare_multiple_templates(): sort3 = sort.frame_slice(start_frame=2 / 3 * duration * fs, end_frame=duration * fs) # compute waveforms - sorting_result_1 = start_sorting_result(sort1, rec1, format="memory") - sorting_result_2 = start_sorting_result(sort2, rec2, format="memory") - sorting_result_3 = start_sorting_result(sort3, rec3, format="memory") + sorting_analyzer_1 = create_sorting_analyzer(sort1, rec1, format="memory") + sorting_analyzer_2 = create_sorting_analyzer(sort2, rec2, format="memory") + sorting_analyzer_3 = create_sorting_analyzer(sort3, rec3, format="memory") - for sorting_result in (sorting_result_1, sorting_result_2, sorting_result_3): - sorting_result.select_random_spikes() - sorting_result.compute("fast_templates") + for sorting_analyzer in (sorting_analyzer_1, sorting_analyzer_2, sorting_analyzer_3): + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("fast_templates") # paired comparison - temp_cmp = compare_templates(sorting_result_1, sorting_result_2) + temp_cmp = compare_templates(sorting_analyzer_1, sorting_analyzer_2) for u1 in temp_cmp.hungarian_match_12.index.values: u2 = temp_cmp.hungarian_match_12[u1] @@ -58,7 +58,7 @@ def test_compare_multiple_templates(): assert u1 == u2 # multi-comparison - temp_mcmp = compare_multiple_templates([sorting_result_1, sorting_result_2, sorting_result_3]) + temp_mcmp = compare_multiple_templates([sorting_analyzer_1, sorting_analyzer_2, sorting_analyzer_3]) # assert unit ids are the same across sessions (because of initial slicing) for unit_dict in temp_mcmp.units.values(): unit_ids = unit_dict["unit_ids"].values() diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 4f386c645c..bdc29cd17a 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -143,9 +143,9 @@ from .template import Templates -# SortingResult and ResultExtension -from .sortingresult import SortingResult, ResultExtension, start_sorting_result, load_sorting_result -from .result_core import ( +# SortingAnalyzer and ResultExtension +from .sortinganalyzer import SortingAnalyzer, ResultExtension, create_sorting_analyzer, load_sorting_analyzer +from .analyzer_extension_core import ( ComputeWaveforms, compute_waveforms, ComputeTemplates, diff --git a/src/spikeinterface/core/result_core.py b/src/spikeinterface/core/analyzer_extension_core.py similarity index 80% rename from src/spikeinterface/core/result_core.py rename to src/spikeinterface/core/analyzer_extension_core.py index c6be619319..f6d7399c4e 100644 --- a/src/spikeinterface/core/result_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -11,7 +11,7 @@ import numpy as np -from .sortingresult import ResultExtension, register_result_extension +from .sortinganalyzer import ResultExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_average from .recording_tools import get_noise_levels from .template import Templates @@ -21,7 +21,7 @@ class ComputeWaveforms(ResultExtension): """ ResultExtension that extract some waveforms of each units. - The sparsity is controlled by the SortingResult sparsity. + The sparsity is controlled by the SortingAnalyzer sparsity. """ extension_name = "waveforms" @@ -32,25 +32,25 @@ class ComputeWaveforms(ResultExtension): @property def nbefore(self): - return int(self.params["ms_before"] * self.sorting_result.sampling_frequency / 1000.0) + return int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) @property def nafter(self): - return int(self.params["ms_after"] * self.sorting_result.sampling_frequency / 1000.0) + return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) def _run(self, **job_kwargs): self.data.clear() - if self.sorting_result.random_spikes_indices is None: - raise ValueError("compute_waveforms need SortingResult.select_random_spikes() need to be run first") + if self.sorting_analyzer.random_spikes_indices is None: + raise ValueError("compute_waveforms need SortingAnalyzer.select_random_spikes() need to be run first") - recording = self.sorting_result.recording - sorting = self.sorting_result.sorting + recording = self.sorting_analyzer.recording + sorting = self.sorting_analyzer.sorting unit_ids = sorting.unit_ids # retrieve spike vector and the sampling spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] if self.format == "binary_folder": # in that case waveforms are extacted directly in files @@ -92,7 +92,7 @@ def _set_params( return_scaled: bool = True, dtype=None, ): - recording = self.sorting_result.recording + recording = self.sorting_analyzer.recording if dtype is None: dtype = recording.get_dtype() @@ -116,9 +116,9 @@ def _set_params( return params def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) - spikes = self.sorting_result.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) + spikes = self.sorting_analyzer.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] keep_spike_mask = np.isin(some_spikes["unit_index"], keep_unit_indices) new_data = dict() @@ -131,15 +131,15 @@ def get_waveforms_one_unit( unit_id, force_dense: bool = False, ): - sorting = self.sorting_result.sorting + sorting = self.sorting_analyzer.sorting unit_index = sorting.id_to_index(unit_id) spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] spike_mask = some_spikes["unit_index"] == unit_index wfs = self.data["waveforms"][spike_mask, :, :] - if self.sorting_result.sparsity is not None: - chan_inds = self.sorting_result.sparsity.unit_id_to_channel_indices[unit_id] + if self.sorting_analyzer.sparsity is not None: + chan_inds = self.sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id] wfs = wfs[:, :, : chan_inds.size] if force_dense: num_channels = self.get_num_channels() @@ -161,12 +161,12 @@ class ComputeTemplates(ResultExtension): """ ResultExtension that compute templates (average, str, median, percentile, ...) - This must be run after "waveforms" extension (`SortingResult.compute("waveforms")`) + This must be run after "waveforms" extension (`SortingAnalyzer.compute("waveforms")`) Note that when "waveforms" is already done, then the recording is not needed anymore for this extension. Note: by default only the average is computed. Other operator (std, median, percentile) can be computed on demand - after the SortingResult.compute("templates") and then the data dict is updated on demand. + after the SortingAnalyzer.compute("templates") and then the data dict is updated on demand. """ @@ -187,7 +187,7 @@ def _set_params(self, operators=["average", "std"]): assert len(operator) == 2 assert operator[0] == "percentile" - waveforms_extension = self.sorting_result.get_extension("waveforms") + waveforms_extension = self.sorting_analyzer.get_extension("waveforms") params = dict( operators=operators, @@ -201,9 +201,9 @@ def _run(self): self._compute_and_append(self.params["operators"]) def _compute_and_append(self, operators): - unit_ids = self.sorting_result.unit_ids - channel_ids = self.sorting_result.channel_ids - waveforms_extension = self.sorting_result.get_extension("waveforms") + unit_ids = self.sorting_analyzer.unit_ids + channel_ids = self.sorting_analyzer.channel_ids + waveforms_extension = self.sorting_analyzer.get_extension("waveforms") waveforms = waveforms_extension.data["waveforms"] num_samples = waveforms.shape[1] @@ -219,8 +219,8 @@ def _compute_and_append(self, operators): raise ValueError(f"ComputeTemplates: wrong operator {operator}") self.data[key] = np.zeros((unit_ids.size, num_samples, channel_ids.size)) - spikes = self.sorting_result.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + spikes = self.sorting_analyzer.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] for unit_index, unit_id in enumerate(unit_ids): spike_mask = some_spikes["unit_index"] == unit_index wfs = waveforms[spike_mask, :, :] @@ -257,7 +257,7 @@ def nafter(self): return self.params["nafter"] def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) new_data = dict() for key, arr in self.data.items(): @@ -279,11 +279,11 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): elif outputs == "Templates": return Templates( templates_array=templates_array, - sampling_frequency=self.sorting_result.sampling_frequency, + sampling_frequency=self.sorting_analyzer.sampling_frequency, nbefore=self.nbefore, - channel_ids=self.sorting_result.channel_ids, - unit_ids=self.sorting_result.unit_ids, - probe=self.sorting_result.get_probe(), + channel_ids=self.sorting_analyzer.channel_ids, + unit_ids=self.sorting_analyzer.unit_ids, + probe=self.sorting_analyzer.get_probe(), ) else: raise ValueError("outputs must be numpy or Templates") @@ -331,7 +331,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save self.save() if unit_ids is not None: - unit_indices = self.sorting_result.sorting.ids_to_indices(unit_ids) + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) templates = templates[unit_indices, :, :] return np.array(templates) @@ -355,25 +355,25 @@ class ComputeFastTemplates(ResultExtension): @property def nbefore(self): - return int(self.params["ms_before"] * self.sorting_result.sampling_frequency / 1000.0) + return int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) @property def nafter(self): - return int(self.params["ms_after"] * self.sorting_result.sampling_frequency / 1000.0) + return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) def _run(self, **job_kwargs): self.data.clear() - if self.sorting_result.random_spikes_indices is None: - raise ValueError("compute_waveforms need SortingResult.select_random_spikes() need to be run first") + if self.sorting_analyzer.random_spikes_indices is None: + raise ValueError("compute_waveforms need SortingAnalyzer.select_random_spikes() need to be run first") - recording = self.sorting_result.recording - sorting = self.sorting_result.sorting + recording = self.sorting_analyzer.recording + sorting = self.sorting_analyzer.sorting unit_ids = sorting.unit_ids # retrieve spike vector and the sampling spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] return_scaled = self.params["return_scaled"] @@ -403,17 +403,17 @@ def _get_data(self, outputs="numpy"): elif outputs == "Templates": return Templates( templates_array=templates_array, - sampling_frequency=self.sorting_result.sampling_frequency, + sampling_frequency=self.sorting_analyzer.sampling_frequency, nbefore=self.nbefore, - channel_ids=self.sorting_result.channel_ids, - unit_ids=self.sorting_result.unit_ids, - probe=self.sorting_result.get_probe(), + channel_ids=self.sorting_analyzer.channel_ids, + unit_ids=self.sorting_analyzer.unit_ids, + probe=self.sorting_analyzer.get_probe(), ) else: raise ValueError("outputs must be numpy or Templates") def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) new_data = dict() new_data["average"] = self.data["average"][keep_unit_indices, :, :] @@ -439,8 +439,8 @@ class ComputeNoiseLevels(ResultExtension): Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object **params: dict with additional parameters Returns @@ -455,8 +455,8 @@ class ComputeNoiseLevels(ResultExtension): use_nodepipeline = False need_job_kwargs = False - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) + def __init__(self, sorting_analyzer): + ResultExtension.__init__(self, sorting_analyzer) def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, return_scaled=True, seed=None): params = dict( @@ -469,7 +469,7 @@ def _select_extension_data(self, unit_ids): return self.data def _run(self): - self.data["noise_levels"] = get_noise_levels(self.sorting_result.recording, **self.params) + self.data["noise_levels"] = get_noise_levels(self.sorting_analyzer.recording, **self.params) def _get_data(self): return self.data["noise_levels"] diff --git a/src/spikeinterface/core/sortingresult.py b/src/spikeinterface/core/sortinganalyzer.py similarity index 86% rename from src/spikeinterface/core/sortingresult.py rename to src/spikeinterface/core/sortinganalyzer.py index 4de643a0db..928dffa780 100644 --- a/src/spikeinterface/core/sortingresult.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -30,15 +30,15 @@ from .node_pipeline import run_node_pipeline -# TODO make some_spikes a method of SortingResult +# TODO make some_spikes a method of SortingAnalyzer # high level function -def start_sorting_result( +def create_sorting_analyzer( sorting, recording, format="memory", folder=None, sparse=True, sparsity=None, **sparsity_kwargs ): """ - Create a SortingResult by pairing a Sorting and the corresponding Recording. + Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. This object will handle a list of ResultExtension for all the post processing steps like: waveforms, templates, unit locations, spike locations, quality mertics ... @@ -68,29 +68,29 @@ def start_sorting_result( Returns ------- - sorting_result: SortingResult - The SortingResult object + sorting_analyzer: SortingAnalyzer + The SortingAnalyzer object Examples -------- >>> import spikeinterface as si >>> # Extract dense waveforms and save to disk with binary_folder format. - >>> sorting_result = si.start_sorting_result(sorting, recording, format="binary_folder", folder="/path/to_my/result") + >>> sorting_analyzer = si.create_sorting_analyzer(sorting, recording, format="binary_folder", folder="/path/to_my/result") >>> # Can be reload - >>> sorting_result = si.load_sorting_result(folder="/path/to_my/result") + >>> sorting_analyzer = si.load_sorting_analyzer(folder="/path/to_my/result") >>> # Can run extension - >>> sorting_result = si.compute("unit_locations", ...) + >>> sorting_analyzer = si.compute("unit_locations", ...) >>> # Can be copy to another format (extensions are propagated) - >>> sorting_result2 = sorting_result.save_as(format="memory") - >>> sorting_result3 = sorting_result.save_as(format="zarr", folder="/path/to_my/result.zarr") + >>> sorting_analyzer2 = sorting_analyzer.save_as(format="memory") + >>> sorting_analyzer3 = sorting_analyzer.save_as(format="zarr", folder="/path/to_my/result.zarr") >>> # Can make a copy with a subset of units (extensions are propagated for the unit subset) - >>> sorting_result4 = sorting_result.select_units(unit_ids=sorting.units_ids[:5], format="memory") - >>> sorting_result5 = sorting_result.select_units(unit_ids=sorting.units_ids[:5], format="binary_folder", folder="/result_5units") + >>> sorting_analyzer4 = sorting_analyzer.select_units(unit_ids=sorting.units_ids[:5], format="memory") + >>> sorting_analyzer5 = sorting_analyzer.select_units(unit_ids=sorting.units_ids[:5], format="binary_folder", folder="/result_5units") """ # handle sparsity @@ -99,23 +99,23 @@ def start_sorting_result( assert isinstance(sparsity, ChannelSparsity), "'sparsity' must be a ChannelSparsity object" assert np.array_equal( sorting.unit_ids, sparsity.unit_ids - ), "start_sorting_result(): if external sparsity is given unit_ids must correspond" + ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" assert np.array_equal( recording.channel_ids, recording.channel_ids - ), "start_sorting_result(): if external sparsity is given unit_ids must correspond" + ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" elif sparse: sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs) else: sparsity = None - sorting_result = SortingResult.create(sorting, recording, format=format, folder=folder, sparsity=sparsity) + sorting_analyzer = SortingAnalyzer.create(sorting, recording, format=format, folder=folder, sparsity=sparsity) - return sorting_result + return sorting_analyzer -def load_sorting_result(folder, load_extensions=True, format="auto"): +def load_sorting_analyzer(folder, load_extensions=True, format="auto"): """ - Load a SortingResult object from disk. + Load a SortingAnalyzer object from disk. Parameters ---------- @@ -128,15 +128,15 @@ def load_sorting_result(folder, load_extensions=True, format="auto"): Returns ------- - sorting_result: SortingResult - The loaded SortingResult + sorting_analyzer: SortingAnalyzer + The loaded SortingAnalyzer """ - return SortingResult.load(folder, load_extensions=load_extensions, format=format) + return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format) -class SortingResult: +class SortingAnalyzer: """ Class to make a pair of Recording-Sorting which will be used used for all post postprocessing, visualization and quality metric computation. @@ -144,18 +144,18 @@ class SortingResult: This internaly maintain a list of computed ResultExtention (waveform, pca, unit position, spike poisition, ...). This can live in memory and/or can be be persistent to disk in 2 internal formats (folder/json/npz or zarr). - A SortingResult can be transfer to another format using `save_as()` + A SortingAnalyzer can be transfer to another format using `save_as()` This handle unit sparsity that can be propagated to ResultExtention. This handle spike sampling that can be propagated to ResultExtention : work only on a subset of spikes. This internally save a copy of the Sorting and extract main recording attributes (without traces) so - the SortingResult object can be reload even if references to the original sorting and/or to the original recording + the SortingAnalyzer object can be reload even if references to the original sorting and/or to the original recording are lost. - SortingResult() should not never be used directly for creating: use instead start_sorting_result(sorting, resording, ...) - or eventually SortingResult.create(...) + SortingAnalyzer() should not never be used directly for creating: use instead create_sorting_analyzer(sorting, resording, ...) + or eventually SortingAnalyzer.create(...) """ def __init__( @@ -209,19 +209,19 @@ def create( check_probe_do_not_overlap(all_probes) if format == "memory": - sorting_result = cls.create_memory(sorting, recording, sparsity, rec_attributes=None) + sorting_analyzer = cls.create_memory(sorting, recording, sparsity, rec_attributes=None) elif format == "binary_folder": cls.create_binary_folder(folder, sorting, recording, sparsity, rec_attributes=None) - sorting_result = cls.load_from_binary_folder(folder, recording=recording) - sorting_result.folder = folder + sorting_analyzer = cls.load_from_binary_folder(folder, recording=recording) + sorting_analyzer.folder = folder elif format == "zarr": cls.create_zarr(folder, sorting, recording, sparsity, rec_attributes=None) - sorting_result = cls.load_from_zarr(folder, recording=recording) - sorting_result.folder = folder + sorting_analyzer = cls.load_from_zarr(folder, recording=recording) + sorting_analyzer.folder = folder else: - raise ValueError("SortingResult.create: wrong format") + raise ValueError("SortingAnalyzer.create: wrong format") - return sorting_result + return sorting_analyzer @classmethod def load(cls, folder, recording=None, load_extensions=True, format="auto"): @@ -240,16 +240,16 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto"): format = "binary_folder" if format == "binary_folder": - sorting_result = SortingResult.load_from_binary_folder(folder, recording=recording) + sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) elif format == "zarr": - sorting_result = SortingResult.load_from_zarr(folder, recording=recording) + sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) - sorting_result.folder = folder + sorting_analyzer.folder = folder if load_extensions: - sorting_result.load_all_saved_extension() + sorting_analyzer.load_all_saved_extension() - return sorting_result + return sorting_analyzer @classmethod def create_memory(cls, sorting, recording, sparsity, rec_attributes): @@ -265,16 +265,16 @@ def create_memory(cls, sorting, recording, sparsity, rec_attributes): # a copy of sorting is created directly in shared memory format to avoid further duplication of spikes. sorting_copy = SharedMemorySorting.from_sorting(sorting, with_metadata=True) - sorting_result = SortingResult( + sorting_analyzer = SortingAnalyzer( sorting=sorting_copy, recording=recording, rec_attributes=rec_attributes, format="memory", sparsity=sparsity ) - return sorting_result + return sorting_analyzer @classmethod def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attributes): # used by create and save_as - assert recording is not None, "To create a SortingResult you need recording not None" + assert recording is not None, "To create a SortingAnalyzer you need recording not None" folder = Path(folder) if folder.is_dir(): @@ -285,7 +285,7 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attribut info = dict( version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, - object="SortingResult", + object="SortingAnalyzer", ) with open(info_file, mode="w") as f: json.dump(check_json(info), f, indent=4) @@ -354,7 +354,7 @@ def load_from_binary_folder(cls, folder, recording=None): # recording attributes rec_attributes_file = folder / "recording_info" / "recording_attributes.json" if not rec_attributes_file.exists(): - raise ValueError("This folder is not a SortingResult with format='binary_folder'") + raise ValueError("This folder is not a SortingAnalyzer with format='binary_folder'") with open(rec_attributes_file, "r") as f: rec_attributes = json.load(f) # the probe is handle ouside the main json @@ -382,7 +382,7 @@ def load_from_binary_folder(cls, folder, recording=None): else: random_spikes_indices = None - sorting_result = SortingResult( + sorting_analyzer = SortingAnalyzer( sorting=sorting, recording=recording, rec_attributes=rec_attributes, @@ -391,7 +391,7 @@ def load_from_binary_folder(cls, folder, recording=None): random_spikes_indices=random_spikes_indices, ) - return sorting_result + return sorting_analyzer def _get_zarr_root(self, mode="r+"): import zarr @@ -415,7 +415,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): zarr_root = zarr.open(folder, mode="w") - info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingResult") + info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) # the recording @@ -429,7 +429,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) else: warnings.warn( - "SortingResult with zarr : the Recording is not json serializable, the recording link will be lost for futur load" + "SortingAnalyzer with zarr : the Recording is not json serializable, the recording link will be lost for futur load" ) # sorting provenance @@ -444,7 +444,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle()) # else: - # warnings.warn("SortingResult with zarr : the sorting provenance is not json serializable, the sorting provenance link will be lost for futur load") + # warnings.warn("SortingAnalyzer with zarr : the sorting provenance is not json serializable, the sorting provenance link will be lost for futur load") recording_info = zarr_root.create_group("recording_info") @@ -528,7 +528,7 @@ def load_from_zarr(cls, folder, recording=None): else: random_spikes_indices = None - sorting_result = SortingResult( + sorting_analyzer = SortingAnalyzer( sorting=sorting, recording=recording, rec_attributes=rec_attributes, @@ -537,9 +537,9 @@ def load_from_zarr(cls, folder, recording=None): random_spikes_indices=random_spikes_indices, ) - return sorting_result + return sorting_analyzer - def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingResult": + def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer": """ Internal used by both save_as(), copy() and select_units() which are more or less the same. """ @@ -569,28 +569,28 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> sorting_provenance = sorting_provenance.select_units(unit_ids) if format == "memory": - # This make a copy of actual SortingResult - new_sorting_result = SortingResult.create_memory(sorting_provenance, recording, sparsity, self.rec_attributes) + # This make a copy of actual SortingAnalyzer + new_sorting_analyzer = SortingAnalyzer.create_memory(sorting_provenance, recording, sparsity, self.rec_attributes) elif format == "binary_folder": # create a new folder assert folder is not None, "For format='binary_folder' folder must be provided" - SortingResult.create_binary_folder(folder, sorting_provenance, recording, sparsity, self.rec_attributes) - new_sorting_result = SortingResult.load_from_binary_folder(folder, recording=recording) - new_sorting_result.folder = folder + SortingAnalyzer.create_binary_folder(folder, sorting_provenance, recording, sparsity, self.rec_attributes) + new_sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) + new_sorting_analyzer.folder = folder elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" - SortingResult.create_zarr(folder, sorting_provenance, recording, sparsity, self.rec_attributes) - new_sorting_result = SortingResult.load_from_zarr(folder, recording=recording) - new_sorting_result.folder = folder + SortingAnalyzer.create_zarr(folder, sorting_provenance, recording, sparsity, self.rec_attributes) + new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) + new_sorting_analyzer.folder = folder else: - raise ValueError("SortingResult.save: wrong format") + raise ValueError("SortingAnalyzer.save: wrong format") # propagate random_spikes_indices is already done if self.random_spikes_indices is not None: if unit_ids is None: - new_sorting_result.random_spikes_indices = self.random_spikes_indices.copy() + new_sorting_analyzer.random_spikes_indices = self.random_spikes_indices.copy() else: # more tricky spikes = self.sorting.to_spike_vector() @@ -601,21 +601,21 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> selected_mask = np.zeros(spikes.size, dtype=bool) selected_mask[self.random_spikes_indices] = True - new_sorting_result.random_spikes_indices = np.flatnonzero(selected_mask[keep_spike_mask]) + new_sorting_analyzer.random_spikes_indices = np.flatnonzero(selected_mask[keep_spike_mask]) # save it - new_sorting_result._save_random_spikes_indices() + new_sorting_analyzer._save_random_spikes_indices() # make a copy of extensions # note that the copy of extension handle itself the slicing of units when necessary and also the saveing for extension_name, extension in self.extensions.items(): - new_ext = new_sorting_result.extensions[extension_name] = extension.copy(new_sorting_result, unit_ids=unit_ids) + new_ext = new_sorting_analyzer.extensions[extension_name] = extension.copy(new_sorting_analyzer, unit_ids=unit_ids) - return new_sorting_result + return new_sorting_analyzer - def save_as(self, format="memory", folder=None) -> "SortingResult": + def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": """ - Save SortingResult object into another format. + Save SortingAnalyzer object into another format. Uselfull for memory to zarr or memory to binray. Note that the recording provenance or sorting provenance can be lost. @@ -631,7 +631,7 @@ def save_as(self, format="memory", folder=None) -> "SortingResult": """ return self._save_or_select(format=format, folder=folder, unit_ids=None) - def select_units(self, unit_ids, format="memory", folder=None) -> "SortingResult": + def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ This method is equivalent to `save_as()`but with a subset of units. Filters units by creating a new waveform extractor object in a new folder. @@ -641,14 +641,14 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingResult Parameters ---------- unit_ids : list or array - The unit ids to keep in the new SortingResult object + The unit ids to keep in the new SortingAnalyzer object folder : Path or None The new folder where selected waveforms are copied format: a Returns ------- - we : SortingResult + we : SortingAnalyzer The newly create waveform extractor with the selected units """ # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! @@ -656,7 +656,7 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingResult def copy(self): """ - Create a a copy of SortingResult with format "memory". + Create a a copy of SortingAnalyzer with format "memory". """ return self._save_or_select(format="memory", folder=None, unit_ids=None) @@ -670,7 +670,7 @@ def is_read_only(self) -> bool: @property def recording(self) -> BaseRecording: if not self.has_recording(): - raise ValueError("SortingResult could not load the recording") + raise ValueError("SortingAnalyzer could not load the recording") return self._recording @property @@ -789,7 +789,7 @@ def compute(self, input, save=True, **kwargs): return self.compute_one_extension(extension_name=input, save=save, **kwargs) elif isinstance(input, dict): params_, job_kwargs = split_job_kwargs(kwargs) - assert len(params_) == 0, "Too many arguments for SortingResult.compute_several_extensions()" + assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" self.compute_several_extensions(extensions=input, save=save, **job_kwargs) def compute_one_extension(self, extension_name, save=True, **kwargs): @@ -818,11 +818,11 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): -------- >>> Note that the return is the instance extension. - >>> extension = sorting_result.compute("waveforms", **some_params) - >>> extension = sorting_result.compute_one_extension("waveforms", **some_params) + >>> extension = sorting_analyzer.compute("waveforms", **some_params) + >>> extension = sorting_analyzer.compute_one_extension("waveforms", **some_params) >>> wfs = extension.data["waveforms"] >>> # Note this can be be done in the old way style BUT the return is not the same it return directly data - >>> wfs = compute_waveforms(sorting_result, **some_params) + >>> wfs = compute_waveforms(sorting_analyzer, **some_params) """ @@ -873,8 +873,8 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): Examples -------- - >>> sorting_result.compute({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std", ]} }) - >>> sorting_result.compute_several_extensions({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std"]}}) + >>> sorting_analyzer.compute({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std", ]} }) + >>> sorting_analyzer.compute_several_extensions({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std"]}}) """ # TODO this is a simple implementation @@ -997,7 +997,7 @@ def load_extension(self, extension_name: str): """ assert ( self.format != "memory" - ), "SortingResult.load_extension() do not work for format='memory' use SortingResult.get_extension()instead" + ), "SortingAnalyzer.load_extension() do not work for format='memory' use SortingAnalyzer.get_extension()instead" extension_class = get_extension_class(extension_name) @@ -1136,7 +1136,7 @@ def get_extension_class(extension_name: str): class ResultExtension: """ - This the base class to extend the SortingResult. + This the base class to extend the SortingAnalyzer. It can handle persistency to disk any computations related For instance: @@ -1146,7 +1146,7 @@ class ResultExtension: * quality metrics Possible extension can be register on the fly at import time with register_result_extension() mechanism. - It also enables any custum computation on top on SortingResult to be implemented by the user. + It also enables any custum computation on top on SortingAnalyzer to be implemented by the user. An extension needs to inherit from this class and implement some attributes and abstract methods: * extension_name @@ -1165,7 +1165,7 @@ class ResultExtension: The subclass must also hanle an attribute `data` which is a dict contain the results after the `run()`. All ResultExtension will have a function associate for instance (this use the function_factory): - comptute_unit_location(sorting_result, ...) will be equivalent to sorting_result.compute("unit_location", ...) + comptute_unit_location(sorting_analyzer, ...) will be equivalent to sorting_analyzer.compute("unit_location", ...) """ @@ -1177,15 +1177,15 @@ class ResultExtension: nodepipeline_variables = None need_job_kwargs = False - def __init__(self, sorting_result): - self._sorting_result = weakref.ref(sorting_result) + def __init__(self, sorting_analyzer): + self._sorting_analyzer = weakref.ref(sorting_analyzer) self.params = None self.data = dict() ####### # This 3 methods must be implemented in the subclass!!! - # See DummyResultExtension in test_sortingresult.py as a simple example + # See DummyResultExtension in test_sortinganalyzer.py as a simple example def _run(self, **kwargs): # must be implemented in subclass # must populate the self.data dictionary @@ -1214,7 +1214,7 @@ def _get_data(self): @classmethod def function_factory(cls): # make equivalent - # comptute_unit_location(sorting_result, ...) <> sorting_result.compute("unit_location", ...) + # comptute_unit_location(sorting_analyzer, ...) <> sorting_analyzer.compute("unit_location", ...) # this also make backcompatibility # comptute_unit_location(we, ...) @@ -1222,15 +1222,15 @@ class FuncWrapper: def __init__(self, extension_name): self.extension_name = extension_name - def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): + def __call__(self, sorting_analyzer, load_if_exists=None, *args, **kwargs): from .waveforms_extractor_backwards_compatibility import MockWaveformExtractor - if isinstance(sorting_result, MockWaveformExtractor): + if isinstance(sorting_analyzer, MockWaveformExtractor): # backward compatibility with WaveformsExtractor - sorting_result = sorting_result.sorting_result + sorting_analyzer = sorting_analyzer.sorting_analyzer - if not isinstance(sorting_result, SortingResult): - raise ValueError(f"compute_{self.extension_name}() need a SortingResult instance") + if not isinstance(sorting_analyzer, SortingAnalyzer): + raise ValueError(f"compute_{self.extension_name}() need a SortingAnalyzer instance") if load_if_exists is not None: # backward compatibility with "load_if_exists" @@ -1239,10 +1239,10 @@ def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): ) assert isinstance(load_if_exists, bool) if load_if_exists: - ext = sorting_result.get_extension(self.extension_name) + ext = sorting_analyzer.get_extension(self.extension_name) return ext - ext = sorting_result.compute(cls.extension_name, *args, **kwargs) + ext = sorting_analyzer.compute(cls.extension_name, *args, **kwargs) return ext.get_data() func = FuncWrapper(cls.extension_name) @@ -1250,42 +1250,42 @@ def __call__(self, sorting_result, load_if_exists=None, *args, **kwargs): return func @property - def sorting_result(self): - # Important : to avoid the SortingResult referencing a ResultExtension - # and ResultExtension referencing a SortingResult we need a weakref. + def sorting_analyzer(self): + # Important : to avoid the SortingAnalyzer referencing a ResultExtension + # and ResultExtension referencing a SortingAnalyzer we need a weakref. # Otherwise the garbage collector is not working properly. - # and so the SortingResult + its recording are still alive even after deleting explicitly - # the SortingResult which makes it impossible to delete the folder when using memmap. - sorting_result = self._sorting_result() - if sorting_result is None: - raise ValueError(f"The extension {self.extension_name} has lost its SortingResult") - return sorting_result - - # some attribuites come from sorting_result + # and so the SortingAnalyzer + its recording are still alive even after deleting explicitly + # the SortingAnalyzer which makes it impossible to delete the folder when using memmap. + sorting_analyzer = self._sorting_analyzer() + if sorting_analyzer is None: + raise ValueError(f"The extension {self.extension_name} has lost its SortingAnalyzer") + return sorting_analyzer + + # some attribuites come from sorting_analyzer @property def format(self): - return self.sorting_result.format + return self.sorting_analyzer.format @property def sparsity(self): - return self.sorting_result.sparsity + return self.sorting_analyzer.sparsity @property def folder(self): - return self.sorting_result.folder + return self.sorting_analyzer.folder def _get_binary_extension_folder(self): extension_folder = self.folder / "extensions" / self.extension_name return extension_folder def _get_zarr_extension_group(self, mode="r+"): - zarr_root = self.sorting_result._get_zarr_root(mode=mode) + zarr_root = self.sorting_analyzer._get_zarr_root(mode=mode) extension_group = zarr_root["extensions"][self.extension_name] return extension_group @classmethod - def load(cls, sorting_result): - ext = cls(sorting_result) + def load(cls, sorting_analyzer): + ext = cls(sorting_analyzer) ext.load_params() ext.load_data() return ext @@ -1333,7 +1333,7 @@ def load_data(self): elif self.format == "zarr": # Alessio # TODO: we need decide if we make a copy to memory or keep the lazy loading. For binary_folder it used to be lazy with memmap - # but this make the garbage complicated when a data is hold by a plot but the o SortingResult is delete + # but this make the garbage complicated when a data is hold by a plot but the o SortingAnalyzer is delete # lets talk extension_group = self._get_zarr_extension_group(mode="r") for ext_data_name in extension_group.keys(): @@ -1353,9 +1353,9 @@ def load_data(self): ext_data = ext_data_ self.data[ext_data_name] = ext_data - def copy(self, new_sorting_result, unit_ids=None): + def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! - new_extension = self.__class__(new_sorting_result) + new_extension = self.__class__(new_sorting_analyzer) new_extension.params = self.params.copy() if unit_ids is None: new_extension.data = self.data @@ -1365,13 +1365,13 @@ def copy(self, new_sorting_result, unit_ids=None): return new_extension def run(self, save=True, **kwargs): - if save and not self.sorting_result.is_read_only(): + if save and not self.sorting_analyzer.is_read_only(): # this also reset the folder or zarr group self._save_params() self._run(**kwargs) - if save and not self.sorting_result.is_read_only(): + if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) def save(self, **kwargs): @@ -1382,8 +1382,8 @@ def _save_data(self, **kwargs): if self.format == "memory": return - if self.sorting_result.is_read_only(): - raise ValueError(f"The SortingResult is read only save extension {self.extension_name} is not possible") + if self.sorting_analyzer.is_read_only(): + raise ValueError(f"The SortingAnalyzer is read only save extension {self.extension_name} is not possible") if self.format == "binary_folder": import pandas as pd @@ -1396,7 +1396,7 @@ def _save_data(self, **kwargs): elif isinstance(ext_data, np.ndarray): data_file = extension_folder / f"{ext_data_name}.npy" if isinstance(ext_data, np.memmap) and data_file.exists(): - # important some SortingResult like ComputeWaveforms already run the computation with memmap + # important some SortingAnalyzer like ComputeWaveforms already run the computation with memmap # so no need to save theses array pass else: @@ -1483,7 +1483,7 @@ def set_params(self, save=True, **params): params = self._set_params(**params) self.params = params - if self.sorting_result.is_read_only(): + if self.sorting_analyzer.is_read_only(): return if save: diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 06e900a3c5..48c4cdf1e0 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -71,26 +71,26 @@ class ChannelSparsity: Examples -------- - The class can also be used to construct/estimate the sparsity from a SortingResult or a Templates + The class can also be used to construct/estimate the sparsity from a SortingAnalyzer or a Templates with several methods: Using the N best channels (largest template amplitude): - >>> sparsity = ChannelSparsity.from_best_channels(sorting_result, num_channels, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels, peak_sign="neg") Using a neighborhood by radius: - >>> sparsity = ChannelSparsity.from_radius(sorting_result, radius_um, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um, peak_sign="neg") Using a SNR threshold: - >>> sparsity = ChannelSparsity.from_snr(sorting_result, threshold, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold, peak_sign="neg") Using a template energy threshold: - >>> sparsity = ChannelSparsity.from_energy(sorting_result, threshold) + >>> sparsity = ChannelSparsity.from_energy(sorting_analyzer, threshold) Using a recording/sorting property (e.g. "group"): - >>> sparsity = ChannelSparsity.from_property(sorting_result, by_property="group") + >>> sparsity = ChannelSparsity.from_property(sorting_analyzer, by_property="group") """ @@ -269,7 +269,7 @@ def from_dict(cls, dictionary: dict): ## Some convinient function to compute sparsity from several strategy @classmethod - def from_best_channels(cls, templates_or_sorting_result, num_channels, peak_sign="neg"): + def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): """ Construct sparsity from N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. @@ -277,17 +277,17 @@ def from_best_channels(cls, templates_or_sorting_result, num_channels, peak_sign from .template_tools import get_template_amplitudes mask = np.zeros( - (templates_or_sorting_result.unit_ids.size, templates_or_sorting_result.channel_ids.size), dtype="bool" + (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) - peak_values = get_template_amplitudes(templates_or_sorting_result, peak_sign=peak_sign) - for unit_ind, unit_id in enumerate(templates_or_sorting_result.unit_ids): + peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign) + for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_sorting_result.unit_ids, templates_or_sorting_result.channel_ids) + return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_radius(cls, templates_or_sorting_result, radius_um, peak_sign="neg"): + def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): """ Construct sparsity from a radius around the best channel. Use the "radius_um" argument to specify the radius in um @@ -295,48 +295,48 @@ def from_radius(cls, templates_or_sorting_result, radius_um, peak_sign="neg"): from .template_tools import get_template_extremum_channel mask = np.zeros( - (templates_or_sorting_result.unit_ids.size, templates_or_sorting_result.channel_ids.size), dtype="bool" + (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) - channel_locations = templates_or_sorting_result.get_channel_locations() + channel_locations = templates_or_sorting_analyzer.get_channel_locations() distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - best_chan = get_template_extremum_channel(templates_or_sorting_result, peak_sign=peak_sign, outputs="index") - for unit_ind, unit_id in enumerate(templates_or_sorting_result.unit_ids): + best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, outputs="index") + for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_ind = best_chan[unit_id] (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_sorting_result.unit_ids, templates_or_sorting_result.channel_ids) + return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_snr(cls, templates_or_sorting_result, threshold, noise_levels=None, peak_sign="neg"): + def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, peak_sign="neg"): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. """ from .template_tools import get_template_amplitudes - from .sortingresult import SortingResult + from .sortinganalyzer import SortingAnalyzer from .template import Templates assert ( - templates_or_sorting_result.sparsity is None - ), "To compute sparsity you need a dense SortingResult or Templates" + templates_or_sorting_analyzer.sparsity is None + ), "To compute sparsity you need a dense SortingAnalyzer or Templates" - unit_ids = templates_or_sorting_result.unit_ids - channel_ids = templates_or_sorting_result.channel_ids + unit_ids = templates_or_sorting_analyzer.unit_ids + channel_ids = templates_or_sorting_analyzer.channel_ids - if isinstance(templates_or_sorting_result, SortingResult): - ext = templates_or_sorting_result.get_extension("noise_levels") + if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): + ext = templates_or_sorting_analyzer.get_extension("noise_levels") assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" assert ext.params[ "return_scaled" ], "To compute sparsity from snr you need return_scaled=True for extensions" noise_levels = ext.data["noise_levels"] - elif isinstance(templates_or_sorting_result, Templates): + elif isinstance(templates_or_sorting_analyzer, Templates): assert noise_levels is not None mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_result, peak_sign=peak_sign, mode="extremum", return_scaled=True + templates_or_sorting_analyzer, peak_sign=peak_sign, mode="extremum", return_scaled=True ) for unit_ind, unit_id in enumerate(unit_ids): @@ -345,38 +345,38 @@ def from_snr(cls, templates_or_sorting_result, threshold, noise_levels=None, pea return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_result, threshold, noise_levels=None): + def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the SNR threshold. """ assert ( - templates_or_sorting_result.sparsity is None - ), "To compute sparsity you need a dense SortingResult or Templates" + templates_or_sorting_analyzer.sparsity is None + ), "To compute sparsity you need a dense SortingAnalyzer or Templates" from .template_tools import get_template_amplitudes - from .sortingresult import SortingResult + from .sortinganalyzer import SortingAnalyzer from .template import Templates - unit_ids = templates_or_sorting_result.unit_ids - channel_ids = templates_or_sorting_result.channel_ids + unit_ids = templates_or_sorting_analyzer.unit_ids + channel_ids = templates_or_sorting_analyzer.channel_ids - if isinstance(templates_or_sorting_result, SortingResult): - ext = templates_or_sorting_result.get_extension("noise_levels") + if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): + ext = templates_or_sorting_analyzer.get_extension("noise_levels") assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" assert ext.params[ "return_scaled" ], "To compute sparsity from snr you need return_scaled=True for extensions" noise_levels = ext.data["noise_levels"] - elif isinstance(templates_or_sorting_result, Templates): + elif isinstance(templates_or_sorting_analyzer, Templates): assert noise_levels is not None from .template_tools import _get_dense_templates_array mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - templates_array = _get_dense_templates_array(templates_or_sorting_result, return_scaled=True) + templates_array = _get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=True) templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): @@ -385,71 +385,71 @@ def from_ptp(cls, templates_or_sorting_result, threshold, noise_levels=None): return cls(mask, unit_ids, channel_ids) @classmethod - def from_energy(cls, sorting_result, threshold): + def from_energy(cls, sorting_analyzer, threshold): """ Construct sparsity from a threshold based on per channel energy ratio. Use the "threshold" argument to specify the SNR threshold. """ - assert sorting_result.sparsity is None, "To compute sparsity with energy you need a dense SortingResult" + assert sorting_analyzer.sparsity is None, "To compute sparsity with energy you need a dense SortingAnalyzer" - mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") + mask = np.zeros((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") # noise_levels - ext = sorting_result.get_extension("noise_levels") + ext = sorting_analyzer.get_extension("noise_levels") assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" assert ext.params["return_scaled"], "To compute sparsity from snr you need return_scaled=True for extensions" noise_levels = ext.data["noise_levels"] # waveforms - ext_waveforms = sorting_result.get_extension("waveforms") + ext_waveforms = sorting_analyzer.get_extension("waveforms") assert ext_waveforms is not None, "To compute sparsity from energy you need to compute 'waveforms' first" namples = ext_waveforms.nbefore + ext_waveforms.nafter noise = np.sqrt(namples) * noise_levels - for unit_ind, unit_id in enumerate(sorting_result.unit_ids): + for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): wfs = ext_waveforms.get_waveforms_one_unit(unit_id, force_dense=True) energies = np.linalg.norm(wfs, axis=(0, 1)) chan_inds = np.nonzero(energies / (noise * np.sqrt(len(wfs))) >= threshold) mask[unit_ind, chan_inds] = True - return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) + return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) @classmethod - def from_property(cls, sorting_result, by_property): + def from_property(cls, sorting_analyzer, by_property): """ Construct sparsity witha property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. """ # check consistency assert ( - by_property in sorting_result.recording.get_property_keys() + by_property in sorting_analyzer.recording.get_property_keys() ), f"Property {by_property} is not a recording property" assert ( - by_property in sorting_result.sorting.get_property_keys() + by_property in sorting_analyzer.sorting.get_property_keys() ), f"Property {by_property} is not a sorting property" - mask = np.zeros((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") - rec_by = sorting_result.recording.split_by(by_property) - for unit_ind, unit_id in enumerate(sorting_result.unit_ids): - unit_property = sorting_result.sorting.get_property(by_property)[unit_ind] + mask = np.zeros((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") + rec_by = sorting_analyzer.recording.split_by(by_property) + for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): + unit_property = sorting_analyzer.sorting.get_property(by_property)[unit_ind] assert ( unit_property in rec_by.keys() ), f"Unit property {unit_property} cannot be found in the recording properties" - chan_inds = sorting_result.recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) + chan_inds = sorting_analyzer.recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) mask[unit_ind, chan_inds] = True - return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) + return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) @classmethod - def create_dense(cls, sorting_result): + def create_dense(cls, sorting_analyzer): """ Create a sparsity object with all selected channel for all units. """ - mask = np.ones((sorting_result.unit_ids.size, sorting_result.channel_ids.size), dtype="bool") - return cls(mask, sorting_result.unit_ids, sorting_result.channel_ids) + mask = np.ones((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") + return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) def compute_sparsity( - templates_or_sorting_result, + templates_or_sorting_analyzer, noise_levels=None, method="radius", peak_sign="neg", @@ -463,10 +463,10 @@ def compute_sparsity( Parameters ---------- - templates_or_sorting_result: Templates | SortingResult - A Templates or a SortingResult object. + templates_or_sorting_analyzer: Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. Some methods accept both objects ("best_channels", "radius", ) - Other methods require only SortingResult because internally the recording is needed. + Other methods require only SortingAnalyzer because internally the recording is needed. {} @@ -479,50 +479,50 @@ def compute_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates from .waveforms_extractor_backwards_compatibility import MockWaveformExtractor - from .sortingresult import SortingResult + from .sortinganalyzer import SortingAnalyzer - if isinstance(templates_or_sorting_result, MockWaveformExtractor): + if isinstance(templates_or_sorting_analyzer, MockWaveformExtractor): # to keep backward compatibility - templates_or_sorting_result = templates_or_sorting_result.sorting_result + templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer if method in ("best_channels", "radius", "snr", "ptp"): assert isinstance( - templates_or_sorting_result, (Templates, SortingResult) - ), f"compute_sparsity(method='{method}') need Templates or SortingResult" + templates_or_sorting_analyzer, (Templates, SortingAnalyzer) + ), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer" else: assert isinstance( - templates_or_sorting_result, SortingResult - ), f"compute_sparsity(method='{method}') need SortingResult" + templates_or_sorting_analyzer, SortingAnalyzer + ), f"compute_sparsity(method='{method}') need SortingAnalyzer" - if method in ("snr", "ptp") and isinstance(templates_or_sorting_result, Templates): + if method in ("snr", "ptp") and isinstance(templates_or_sorting_analyzer, Templates): assert ( noise_levels is not None ), f"compute_sparsity(..., method='{method}') with Templates need noise_levels as input" if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" - sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_result, num_channels, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels, peak_sign=peak_sign) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" - sparsity = ChannelSparsity.from_radius(templates_or_sorting_result, radius_um, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_radius(templates_or_sorting_analyzer, radius_um, peak_sign=peak_sign) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_snr( - templates_or_sorting_result, threshold, noise_levels=noise_levels, peak_sign=peak_sign + templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, peak_sign=peak_sign ) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_ptp( - templates_or_sorting_result, + templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_energy(templates_or_sorting_result, threshold) + sparsity = ChannelSparsity.from_energy(templates_or_sorting_analyzer, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" - sparsity = ChannelSparsity.from_property(templates_or_sorting_result, by_property) + sparsity = ChannelSparsity.from_property(templates_or_sorting_analyzer, by_property) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -545,7 +545,7 @@ def estimate_sparsity( **job_kwargs, ): """ - Estimate the sparsity without needing a SortingResult or Templates object + Estimate the sparsity without needing a SortingAnalyzer or Templates object This is faster than `spikeinterface.waveforms_extractor.precompute_sparsity()` and it traverses the recording to compute the average templates for each unit. diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 098b8d7237..509c810d94 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -4,14 +4,14 @@ from .template import Templates from .sparsity import _sparsity_doc -from .sortingresult import SortingResult +from .sortinganalyzer import SortingAnalyzer # TODO make this function a non private function def _get_dense_templates_array(one_object, return_scaled=True): if isinstance(one_object, Templates): templates_array = one_object.get_dense_templates() - elif isinstance(one_object, SortingResult): + elif isinstance(one_object, SortingAnalyzer): ext = one_object.get_extension("templates") if ext is not None: templates_array = ext.data["average"] @@ -26,9 +26,9 @@ def _get_dense_templates_array(one_object, return_scaled=True): if ext is not None: templates_array = ext.data["average"] else: - raise ValueError("SortingResult need extension 'templates' or 'fast_templates' to be computed") + raise ValueError("SortingAnalyzer need extension 'templates' or 'fast_templates' to be computed") else: - raise ValueError("Input should be Templates or SortingResult or SortingResult") + raise ValueError("Input should be Templates or SortingAnalyzer or SortingAnalyzer") return templates_array @@ -36,20 +36,20 @@ def _get_dense_templates_array(one_object, return_scaled=True): def _get_nbefore(one_object): if isinstance(one_object, Templates): return one_object.nbefore - elif isinstance(one_object, SortingResult): + elif isinstance(one_object, SortingAnalyzer): ext = one_object.get_extension("templates") if ext is not None: return ext.nbefore ext = one_object.get_extension("fast_templates") if ext is not None: return ext.nbefore - raise ValueError("SortingResult need extension 'templates' or 'fast_templates' to be computed") + raise ValueError("SortingAnalyzer need extension 'templates' or 'fast_templates' to be computed") else: - raise ValueError("Input should be Templates or SortingResult or SortingResult") + raise ValueError("Input should be Templates or SortingAnalyzer or SortingAnalyzer") def get_template_amplitudes( - templates_or_sorting_result, + templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", return_scaled: bool = True, @@ -59,8 +59,8 @@ def get_template_amplitudes( Parameters ---------- - templates_or_sorting_result: Templates | SortingResult - A Templates or a SortingResult object + templates_or_sorting_analyzer: Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "extremum" @@ -77,12 +77,12 @@ def get_template_amplitudes( assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" - unit_ids = templates_or_sorting_result.unit_ids - before = _get_nbefore(templates_or_sorting_result) + unit_ids = templates_or_sorting_analyzer.unit_ids + before = _get_nbefore(templates_or_sorting_analyzer) peak_values = {} - templates_array = _get_dense_templates_array(templates_or_sorting_result, return_scaled=return_scaled) + templates_array = _get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) for unit_ind, unit_id in enumerate(unit_ids): template = templates_array[unit_ind, :, :] @@ -108,7 +108,7 @@ def get_template_amplitudes( def get_template_extremum_channel( - templates_or_sorting_result, + templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", outputs: "id" | "index" = "id", @@ -118,8 +118,8 @@ def get_template_extremum_channel( Parameters ---------- - templates_or_sorting_result: Templates | SortingResult - A Templates or a SortingResult object + templates_or_sorting_analyzer: Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "extremum" @@ -139,10 +139,10 @@ def get_template_extremum_channel( assert mode in ("extremum", "at_index") assert outputs in ("id", "index") - unit_ids = templates_or_sorting_result.unit_ids - channel_ids = templates_or_sorting_result.channel_ids + unit_ids = templates_or_sorting_analyzer.unit_ids + channel_ids = templates_or_sorting_analyzer.channel_ids - peak_values = get_template_amplitudes(templates_or_sorting_result, peak_sign=peak_sign, mode=mode) + peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode) extremum_channels_id = {} extremum_channels_index = {} for unit_id in unit_ids: @@ -156,7 +156,7 @@ def get_template_extremum_channel( return extremum_channels_index -def get_template_extremum_channel_peak_shift(templates_or_sorting_result, peak_sign: "neg" | "pos" | "both" = "neg"): +def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg"): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. @@ -164,8 +164,8 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_result, peak_s Parameters ---------- - templates_or_sorting_result: Templates | SortingResult - A Templates or a SortingResult object + templates_or_sorting_analyzer: Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels @@ -174,15 +174,15 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_result, peak_s shifts: dict Dictionary with unit ids as keys and shifts as values """ - unit_ids = templates_or_sorting_result.unit_ids - channel_ids = templates_or_sorting_result.channel_ids - nbefore = _get_nbefore(templates_or_sorting_result) + unit_ids = templates_or_sorting_analyzer.unit_ids + channel_ids = templates_or_sorting_analyzer.channel_ids + nbefore = _get_nbefore(templates_or_sorting_analyzer) - extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_result, peak_sign=peak_sign) + extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign) shifts = {} - templates_array = _get_dense_templates_array(templates_or_sorting_result) + templates_array = _get_dense_templates_array(templates_or_sorting_analyzer) for unit_ind, unit_id in enumerate(unit_ids): template = templates_array[unit_ind, :, :] @@ -203,7 +203,7 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_result, peak_s def get_template_extremum_amplitude( - templates_or_sorting_result, + templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index", ): @@ -212,8 +212,8 @@ def get_template_extremum_amplitude( Parameters ---------- - templates_or_sorting_result: Templates | SortingResult - A Templates or a SortingResult object + templates_or_sorting_analyzer: Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object peak_sign: "neg" | "pos" | "both" Sign of the template to compute best channels mode: "extremum" | "at_index", default: "at_index" @@ -228,12 +228,12 @@ def get_template_extremum_amplitude( """ assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" - unit_ids = templates_or_sorting_result.unit_ids - channel_ids = templates_or_sorting_result.channel_ids + unit_ids = templates_or_sorting_analyzer.unit_ids + channel_ids = templates_or_sorting_analyzer.channel_ids - extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_result, peak_sign=peak_sign, mode=mode) + extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode) - extremum_amplitudes = get_template_amplitudes(templates_or_sorting_result, peak_sign=peak_sign, mode=mode) + extremum_amplitudes = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode) unit_amplitudes = {} for unit_id in unit_ids: diff --git a/src/spikeinterface/core/tests/test_result_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py similarity index 63% rename from src/spikeinterface/core/tests/test_result_core.py rename to src/spikeinterface/core/tests/test_analyzer_extension_core.py index 4c562ffb0d..cb7450d561 100644 --- a/src/spikeinterface/core/tests/test_result_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -4,7 +4,7 @@ import shutil from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import start_sorting_result +from spikeinterface.core import create_sorting_analyzer import numpy as np @@ -14,7 +14,7 @@ cache_folder = Path("cache_folder") / "core" -def get_sorting_result(format="memory", sparse=True): +def get_sorting_analyzer(format="memory", sparse=True): recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, @@ -43,32 +43,32 @@ def get_sorting_result(format="memory", sparse=True): if folder and folder.exists(): shutil.rmtree(folder) - sorting_result = start_sorting_result( + sorting_analyzer = create_sorting_analyzer( sorting, recording, format=format, folder=folder, sparse=sparse, sparsity=None ) - return sorting_result + return sorting_analyzer -def _check_result_extension(sorting_result, extension_name): +def _check_result_extension(sorting_analyzer, extension_name): # select unit_ids to several format for format in ("memory", "binary_folder", "zarr"): # for format in ("memory", ): if format != "memory": if format == "zarr": - folder = cache_folder / f"test_SortingResult_{extension_name}_select_units_with_{format}.zarr" + folder = cache_folder / f"test_SortingAnalyzer_{extension_name}_select_units_with_{format}.zarr" else: - folder = cache_folder / f"test_SortingResult_{extension_name}_select_units_with_{format}" + folder = cache_folder / f"test_SortingAnalyzer_{extension_name}_select_units_with_{format}" if folder.exists(): shutil.rmtree(folder) else: folder = None # check unit slice - keep_unit_ids = sorting_result.sorting.unit_ids[::2] - sorting_result2 = sorting_result.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) + keep_unit_ids = sorting_analyzer.sorting.unit_ids[::2] + sorting_analyzer2 = sorting_analyzer.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) - data = sorting_result2.get_extension(extension_name).data + data = sorting_analyzer2.get_extension(extension_name).data # for k, arr in data.items(): # print(k, arr.shape) @@ -76,31 +76,31 @@ def _check_result_extension(sorting_result, extension_name): @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeWaveforms(format, sparse): - sorting_result = get_sorting_result(format=format, sparse=sparse) + sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) - sorting_result.select_random_spikes(max_spikes_per_unit=50, seed=2205) - ext = sorting_result.compute("waveforms", **job_kwargs) + sorting_analyzer.select_random_spikes(max_spikes_per_unit=50, seed=2205) + ext = sorting_analyzer.compute("waveforms", **job_kwargs) wfs = ext.data["waveforms"] - _check_result_extension(sorting_result, "waveforms") + _check_result_extension(sorting_analyzer, "waveforms") @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeTemplates(format, sparse): - sorting_result = get_sorting_result(format=format, sparse=sparse) + sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) - sorting_result.select_random_spikes(max_spikes_per_unit=20, seed=2205) + sorting_analyzer.select_random_spikes(max_spikes_per_unit=20, seed=2205) with pytest.raises(AssertionError): # This require "waveforms first and should trig an error - sorting_result.compute("templates") + sorting_analyzer.compute("templates") job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) - sorting_result.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("waveforms", **job_kwargs) # compute some operators - sorting_result.compute( + sorting_analyzer.compute( "templates", operators=[ "average", @@ -110,20 +110,20 @@ def test_ComputeTemplates(format, sparse): ) # ask for more operator later - ext = sorting_result.get_extension("templates") + ext = sorting_analyzer.get_extension("templates") templated_median = ext.get_templates(operator="median") templated_per_5 = ext.get_templates(operator="percentile", percentile=5.0) # they all should be in data - data = sorting_result.get_extension("templates").data + data = sorting_analyzer.get_extension("templates").data for k in ["average", "std", "median", "pencentile_5.0", "pencentile_95.0"]: assert k in data.keys() - assert data[k].shape[0] == sorting_result.unit_ids.size - assert data[k].shape[2] == sorting_result.channel_ids.size + assert data[k].shape[0] == sorting_analyzer.unit_ids.size + assert data[k].shape[2] == sorting_analyzer.channel_ids.size assert np.any(data[k] > 0) # import matplotlib.pyplot as plt - # for unit_index, unit_id in enumerate(sorting_result.unit_ids): + # for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): # fig, ax = plt.subplots() # for k in data.keys(): # wf0 = data[k][unit_index, :, :] @@ -131,13 +131,13 @@ def test_ComputeTemplates(format, sparse): # ax.legend() # plt.show() - _check_result_extension(sorting_result, "templates") + _check_result_extension(sorting_analyzer, "templates") @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeFastTemplates(format, sparse): - sorting_result = get_sorting_result(format=format, sparse=sparse) + sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) # TODO check this because this is not passing with n_jobs=2 job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) @@ -145,29 +145,29 @@ def test_ComputeFastTemplates(format, sparse): ms_before = 1.0 ms_after = 2.5 - sorting_result.select_random_spikes(max_spikes_per_unit=20, seed=2205) - sorting_result.compute("fast_templates", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) + sorting_analyzer.select_random_spikes(max_spikes_per_unit=20, seed=2205) + sorting_analyzer.compute("fast_templates", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) - _check_result_extension(sorting_result, "fast_templates") + _check_result_extension(sorting_analyzer, "fast_templates") # compare ComputeTemplates with dense and ComputeFastTemplates: should give the same on "average" - other_sorting_result = get_sorting_result(format=format, sparse=False) - other_sorting_result.select_random_spikes(max_spikes_per_unit=20, seed=2205) - other_sorting_result.compute("waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) - other_sorting_result.compute( + other_sorting_analyzer = get_sorting_analyzer(format=format, sparse=False) + other_sorting_analyzer.select_random_spikes(max_spikes_per_unit=20, seed=2205) + other_sorting_analyzer.compute("waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) + other_sorting_analyzer.compute( "templates", operators=[ "average", ], ) - templates0 = sorting_result.get_extension("fast_templates").data["average"] - templates1 = other_sorting_result.get_extension("templates").data["average"] + templates0 = sorting_analyzer.get_extension("fast_templates").data["average"] + templates1 = other_sorting_analyzer.get_extension("templates").data["average"] np.testing.assert_almost_equal(templates0, templates1) # import matplotlib.pyplot as plt # fig, ax = plt.subplots() - # for unit_index, unit_id in enumerate(sorting_result.unit_ids): + # for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): # wf0 = templates0[unit_index, :, :] # ax.plot(wf0.T.flatten(), label=f"{unit_id}") # wf1 = templates1[unit_index, :, :] @@ -179,13 +179,13 @@ def test_ComputeFastTemplates(format, sparse): @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeNoiseLevels(format, sparse): - sorting_result = get_sorting_result(format=format, sparse=sparse) + sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) - sorting_result.compute("noise_levels", return_scaled=True) - print(sorting_result) + sorting_analyzer.compute("noise_levels", return_scaled=True) + print(sorting_analyzer) - noise_levels = sorting_result.get_extension("noise_levels").data["noise_levels"] - assert noise_levels.shape[0] == sorting_result.channel_ids.size + noise_levels = sorting_analyzer.get_extension("noise_levels").data["noise_levels"] + assert noise_levels.shape[0] == sorting_analyzer.channel_ids.size if __name__ == "__main__": diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index a4ae651de6..effd116d44 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,7 +3,7 @@ from pathlib import Path import shutil -from spikeinterface import start_sorting_result, get_template_extremum_channel, generate_ground_truth_recording +from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording # from spikeinterface.sortingcomponents.peak_detection import detect_peaks @@ -77,10 +77,10 @@ def test_run_node_pipeline(): spikes = sorting.to_spike_vector() # create peaks from spikes - sorting_result = start_sorting_result(sorting, recording, format="memory") - sorting_result.select_random_spikes() - sorting_result.compute("fast_templates") - extremum_channel_inds = get_template_extremum_channel(sorting_result, peak_sign="neg", outputs="index") + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("fast_templates") + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) diff --git a/src/spikeinterface/core/tests/test_sortingresult.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py similarity index 53% rename from src/spikeinterface/core/tests/test_sortingresult.py rename to src/spikeinterface/core/tests/test_sortinganalyzer.py index e6d7872396..1ae0d193e6 100644 --- a/src/spikeinterface/core/tests/test_sortingresult.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -4,8 +4,8 @@ import shutil from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import SortingResult, start_sorting_result, load_sorting_result -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core import SortingAnalyzer, create_sorting_analyzer, load_sorting_analyzer +from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension import numpy as np @@ -28,97 +28,97 @@ def get_dataset(): return recording, sorting -def test_SortingResult_memory(): +def test_SortingAnalyzer_memory(): recording, sorting = get_dataset() - sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=False, sparsity=None) - _check_sorting_results(sorting_result, sorting) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting) - sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True, sparsity=None) - _check_sorting_results(sorting_result, sorting) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting) -def test_SortingResult_binary_folder(): +def test_SortingAnalyzer_binary_folder(): recording, sorting = get_dataset() - folder = cache_folder / "test_SortingResult_binary_folder" + folder = cache_folder / "test_SortingAnalyzer_binary_folder" if folder.exists(): shutil.rmtree(folder) - sorting_result = start_sorting_result( + sorting_analyzer = create_sorting_analyzer( sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None ) - sorting_result = load_sorting_result(folder, format="auto") - _check_sorting_results(sorting_result, sorting) + sorting_analyzer = load_sorting_analyzer(folder, format="auto") + _check_sorting_analyzers(sorting_analyzer, sorting) -def test_SortingResult_zarr(): +def test_SortingAnalyzer_zarr(): recording, sorting = get_dataset() - folder = cache_folder / "test_SortingResult_zarr.zarr" + folder = cache_folder / "test_SortingAnalyzer_zarr.zarr" if folder.exists(): shutil.rmtree(folder) - sorting_result = start_sorting_result(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) - sorting_result = load_sorting_result(folder, format="auto") - _check_sorting_results(sorting_result, sorting) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) + sorting_analyzer = load_sorting_analyzer(folder, format="auto") + _check_sorting_analyzers(sorting_analyzer, sorting) -def _check_sorting_results(sorting_result, original_sorting): +def _check_sorting_analyzers(sorting_analyzer, original_sorting): print() - print(sorting_result) + print(sorting_analyzer) register_result_extension(DummyResultExtension) - assert "channel_ids" in sorting_result.rec_attributes - assert "sampling_frequency" in sorting_result.rec_attributes - assert "num_samples" in sorting_result.rec_attributes + assert "channel_ids" in sorting_analyzer.rec_attributes + assert "sampling_frequency" in sorting_analyzer.rec_attributes + assert "num_samples" in sorting_analyzer.rec_attributes - probe = sorting_result.get_probe() - sparsity = sorting_result.sparsity + probe = sorting_analyzer.get_probe() + sparsity = sorting_analyzer.sparsity # compute - sorting_result.compute("dummy", param1=5.5) + sorting_analyzer.compute("dummy", param1=5.5) # equivalent - compute_dummy(sorting_result, param1=5.5) - ext = sorting_result.get_extension("dummy") + compute_dummy(sorting_analyzer, param1=5.5) + ext = sorting_analyzer.get_extension("dummy") assert ext is not None assert ext.params["param1"] == 5.5 - print(sorting_result) + print(sorting_analyzer) # recompute - sorting_result.compute("dummy", param1=5.5) + sorting_analyzer.compute("dummy", param1=5.5) # and delete - sorting_result.delete_extension("dummy") - ext = sorting_result.get_extension("dummy") + sorting_analyzer.delete_extension("dummy") + ext = sorting_analyzer.get_extension("dummy") assert ext is None - assert sorting_result.has_recording() + assert sorting_analyzer.has_recording() - if sorting_result.random_spikes_indices is None: - sorting_result.select_random_spikes(max_spikes_per_unit=10, seed=2205) - assert sorting_result.random_spikes_indices is not None - assert sorting_result.random_spikes_indices.size == 10 * sorting_result.sorting.unit_ids.size + if sorting_analyzer.random_spikes_indices is None: + sorting_analyzer.select_random_spikes(max_spikes_per_unit=10, seed=2205) + assert sorting_analyzer.random_spikes_indices is not None + assert sorting_analyzer.random_spikes_indices.size == 10 * sorting_analyzer.sorting.unit_ids.size # save to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": if format == "zarr": - folder = cache_folder / f"test_SortingResult_save_as_{format}.zarr" + folder = cache_folder / f"test_SortingAnalyzer_save_as_{format}.zarr" else: - folder = cache_folder / f"test_SortingResult_save_as_{format}" + folder = cache_folder / f"test_SortingAnalyzer_save_as_{format}" if folder.exists(): shutil.rmtree(folder) else: folder = None # compute one extension to check the save - sorting_result.compute("dummy") + sorting_analyzer.compute("dummy") - sorting_result2 = sorting_result.save_as(format=format, folder=folder) - ext = sorting_result2.get_extension("dummy") + sorting_analyzer2 = sorting_analyzer.save_as(format=format, folder=folder) + ext = sorting_analyzer2.get_extension("dummy") assert ext is not None - data = sorting_result2.get_extension("dummy").data + data = sorting_analyzer2.get_extension("dummy").data assert "result_one" in data assert data["result_two"].size == original_sorting.to_spike_vector().size @@ -126,27 +126,27 @@ def _check_sorting_results(sorting_result, original_sorting): for format in ("memory", "binary_folder", "zarr"): if format != "memory": if format == "zarr": - folder = cache_folder / f"test_SortingResult_select_units_with_{format}.zarr" + folder = cache_folder / f"test_SortingAnalyzer_select_units_with_{format}.zarr" else: - folder = cache_folder / f"test_SortingResult_select_units_with_{format}" + folder = cache_folder / f"test_SortingAnalyzer_select_units_with_{format}" if folder.exists(): shutil.rmtree(folder) else: folder = None # compute one extension to check the slice - sorting_result.compute("dummy") + sorting_analyzer.compute("dummy") keep_unit_ids = original_sorting.unit_ids[::2] - sorting_result2 = sorting_result.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) + sorting_analyzer2 = sorting_analyzer.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) # check that random_spikes_indices are remmaped - assert sorting_result2.random_spikes_indices is not None - some_spikes = sorting_result2.sorting.to_spike_vector()[sorting_result2.random_spikes_indices] + assert sorting_analyzer2.random_spikes_indices is not None + some_spikes = sorting_analyzer2.sorting.to_spike_vector()[sorting_analyzer2.random_spikes_indices] assert np.array_equal(np.unique(some_spikes["unit_index"]), np.arange(keep_unit_ids.size)) # check propagation of result data and correct sligin - assert np.array_equal(keep_unit_ids, sorting_result2.unit_ids) - data = sorting_result2.get_extension("dummy").data - assert data["result_one"] == sorting_result.get_extension("dummy").data["result_one"] + assert np.array_equal(keep_unit_ids, sorting_analyzer2.unit_ids) + data = sorting_analyzer2.get_extension("dummy").data + assert data["result_one"] == sorting_analyzer.get_extension("dummy").data["result_one"] # unit 1, 3, ... should be removed assert np.all(~np.isin(data["result_two"], [1, 3])) @@ -167,13 +167,13 @@ def _run(self, **kwargs): self.data["result_one"] = "abcd" # the result two has the same size of the spike vector!! # and represent nothing (the trick is to use unit_index for testing slice) - spikes = self.sorting_result.sorting.to_spike_vector() + spikes = self.sorting_analyzer.sorting.to_spike_vector() self.data["result_two"] = spikes["unit_index"].copy() def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - spikes = self.sorting_result.sorting.to_spike_vector() + spikes = self.sorting_analyzer.sorting.to_spike_vector() keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) # here the first key do not depend on unit_id # but the second need to be sliced!! @@ -205,7 +205,7 @@ def test_extension(): if __name__ == "__main__": - test_SortingResult_memory() - test_SortingResult_binary_folder() - test_SortingResult_zarr() + test_SortingAnalyzer_memory() + test_SortingAnalyzer_binary_folder() + test_SortingAnalyzer_zarr() test_extension() diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 361a06ece0..d650932162 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -6,7 +6,7 @@ from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, Templates from spikeinterface.core.core_tools import check_json from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import start_sorting_result +from spikeinterface.core import create_sorting_analyzer def test_ChannelSparsity(): @@ -199,24 +199,24 @@ def test_estimate_sparsity(): def test_compute_sparsity(): recording, sorting = get_dataset() - sorting_result = start_sorting_result(sorting=sorting, recording=recording, sparse=False) - sorting_result.select_random_spikes() - sorting_result.compute("fast_templates", return_scaled=True) - sorting_result.compute("noise_levels", return_scaled=True) + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, sparse=False) + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("fast_templates", return_scaled=True) + sorting_analyzer.compute("noise_levels", return_scaled=True) # this is needed for method="energy" - sorting_result.compute("waveforms", return_scaled=True) + sorting_analyzer.compute("waveforms", return_scaled=True) - # using object SortingResult - sparsity = compute_sparsity(sorting_result, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(sorting_result, method="radius", radius_um=50.0, peak_sign="neg") - sparsity = compute_sparsity(sorting_result, method="snr", threshold=5, peak_sign="neg") - sparsity = compute_sparsity(sorting_result, method="ptp", threshold=5) - sparsity = compute_sparsity(sorting_result, method="energy", threshold=5) - sparsity = compute_sparsity(sorting_result, method="by_property", by_property="group") + # using object SortingAnalyzer + sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") + sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) + sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) + sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") # using object Templates - templates = sorting_result.get_extension("fast_templates").get_data(outputs="Templates") - noise_levels = sorting_result.get_extension("noise_levels").get_data() + templates = sorting_analyzer.get_extension("fast_templates").get_data(outputs="Templates") + noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index db15bbfbea..d936674ed5 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -1,6 +1,6 @@ import pytest -from spikeinterface.core import generate_ground_truth_recording, start_sorting_result +from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer from spikeinterface import Templates @@ -12,7 +12,7 @@ ) -def get_sorting_result(): +def get_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[10.0, 5.0], sampling_frequency=10_000.0, @@ -25,52 +25,52 @@ def get_sorting_result(): recording.set_channel_groups([0, 0, 1, 1]) sorting.set_property("group", [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) - sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=False) - sorting_result.select_random_spikes() - sorting_result.compute("fast_templates") + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("fast_templates") - return sorting_result + return sorting_analyzer @pytest.fixture(scope="module") -def sorting_result(): - return get_sorting_result() +def sorting_analyzer(): + return get_sorting_analyzer() -def _get_templates_object_from_sorting_result(sorting_result): - ext = sorting_result.get_extension("fast_templates") +def _get_templates_object_from_sorting_analyzer(sorting_analyzer): + ext = sorting_analyzer.get_extension("fast_templates") templates = Templates( templates_array=ext.data["average"], - sampling_frequency=sorting_result.sampling_frequency, + sampling_frequency=sorting_analyzer.sampling_frequency, nbefore=ext.nbefore, # this is dense sparsity_mask=None, - channel_ids=sorting_result.channel_ids, - unit_ids=sorting_result.unit_ids, + channel_ids=sorting_analyzer.channel_ids, + unit_ids=sorting_analyzer.unit_ids, ) return templates -def test_get_template_amplitudes(sorting_result): - peak_values = get_template_amplitudes(sorting_result) +def test_get_template_amplitudes(sorting_analyzer): + peak_values = get_template_amplitudes(sorting_analyzer) print(peak_values) - templates = _get_templates_object_from_sorting_result(sorting_result) + templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) peak_values = get_template_amplitudes(templates) print(peak_values) -def test_get_template_extremum_channel(sorting_result): - extremum_channels_ids = get_template_extremum_channel(sorting_result, peak_sign="both") +def test_get_template_extremum_channel(sorting_analyzer): + extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign="both") print(extremum_channels_ids) - templates = _get_templates_object_from_sorting_result(sorting_result) + templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) extremum_channels_ids = get_template_extremum_channel(templates, peak_sign="both") print(extremum_channels_ids) -def test_get_template_extremum_channel_peak_shift(sorting_result): - shifts = get_template_extremum_channel_peak_shift(sorting_result, peak_sign="neg") +def test_get_template_extremum_channel_peak_shift(sorting_analyzer): + shifts = get_template_extremum_channel_peak_shift(sorting_analyzer, peak_sign="neg") print(shifts) - templates = _get_templates_object_from_sorting_result(sorting_result) + templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) shifts = get_template_extremum_channel_peak_shift(templates, peak_sign="neg") # DEBUG @@ -89,22 +89,22 @@ def test_get_template_extremum_channel_peak_shift(sorting_result): # plt.show() -def test_get_template_extremum_amplitude(sorting_result): +def test_get_template_extremum_amplitude(sorting_analyzer): - extremum_channels_ids = get_template_extremum_amplitude(sorting_result, peak_sign="both") + extremum_channels_ids = get_template_extremum_amplitude(sorting_analyzer, peak_sign="both") print(extremum_channels_ids) - templates = _get_templates_object_from_sorting_result(sorting_result) + templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) extremum_channels_ids = get_template_extremum_amplitude(templates, peak_sign="both") if __name__ == "__main__": # setup_module() - sorting_result = get_sorting_result() - print(sorting_result) + sorting_analyzer = get_sorting_analyzer() + print(sorting_analyzer) - test_get_template_amplitudes(sorting_result) - test_get_template_extremum_channel(sorting_result) - test_get_template_extremum_channel_peak_shift(sorting_result) - test_get_template_extremum_amplitude(sorting_result) + test_get_template_amplitudes(sorting_analyzer) + test_get_template_extremum_channel(sorting_analyzer) + test_get_template_extremum_channel_peak_shift(sorting_analyzer) + test_get_template_extremum_amplitude(sorting_analyzer) diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index 538a96343b..dcf16bb804 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -81,8 +81,8 @@ def test_extract_waveforms(): # test reading old WaveformsExtractor folder folder = cache_folder / "old_waveforms_extractor" - sorting_result_from_we = load_waveforms_backwards(folder, output="SortingResult") - print(sorting_result_from_we) + sorting_analyzer_from_we = load_waveforms_backwards(folder, output="SortingAnalyzer") + print(sorting_analyzer_from_we) mock_loaded_we_old = load_waveforms_backwards(folder, output="MockWaveformExtractor") print(mock_loaded_we_old) @@ -90,16 +90,16 @@ def test_extract_waveforms(): @pytest.mark.skip() def test_read_old_waveforms_extractor_binary(): folder = "/data_local/DataSpikeSorting/waveform_extractor_backward_compatibility/waveforms_extractor_1" - sorting_result = _read_old_waveforms_extractor_binary(folder) + sorting_analyzer = _read_old_waveforms_extractor_binary(folder) - print(sorting_result) + print(sorting_analyzer) - for ext_name in sorting_result.get_loaded_extension_names(): + for ext_name in sorting_analyzer.get_loaded_extension_names(): print() print(ext_name) - keys = sorting_result.get_extension(ext_name).data.keys() + keys = sorting_analyzer.get_extension(ext_name).data.keys() print(keys) - data = sorting_result.get_extension(ext_name).get_data() + data = sorting_analyzer.get_extension(ext_name).get_data() if isinstance(data, np.ndarray): print(data.shape) diff --git a/src/spikeinterface/core/tests/test_zarrextractors.py b/src/spikeinterface/core/tests/test_zarrextractors.py index 72247cb42a..2d6de6c8a0 100644 --- a/src/spikeinterface/core/tests/test_zarrextractors.py +++ b/src/spikeinterface/core/tests/test_zarrextractors.py @@ -30,7 +30,7 @@ def test_ZarrSortingExtractor(): sorting = ZarrSortingExtractor(folder) sorting = load_extractor(sorting.to_dict()) - # store the sorting in a sub group (for instance SortingResult) + # store the sorting in a sub group (for instance SortingAnalyzer) folder = cache_folder / "zarr_sorting_sub_group" if folder.is_dir(): shutil.rmtree(folder) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index b50c991744..be68473ea1 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -1,7 +1,7 @@ """ This module contains low-level functions to extract snippets of traces (aka "spike waveforms"). -This is internally used by SortingResult, but can also be used as a sorting component. +This is internally used by SortingAnalyzer, but can also be used as a sorting component. It is a 2-step approach: 1. allocate buffers (shared file or memory) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 412028c94f..ecb87c967e 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -1,7 +1,7 @@ """ This backwards compatibility module aims to: - * load old WaveformsExtractor saved with folder or zarr (version <=0.100) into the SortingResult (version>0.100) - * mock the function extract_waveforms() and the class SortingResult() but based SortingResult + * load old WaveformsExtractor saved with folder or zarr (version <=0.100) into the SortingAnalyzer (version>0.100) + * mock the function extract_waveforms() and the class SortingAnalyzer() but based SortingAnalyzer """ from __future__ import annotations @@ -18,16 +18,16 @@ from .baserecording import BaseRecording from .basesorting import BaseSorting -from .sortingresult import start_sorting_result, get_extension_class +from .sortinganalyzer import create_sorting_analyzer, get_extension_class from .job_tools import split_job_kwargs from .sparsity import ChannelSparsity -from .sortingresult import SortingResult, load_sorting_result +from .sortinganalyzer import SortingAnalyzer, load_sorting_analyzer from .base import load_extractor -from .result_core import ComputeWaveforms, ComputeTemplates +from .analyzer_extension_core import ComputeWaveforms, ComputeTemplates _backwards_compatibility_msg = """#### -# extract_waveforms() and WaveformExtractor() have been replace by SortingResult since version 0.101 -# You should use start_sorting_result() instead. +# extract_waveforms() and WaveformExtractor() have been replace by SortingAnalyzer since version 0.101 +# You should use create_sorting_analyzer() instead. # extract_waveforms() is now mocking the old behavior for backwards compatibility only and will be removed after 0.103 ####""" @@ -57,7 +57,7 @@ def extract_waveforms( ): """ This mock the extract_waveforms() in version <= 0.100 to not break old codes but using - the SortingResult (version >0.100) internally. + the SortingAnalyzer (version >0.100) internally. This return a MockWaveformExtractor object that mock the old WaveformExtractor """ @@ -92,74 +92,74 @@ def extract_waveforms( **other_kwargs, **job_kwargs, ) - sorting_result = start_sorting_result( + sorting_analyzer = create_sorting_analyzer( sorting, recording, format=format, folder=folder, sparse=sparse, sparsity=sparsity, **sparsity_kwargs ) # TODO propagate job_kwargs - sorting_result.select_random_spikes(max_spikes_per_unit=max_spikes_per_unit, seed=seed) + sorting_analyzer.select_random_spikes(max_spikes_per_unit=max_spikes_per_unit, seed=seed) waveforms_params = dict(ms_before=ms_before, ms_after=ms_after, return_scaled=return_scaled, dtype=dtype) - sorting_result.compute("waveforms", **waveforms_params, **job_kwargs) + sorting_analyzer.compute("waveforms", **waveforms_params, **job_kwargs) templates_params = dict(operators=list(precompute_template)) - sorting_result.compute("templates", **templates_params) + sorting_analyzer.compute("templates", **templates_params) # this also done because some metrics need it - sorting_result.compute("noise_levels") + sorting_analyzer.compute("noise_levels") - we = MockWaveformExtractor(sorting_result) + we = MockWaveformExtractor(sorting_analyzer) return we class MockWaveformExtractor: - def __init__(self, sorting_result): - self.sorting_result = sorting_result + def __init__(self, sorting_analyzer): + self.sorting_analyzer = sorting_analyzer def __repr__(self): txt = "MockWaveformExtractor: mock the old WaveformExtractor with " - txt += self.sorting_result.__repr__() + txt += self.sorting_analyzer.__repr__() return txt def is_sparse(self) -> bool: - return self.sorting_result.is_sparse() + return self.sorting_analyzer.is_sparse() def has_waveforms(self) -> bool: - return self.sorting_result.get_extension("waveforms") is not None + return self.sorting_analyzer.get_extension("waveforms") is not None def delete_waveforms(self) -> None: - self.sorting_result.delete_extension("waveforms") + self.sorting_analyzer.delete_extension("waveforms") @property def recording(self) -> BaseRecording: - return self.sorting_result.recording + return self.sorting_analyzer.recording @property def sorting(self) -> BaseSorting: - return self.sorting_result.sorting + return self.sorting_analyzer.sorting @property def channel_ids(self) -> np.ndarray: - return self.sorting_result.channel_ids + return self.sorting_analyzer.channel_ids @property def sampling_frequency(self) -> float: - return self.sorting_result.sampling_frequency + return self.sorting_analyzer.sampling_frequency @property def unit_ids(self) -> np.ndarray: - return self.sorting_result.unit_ids + return self.sorting_analyzer.unit_ids @property def nbefore(self) -> int: - ms_before = self.sorting_result.get_extension("waveforms").params["ms_before"] + ms_before = self.sorting_analyzer.get_extension("waveforms").params["ms_before"] return int(ms_before * self.sampling_frequency / 1000.0) @property def nafter(self) -> int: - ms_after = self.sorting_result.get_extension("waveforms").params["ms_after"] + ms_after = self.sorting_analyzer.get_extension("waveforms").params["ms_after"] return int(ms_after * self.sampling_frequency / 1000.0) @property @@ -168,71 +168,71 @@ def nsamples(self) -> int: @property def return_scaled(self) -> bool: - return self.sorting_result.get_extension("waveforms").params["return_scaled"] + return self.sorting_analyzer.get_extension("waveforms").params["return_scaled"] @property def dtype(self): - return self.sorting_result.get_extension("waveforms").params["dtype"] + return self.sorting_analyzer.get_extension("waveforms").params["dtype"] def is_read_only(self) -> bool: - return self.sorting_result.is_read_only() + return self.sorting_analyzer.is_read_only() def has_recording(self) -> bool: - return self.sorting_result._recording is not None + return self.sorting_analyzer._recording is not None def get_num_samples(self, segment_index: Optional[int] = None) -> int: - return self.sorting_result.get_num_samples(segment_index) + return self.sorting_analyzer.get_num_samples(segment_index) def get_total_samples(self) -> int: - return self.sorting_result.get_total_samples() + return self.sorting_analyzer.get_total_samples() def get_total_duration(self) -> float: - return self.sorting_result.get_total_duration() + return self.sorting_analyzer.get_total_duration() def get_num_channels(self) -> int: - return self.sorting_result.get_num_channels() + return self.sorting_analyzer.get_num_channels() def get_num_segments(self) -> int: - return self.sorting_result.get_num_segments() + return self.sorting_analyzer.get_num_segments() def get_probegroup(self): - return self.sorting_result.get_probegroup() + return self.sorting_analyzer.get_probegroup() def get_probe(self): - return self.sorting_result.get_probe() + return self.sorting_analyzer.get_probe() def is_filtered(self) -> bool: - return self.sorting_result.rec_attributes["is_filtered"] + return self.sorting_analyzer.rec_attributes["is_filtered"] def get_channel_locations(self) -> np.ndarray: - return self.sorting_result.get_channel_locations() + return self.sorting_analyzer.get_channel_locations() def channel_ids_to_indices(self, channel_ids) -> np.ndarray: - return self.sorting_result.channel_ids_to_indices(channel_ids) + return self.sorting_analyzer.channel_ids_to_indices(channel_ids) def get_recording_property(self, key) -> np.ndarray: - return self.sorting_result.get_recording_property(key) + return self.sorting_analyzer.get_recording_property(key) def get_sorting_property(self, key) -> np.ndarray: - return self.sorting_result.get_sorting_property(key) + return self.sorting_analyzer.get_sorting_property(key) @property def sparsity(self): - return self.sorting_result.sparsity + return self.sorting_analyzer.sparsity @property def folder(self): - if self.sorting_result.format != "memory": - return self.sorting_result.folder + if self.sorting_analyzer.format != "memory": + return self.sorting_analyzer.folder def has_extension(self, extension_name: str) -> bool: - return self.sorting_result.has_extension(extension_name) + return self.sorting_analyzer.has_extension(extension_name) def get_sampled_indices(self, unit_id): # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) with a complex dtype as follow selected_spikes = [] for segment_index in range(self.get_num_segments()): - inds = self.sorting_result.get_selected_indices_in_spike_train(unit_id, segment_index) + inds = self.sorting_analyzer.get_selected_indices_in_spike_train(unit_id, segment_index) sampled_index = np.zeros(inds.size, dtype=[("spike_index", "int64"), ("segment_index", "int64")]) sampled_index["spike_index"] = inds sampled_index["segment_index"][:] = segment_index @@ -249,28 +249,28 @@ def get_waveforms( force_dense: bool = False, ): # lazy and cache are ingnored - ext = self.sorting_result.get_extension("waveforms") + ext = self.sorting_analyzer.get_extension("waveforms") unit_index = self.sorting.id_to_index(unit_id) spikes = self.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] spike_mask = some_spikes["unit_index"] == unit_index wfs = ext.data["waveforms"][spike_mask, :, :] if sparsity is not None: assert ( - self.sorting_result.sparsity is None + self.sorting_analyzer.sparsity is None ), "Waveforms are alreayd sparse! Cannot apply an additional sparsity." wfs = wfs[:, :, sparsity.mask[self.sorting.id_to_index(unit_id)]] if force_dense: assert sparsity is None - if self.sorting_result.sparsity is None: + if self.sorting_analyzer.sparsity is None: # nothing to do pass else: num_channels = self.get_num_channels() dense_wfs = np.zeros((wfs.shape[0], wfs.shape[1], num_channels), dtype=np.float32) - unit_sparsity = self.sorting_result.sparsity.mask[unit_index] + unit_sparsity = self.sorting_analyzer.sparsity.mask[unit_index] dense_wfs[:, :, unit_sparsity] = wfs wfs = dense_wfs @@ -283,7 +283,7 @@ def get_waveforms( def get_all_templates( self, unit_ids: list | np.array | tuple | None = None, mode="average", percentile: float | None = None ): - ext = self.sorting_result.get_extension("templates") + ext = self.sorting_analyzer.get_extension("templates") if mode == "percentile": key = f"pencentile_{percentile}" @@ -315,7 +315,7 @@ def load_waveforms( output="MockWaveformExtractor", ): """ - This read an old WaveformsExtactor folder (folder or zarr) and convert it into a SortingResult or MockWaveformExtractor. + This read an old WaveformsExtactor folder (folder or zarr) and convert it into a SortingAnalyzer or MockWaveformExtractor. It also mimic the old load_waveforms by opening a Sortingresult folder and return a MockWaveformExtractor. This later behavior is usefull to no break old code like this in versio >=0.101 @@ -334,23 +334,23 @@ def load_waveforms( if (folder / "spikeinterface_info.json").exists: with open(folder / "spikeinterface_info.json", mode="r") as f: info = json.load(f) - if info.get("object", None) == "SortingResult": + if info.get("object", None) == "SortingAnalyzer": # in this case the folder is already a sorting result from version >= 0.101.0 but create with the MockWaveformExtractor - sorting_result = load_sorting_result(folder) - sorting_result.load_all_saved_extension() - we = MockWaveformExtractor(sorting_result) + sorting_analyzer = load_sorting_analyzer(folder) + sorting_analyzer.load_all_saved_extension() + we = MockWaveformExtractor(sorting_analyzer) return we if folder.suffix == ".zarr": raise NotImplementedError # Alessio this is for you else: - sorting_result = _read_old_waveforms_extractor_binary(folder) + sorting_analyzer = _read_old_waveforms_extractor_binary(folder) - if output == "SortingResult": - return sorting_result + if output == "SortingAnalyzer": + return sorting_analyzer elif output in ("WaveformExtractor", "MockWaveformExtractor"): - return MockWaveformExtractor(sorting_result) + return MockWaveformExtractor(sorting_analyzer) def _read_old_waveforms_extractor_binary(folder): @@ -398,7 +398,7 @@ def _read_old_waveforms_extractor_binary(folder): elif (folder / "sorting.pickle").exists(): sorting = load_extractor(folder / "sorting.pickle", base_folder=folder) - sorting_result = SortingResult.create_memory(sorting, recording, sparsity, rec_attributes=rec_attributes) + sorting_analyzer = SortingAnalyzer.create_memory(sorting, recording, sparsity, rec_attributes=rec_attributes) # waveforms # need to concatenate all waveforms in one unique buffer @@ -439,9 +439,9 @@ def _read_old_waveforms_extractor_binary(folder): mask = some_spikes["unit_index"] == unit_index waveforms[:, :, : wfs.shape[2]][mask, :, :] = wfs - sorting_result.random_spikes_indices = random_spikes_indices + sorting_analyzer.random_spikes_indices = random_spikes_indices - ext = ComputeWaveforms(sorting_result) + ext = ComputeWaveforms(sorting_analyzer) ext.params = dict( ms_before=params["ms_before"], ms_after=params["ms_after"], @@ -449,7 +449,7 @@ def _read_old_waveforms_extractor_binary(folder): dtype=params["dtype"], ) ext.data["waveforms"] = waveforms - sorting_result.extensions["waveforms"] = ext + sorting_analyzer.extensions["waveforms"] = ext # templates saved dense # load cached templates @@ -459,13 +459,13 @@ def _read_old_waveforms_extractor_binary(folder): if template_file.is_file(): templates[mode] = np.load(template_file) if len(templates) > 0: - ext = ComputeTemplates(sorting_result) + ext = ComputeTemplates(sorting_analyzer) ext.params = dict( nbefore=nbefore, nafter=nafter, return_scaled=params["return_scaled"], operators=list(templates.keys()) ) for mode, arr in templates.items(): ext.data[mode] = arr - sorting_result.extensions["templates"] = ext + sorting_analyzer.extensions["templates"] = ext # old extensions with same names and equvalent data except similarity>template_similarity old_extension_to_new_class = { @@ -486,7 +486,7 @@ def _read_old_waveforms_extractor_binary(folder): if not ext_folder.is_dir(): continue new_class = get_extension_class(new_name) - ext = new_class(sorting_result) + ext = new_class(sorting_analyzer) with open(ext_folder / "params.json", "r") as f: params = json.load(f) ext.params = params @@ -523,6 +523,6 @@ def _read_old_waveforms_extractor_binary(folder): # elif new_name == "principal_components": # # TODO: alessio this is for you # pass - sorting_result.extensions[new_name] = ext + sorting_analyzer.extensions[new_name] = ext - return sorting_result + return sorting_analyzer diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 106f8ccc1e..3f7962f214 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -255,7 +255,7 @@ def read_zarr( The loaded extractor """ # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is. - # for the futur SortingResult we will have this 2 fields!!! + # for the futur SortingAnalyzer we will have this 2 fields!!! root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) if "channel_ids" in root.keys(): return read_zarr_recording(folder_path, storage_options=storage_options) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index c77176b520..f509ecd6bf 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -2,7 +2,7 @@ import numpy as np -from ..core import start_sorting_result +from ..core import create_sorting_analyzer from ..core.template_tools import get_template_extremum_channel from ..postprocessing import compute_correlograms from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates @@ -11,7 +11,7 @@ def get_potential_auto_merge( - sorting_result, + sorting_analyzer, minimum_spikes=1000, maximum_distance_um=150.0, peak_sign="neg", @@ -57,8 +57,8 @@ def get_potential_auto_merge( Parameters ---------- - sorting_result: SortingResult - The SortingResult + sorting_analyzer: SortingAnalyzer + The SortingAnalyzer minimum_spikes: int, default: 1000 Minimum number of spikes for each unit to consider a potential merge. Enough spikes are needed to estimate the correlogram @@ -113,7 +113,7 @@ def get_potential_auto_merge( """ import scipy - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids # to get fast computation we will not analyse pairs when: @@ -144,7 +144,7 @@ def get_potential_auto_merge( # STEP 2 : remove contaminated auto corr if "remove_contaminated" in steps: contaminations, nb_violations = compute_refrac_period_violations( - sorting_result, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_analyzer, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) @@ -154,9 +154,9 @@ def get_potential_auto_merge( # STEP 3 : unit positions are estimated roughly with channel if "unit_positions" in steps: - chan_loc = sorting_result.get_channel_locations() + chan_loc = sorting_analyzer.get_channel_locations() unit_max_chan = get_template_extremum_channel( - sorting_result, peak_sign=peak_sign, mode="extremum", outputs="index" + sorting_analyzer, peak_sign=peak_sign, mode="extremum", outputs="index" ) unit_max_chan = list(unit_max_chan.values()) unit_locations = chan_loc[unit_max_chan, :] @@ -189,7 +189,7 @@ def get_potential_auto_merge( # STEP 5 : check if potential merge with CC also have template similarity if "template_similarity" in steps: - templates = sorting_result.get_extension("templates").get_templates(operator="average") + templates = sorting_analyzer.get_extension("templates").get_templates(operator="average") templates_diff = compute_templates_diff( sorting, templates, num_channels=num_channels, num_shift=num_shift, pair_mask=pair_mask ) @@ -198,7 +198,7 @@ def get_potential_auto_merge( # STEP 6 : validate the potential merges with CC increase the contamination quality metrics if "check_increase_score" in steps: pair_mask, pairs_decreased_score = check_improve_contaminations_score( - sorting_result, + sorting_analyzer, pair_mask, contaminations, firing_contamination_balance, @@ -429,7 +429,7 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair def check_improve_contaminations_score( - sorting_result, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms + sorting_analyzer, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms ): """ Check that the score is improve afeter a potential merge @@ -441,12 +441,12 @@ def check_improve_contaminations_score( Check that the contamination score is improved (decrease) after a potential merge """ - recording = sorting_result.recording - sorting = sorting_result.sorting + recording = sorting_analyzer.recording + sorting = sorting_analyzer.sorting pair_mask = pair_mask.copy() pairs_removed = [] - firing_rates = list(compute_firing_rates(sorting_result).values()) + firing_rates = list(compute_firing_rates(sorting_analyzer).values()) inds1, inds2 = np.nonzero(pair_mask) for i in range(inds1.size): @@ -464,13 +464,13 @@ def check_improve_contaminations_score( sorting, [[unit_id1, unit_id2]], new_unit_ids=[unit_id1], delta_time_ms=censored_period_ms ).select_units([unit_id1]) - sorting_result_new = start_sorting_result(sorting_merged, recording, format="memory", sparse=False) + sorting_analyzer_new = create_sorting_analyzer(sorting_merged, recording, format="memory", sparse=False) new_contaminations, _ = compute_refrac_period_violations( - sorting_result_new, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_analyzer_new, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms ) c_new = new_contaminations[unit_id1] - f_new = compute_firing_rates(sorting_result_new)[unit_id1] + f_new = compute_firing_rates(sorting_analyzer_new)[unit_id1] # old and new scores k = 1 + firing_contamination_balance diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index 11bf6b15e2..1d6fdd3ac1 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -1,7 +1,7 @@ from __future__ import annotations import numpy as np -from spikeinterface import SortingResult +from spikeinterface import SortingAnalyzer from ..core.template_tools import get_template_extremum_channel_peak_shift, get_template_amplitudes from ..postprocessing import align_sorting @@ -11,7 +11,7 @@ def remove_redundant_units( - sorting_or_sorting_result, + sorting_or_sorting_analyzer, align=True, unit_peak_shifts=None, delta_time=0.4, @@ -33,12 +33,12 @@ def remove_redundant_units( Parameters ---------- - sorting_or_sorting_result : BaseSorting or SortingResult - If SortingResult, the spike trains can be optionally realigned using the peak shift in the + sorting_or_sorting_analyzer : BaseSorting or SortingAnalyzer + If SortingAnalyzer, the spike trains can be optionally realigned using the peak shift in the template to improve the matching procedure. If BaseSorting, the spike trains are not aligned. align : bool, default: False - If True, spike trains are aligned (if a SortingResult is used) + If True, spike trains are aligned (if a SortingAnalyzer is used) delta_time : float, default: 0.4 The time in ms to consider matching spikes agreement_threshold : float, default: 0.2 @@ -65,17 +65,17 @@ def remove_redundant_units( Sorting object without redundant units """ - if isinstance(sorting_or_sorting_result, SortingResult): - sorting = sorting_or_sorting_result.sorting - sorting_result = sorting_or_sorting_result + if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer): + sorting = sorting_or_sorting_analyzer.sorting + sorting_analyzer = sorting_or_sorting_analyzer else: - assert not align, "The 'align' option is only available when a SortingResult is used as input" - sorting = sorting_or_sorting_result - sorting_result = None + assert not align, "The 'align' option is only available when a SortingAnalyzer is used as input" + sorting = sorting_or_sorting_analyzer + sorting_analyzer = None if align and unit_peak_shifts is None: - assert sorting_result is not None, "For align=True must give a SortingResult or explicit unit_peak_shifts" - unit_peak_shifts = get_template_extremum_channel_peak_shift(sorting_result) + assert sorting_analyzer is not None, "For align=True must give a SortingAnalyzer or explicit unit_peak_shifts" + unit_peak_shifts = get_template_extremum_channel_peak_shift(sorting_analyzer) if align: sorting_aligned = align_sorting(sorting, unit_peak_shifts) @@ -93,7 +93,7 @@ def remove_redundant_units( if remove_strategy in ("minimum_shift", "highest_amplitude"): # this is the values at spike index ! - peak_values = get_template_amplitudes(sorting_result, peak_sign=peak_sign, mode="at_index") + peak_values = get_template_amplitudes(sorting_analyzer, peak_sign=peak_sign, mode="at_index") peak_values = {unit_id: np.max(np.abs(values)) for unit_id, values in peak_values.items()} if remove_strategy == "minimum_shift": @@ -125,7 +125,7 @@ def remove_redundant_units( elif remove_strategy == "with_metrics": # TODO # @aurelien @alessio - # here sorting_result can implement the choice of the best one given an external metrics table + # here sorting_analyzer can implement the choice of the best one given an external metrics table # this will be implemented in a futur PR by the first who need it! raise NotImplementedError() else: diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index f14e08c45a..40a6e28e10 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -3,7 +3,7 @@ import pytest from pathlib import Path -from spikeinterface.core import generate_ground_truth_recording, start_sorting_result +from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer from spikeinterface.qualitymetrics import compute_quality_metrics if hasattr(pytest, "global_test_folder"): @@ -15,7 +15,7 @@ job_kwargs = dict(n_jobs=-1) -def make_sorting_result(sparse=True): +def make_sorting_analyzer(sparse=True): recording, sorting = generate_ground_truth_recording( durations=[300.0], sampling_frequency=30000.0, @@ -26,23 +26,23 @@ def make_sorting_result(sparse=True): seed=2205, ) - sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) - sorting_result.select_random_spikes() - sorting_result.compute("waveforms", **job_kwargs) - sorting_result.compute("templates") - sorting_result.compute("noise_levels") - # sorting_result.compute("principal_components") - # sorting_result.compute("template_similarity") - # sorting_result.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse) + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("noise_levels") + # sorting_analyzer.compute("principal_components") + # sorting_analyzer.compute("template_similarity") + # sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) - return sorting_result + return sorting_analyzer @pytest.fixture(scope="module") -def sorting_result_for_curation(): - return make_sorting_result(sparse=True) +def sorting_analyzer_for_curation(): + return make_sorting_analyzer(sparse=True) if __name__ == "__main__": - sorting_result = make_sorting_result(sparse=False) - print(sorting_result) + sorting_analyzer = make_sorting_analyzer(sparse=False) + print(sorting_analyzer) diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 66f1d6602f..4dd62a3178 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -4,12 +4,12 @@ import numpy as np -from spikeinterface.core import start_sorting_result +from spikeinterface.core import create_sorting_analyzer from spikeinterface.core.generate import inject_some_split_units from spikeinterface.curation import get_potential_auto_merge -from spikeinterface.curation.tests.common import make_sorting_result, sorting_result_for_curation +from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation if hasattr(pytest, "global_test_folder"): @@ -18,10 +18,10 @@ cache_folder = Path("cache_folder") / "curation" -def test_get_auto_merge_list(sorting_result_for_curation): +def test_get_auto_merge_list(sorting_analyzer_for_curation): - sorting = sorting_result_for_curation.sorting - recording = sorting_result_for_curation.recording + sorting = sorting_analyzer_for_curation.sorting + recording = sorting_analyzer_for_curation.recording num_unit_splited = 1 num_split = 2 @@ -35,13 +35,13 @@ def test_get_auto_merge_list(sorting_result_for_curation): job_kwargs = dict(n_jobs=-1) - sorting_result = start_sorting_result(sorting_with_split, recording, format="memory") - sorting_result.select_random_spikes() - sorting_result.compute("waveforms", **job_kwargs) - sorting_result.compute("templates") + sorting_analyzer = create_sorting_analyzer(sorting_with_split, recording, format="memory") + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") potential_merges, outs = get_potential_auto_merge( - sorting_result, + sorting_analyzer, minimum_spikes=1000, maximum_distance_um=150.0, peak_sign="neg", @@ -119,5 +119,5 @@ def test_get_auto_merge_list(sorting_result_for_curation): if __name__ == "__main__": - sorting_result = make_sorting_result(sparse=True) - test_get_auto_merge_list(sorting_result) + sorting_analyzer = make_sorting_analyzer(sparse=True) + test_get_auto_merge_list(sorting_analyzer) diff --git a/src/spikeinterface/curation/tests/test_remove_redundant.py b/src/spikeinterface/curation/tests/test_remove_redundant.py index 5a0f15f6e4..2877442cef 100644 --- a/src/spikeinterface/curation/tests/test_remove_redundant.py +++ b/src/spikeinterface/curation/tests/test_remove_redundant.py @@ -6,37 +6,37 @@ import numpy as np -from spikeinterface import start_sorting_result +from spikeinterface import create_sorting_analyzer from spikeinterface.core.generate import inject_some_duplicate_units -from spikeinterface.curation.tests.common import make_sorting_result, sorting_result_for_curation +from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation from spikeinterface.curation import remove_redundant_units -def test_remove_redundant_units(sorting_result_for_curation): +def test_remove_redundant_units(sorting_analyzer_for_curation): - sorting = sorting_result_for_curation.sorting - recording = sorting_result_for_curation.recording + sorting = sorting_analyzer_for_curation.sorting + recording = sorting_analyzer_for_curation.recording sorting_with_dup = inject_some_duplicate_units(sorting, ratio=0.8, num=4, seed=2205) # print(sorting.unit_ids) # print(sorting_with_dup.unit_ids) job_kwargs = dict(n_jobs=-1) - sorting_result = start_sorting_result(sorting_with_dup, recording, format="memory") - sorting_result.select_random_spikes() - sorting_result.compute("waveforms", **job_kwargs) - sorting_result.compute("templates") + sorting_analyzer = create_sorting_analyzer(sorting_with_dup, recording, format="memory") + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") for remove_strategy in ("max_spikes", "minimum_shift", "highest_amplitude"): - sorting_clean = remove_redundant_units(sorting_result, remove_strategy=remove_strategy) + sorting_clean = remove_redundant_units(sorting_analyzer, remove_strategy=remove_strategy) # print(sorting_clean) # print(sorting_clean.unit_ids) assert np.array_equal(sorting_clean.unit_ids, sorting.unit_ids) if __name__ == "__main__": - sorting_result = make_sorting_result(sparse=True) - test_remove_redundant_units(sorting_result) + sorting_analyzer = make_sorting_analyzer(sparse=True) + test_remove_redundant_units(sorting_analyzer) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 44191a2bed..5e2d47fb60 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -37,17 +37,17 @@ # local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") # recording, sorting = read_mearec(local_path) -# sorting_result = si.start_sorting_result(sorting, recording, format="memory") -# sorting_result.select_random_spikes() -# sorting_result.compute("waveforms") -# sorting_result.compute("templates") -# sorting_result.compute("noise_levels") -# sorting_result.compute("spike_amplitudes") -# sorting_result.compute("template_similarity") -# sorting_result.compute("unit_locations") +# sorting_analyzer = si.create_sorting_analyzer(sorting, recording, format="memory") +# sorting_analyzer.select_random_spikes() +# sorting_analyzer.compute("waveforms") +# sorting_analyzer.compute("templates") +# sorting_analyzer.compute("noise_levels") +# sorting_analyzer.compute("spike_amplitudes") +# sorting_analyzer.compute("template_similarity") +# sorting_analyzer.compute("unit_locations") # # plot_sorting_summary with curation -# w = sw.plot_sorting_summary(sorting_result, curation=True, backend="sortingview") +# w = sw.plot_sorting_summary(sorting_analyzer, curation=True, backend="sortingview") # # curation_link: # # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 9c85a3c860..d375996945 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -11,7 +11,7 @@ def export_report( - sorting_result, + sorting_analyzer, output_folder, remove_if_exists=False, format="png", @@ -28,8 +28,8 @@ def export_report( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object output_folder: str The output folder where the report files are saved remove_if_exists: bool, default: False @@ -48,15 +48,15 @@ def export_report( import matplotlib.pyplot as plt job_kwargs = fix_job_kwargs(job_kwargs) - sorting = sorting_result.sorting - unit_ids = sorting_result.unit_ids + sorting = sorting_analyzer.sorting + unit_ids = sorting_analyzer.unit_ids # load or compute spike_amplitudes - if sorting_result.has_extension("spike_amplitudes"): - spike_amplitudes = sorting_result.get_extension("spike_amplitudes").get_data(outputs="by_unit") + if sorting_analyzer.has_extension("spike_amplitudes"): + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit") elif force_computation: - sorting_result.compute("spike_amplitudes", **job_kwargs) - spike_amplitudes = sorting_result.get_extension("spike_amplitudes").get_data(outputs="by_unit") + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit") else: spike_amplitudes = None print( @@ -64,11 +64,11 @@ def export_report( ) # load or compute quality_metrics - if sorting_result.has_extension("quality_metrics"): - metrics = sorting_result.get_extension("quality_metrics").get_data() + if sorting_analyzer.has_extension("quality_metrics"): + metrics = sorting_analyzer.get_extension("quality_metrics").get_data() elif force_computation: - sorting_result.compute("quality_metrics") - metrics = sorting_result.get_extension("quality_metrics").get_data() + sorting_analyzer.compute("quality_metrics") + metrics = sorting_analyzer.get_extension("quality_metrics").get_data() else: metrics = None print( @@ -76,10 +76,10 @@ def export_report( ) # load or compute correlograms - if sorting_result.has_extension("correlograms"): - correlograms, bins = sorting_result.get_extension("correlograms").get_data() + if sorting_analyzer.has_extension("correlograms"): + correlograms, bins = sorting_analyzer.get_extension("correlograms").get_data() elif force_computation: - correlograms, bins = compute_correlograms(sorting_result, window_ms=100.0, bin_ms=1.0) + correlograms, bins = compute_correlograms(sorting_analyzer, window_ms=100.0, bin_ms=1.0) else: correlograms = None print( @@ -87,8 +87,8 @@ def export_report( ) # pre-compute unit locations if not done - if not sorting_result.has_extension("unit_locations"): - sorting_result.compute("unit_locations") + if not sorting_analyzer.has_extension("unit_locations"): + sorting_analyzer.compute("unit_locations") output_folder = Path(output_folder).absolute() if output_folder.is_dir(): @@ -101,28 +101,28 @@ def export_report( # unit list units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" - units["max_on_channel_id"] = pd.Series(get_template_extremum_channel(sorting_result, peak_sign="neg", outputs="id")) - units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_result, peak_sign="neg")) + units["max_on_channel_id"] = pd.Series(get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="id")) + units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign="neg")) units.to_csv(output_folder / "unit list.csv", sep="\t") unit_colors = sw.get_unit_colors(sorting) # global figures fig = plt.figure(figsize=(20, 10)) - w = sw.plot_unit_locations(sorting_result, figure=fig, unit_colors=unit_colors) + w = sw.plot_unit_locations(sorting_analyzer, figure=fig, unit_colors=unit_colors) fig.savefig(output_folder / f"unit_localization.{format}") if not show_figures: plt.close(fig) fig, ax = plt.subplots(figsize=(20, 10)) - sw.plot_unit_depths(sorting_result, ax=ax, unit_colors=unit_colors) + sw.plot_unit_depths(sorting_analyzer, ax=ax, unit_colors=unit_colors) fig.savefig(output_folder / f"unit_depths.{format}") if not show_figures: plt.close(fig) if spike_amplitudes and len(unit_ids) < 100: fig = plt.figure(figsize=(20, 10)) - sw.plot_all_amplitudes_distributions(sorting_result, figure=fig, unit_colors=unit_colors) + sw.plot_all_amplitudes_distributions(sorting_analyzer, figure=fig, unit_colors=unit_colors) fig.savefig(output_folder / f"amplitudes_distribution.{format}") if not show_figures: plt.close(fig) @@ -139,7 +139,7 @@ def export_report( constrained_layout=False, figsize=(15, 7), ) - sw.plot_unit_summary(sorting_result, unit_id, figure=fig) + sw.plot_unit_summary(sorting_analyzer, unit_id, figure=fig) fig.suptitle(f"unit {unit_id}") fig.savefig(units_folder / f"{unit_id}.{format}") if not show_figures: diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index 981fc1c465..2b5a813591 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -3,7 +3,7 @@ import pytest from pathlib import Path -from spikeinterface.core import generate_ground_truth_recording, start_sorting_result, compute_sparsity +from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer, compute_sparsity if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "exporters" @@ -11,7 +11,7 @@ cache_folder = Path("cache_folder") / "exporters" -def make_sorting_result(sparse=True, with_group=False): +def make_sorting_analyzer(sparse=True, with_group=False): recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=28000.0, @@ -33,43 +33,43 @@ def make_sorting_result(sparse=True, with_group=False): recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1]) sorting.set_property("group", [0, 0, 1, 1]) - sorting_result_unused = start_sorting_result( + sorting_analyzer_unused = create_sorting_analyzer( sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=None ) - sparsity_group = compute_sparsity(sorting_result_unused, method="by_property", by_property="group") + sparsity_group = compute_sparsity(sorting_analyzer_unused, method="by_property", by_property="group") - sorting_result = start_sorting_result( + sorting_analyzer = create_sorting_analyzer( sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=sparsity_group ) else: - sorting_result = start_sorting_result(sorting=sorting, recording=recording, format="memory", sparse=sparse) + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse) - sorting_result.select_random_spikes() - sorting_result.compute("waveforms") - sorting_result.compute("templates") - sorting_result.compute("noise_levels") - sorting_result.compute("principal_components") - sorting_result.compute("template_similarity") - sorting_result.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("waveforms") + sorting_analyzer.compute("templates") + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("principal_components") + sorting_analyzer.compute("template_similarity") + sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) - return sorting_result + return sorting_analyzer @pytest.fixture(scope="session") -def sorting_result_dense_for_export(): - return make_sorting_result(sparse=False) +def sorting_analyzer_dense_for_export(): + return make_sorting_analyzer(sparse=False) @pytest.fixture(scope="session") -def sorting_result_with_group_for_export(): - return make_sorting_result(sparse=False, with_group=True) +def sorting_analyzer_with_group_for_export(): + return make_sorting_analyzer(sparse=False, with_group=True) @pytest.fixture(scope="session") -def sorting_result_sparse_for_export(): - return make_sorting_result(sparse=True) +def sorting_analyzer_sparse_for_export(): + return make_sorting_analyzer(sparse=True) if __name__ == "__main__": - sorting_result = make_sorting_result(sparse=False) - print(sorting_result) + sorting_analyzer = make_sorting_analyzer(sparse=False) + print(sorting_analyzer) diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 3394a9a6d1..18ba15b975 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -12,23 +12,23 @@ from spikeinterface.exporters.tests.common import ( cache_folder, - make_sorting_result, - sorting_result_sparse_for_export, - sorting_result_with_group_for_export, - sorting_result_dense_for_export, + make_sorting_analyzer, + sorting_analyzer_sparse_for_export, + sorting_analyzer_with_group_for_export, + sorting_analyzer_dense_for_export, ) -def test_export_to_phy_dense(sorting_result_dense_for_export): +def test_export_to_phy_dense(sorting_analyzer_dense_for_export): output_folder1 = cache_folder / "phy_output_dense" for f in (output_folder1,): if f.is_dir(): shutil.rmtree(f) - sorting_result = sorting_result_dense_for_export + sorting_analyzer = sorting_analyzer_dense_for_export export_to_phy( - sorting_result, + sorting_analyzer, output_folder1, compute_pc_features=True, compute_amplitudes=True, @@ -38,17 +38,17 @@ def test_export_to_phy_dense(sorting_result_dense_for_export): ) -def test_export_to_phy_sparse(sorting_result_sparse_for_export): +def test_export_to_phy_sparse(sorting_analyzer_sparse_for_export): output_folder1 = cache_folder / "phy_output_1" output_folder2 = cache_folder / "phy_output_2" for f in (output_folder1, output_folder2): if f.is_dir(): shutil.rmtree(f) - sorting_result = sorting_result_sparse_for_export + sorting_analyzer = sorting_analyzer_sparse_for_export export_to_phy( - sorting_result, + sorting_analyzer, output_folder1, compute_pc_features=True, compute_amplitudes=True, @@ -59,7 +59,7 @@ def test_export_to_phy_sparse(sorting_result_sparse_for_export): # Test for previous crash when copy_binary=False. export_to_phy( - sorting_result, + sorting_analyzer, output_folder2, compute_pc_features=False, compute_amplitudes=False, @@ -70,18 +70,18 @@ def test_export_to_phy_sparse(sorting_result_sparse_for_export): ) -def test_export_to_phy_by_property(sorting_result_with_group_for_export): +def test_export_to_phy_by_property(sorting_analyzer_with_group_for_export): output_folder = cache_folder / "phy_output_property" for f in (output_folder,): if f.is_dir(): shutil.rmtree(f) - sorting_result = sorting_result_with_group_for_export - print(sorting_result.sparsity) + sorting_analyzer = sorting_analyzer_with_group_for_export + print(sorting_analyzer.sparsity) export_to_phy( - sorting_result, + sorting_analyzer, output_folder, compute_pc_features=True, compute_amplitudes=True, @@ -91,14 +91,14 @@ def test_export_to_phy_by_property(sorting_result_with_group_for_export): ) template_inds = np.load(output_folder / "template_ind.npy") - assert template_inds.shape == (sorting_result.unit_ids.size, 4) + assert template_inds.shape == (sorting_analyzer.unit_ids.size, 4) if __name__ == "__main__": - sorting_result_sparse = make_sorting_result(sparse=True) - sorting_result_group = make_sorting_result(sparse=False, with_group=True) - sorting_result_dense = make_sorting_result(sparse=False) + sorting_analyzer_sparse = make_sorting_analyzer(sparse=True) + sorting_analyzer_group = make_sorting_analyzer(sparse=False, with_group=True) + sorting_analyzer_dense = make_sorting_analyzer(sparse=False) - test_export_to_phy_dense(sorting_result_dense) - test_export_to_phy_sparse(sorting_result_sparse) - test_export_to_phy_by_property(sorting_result_group) + test_export_to_phy_dense(sorting_analyzer_dense) + test_export_to_phy_sparse(sorting_analyzer_sparse) + test_export_to_phy_by_property(sorting_analyzer_group) diff --git a/src/spikeinterface/exporters/tests/test_report.py b/src/spikeinterface/exporters/tests/test_report.py index 8ed6feb1df..c89a0d70c6 100644 --- a/src/spikeinterface/exporters/tests/test_report.py +++ b/src/spikeinterface/exporters/tests/test_report.py @@ -5,20 +5,20 @@ from spikeinterface.exporters import export_report -from spikeinterface.exporters.tests.common import cache_folder, make_sorting_result, sorting_result_sparse_for_export +from spikeinterface.exporters.tests.common import cache_folder, make_sorting_analyzer, sorting_analyzer_sparse_for_export -def test_export_report(sorting_result_sparse_for_export): +def test_export_report(sorting_analyzer_sparse_for_export): report_folder = cache_folder / "report" if report_folder.exists(): shutil.rmtree(report_folder) - sorting_result = sorting_result_sparse_for_export + sorting_analyzer = sorting_analyzer_sparse_for_export job_kwargs = dict(n_jobs=1, chunk_size=30000, progress_bar=True) - export_report(sorting_result, report_folder, force_computation=True, **job_kwargs) + export_report(sorting_analyzer, report_folder, force_computation=True, **job_kwargs) if __name__ == "__main__": - sorting_result = make_sorting_result(sparse=True) - test_export_report(sorting_result) + sorting_analyzer = make_sorting_analyzer(sparse=True) + test_export_report(sorting_analyzer) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 4d66c2769a..30f74e584b 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -13,7 +13,7 @@ BinaryRecordingExtractor, BinaryFolderRecording, ChannelSparsity, - SortingResult, + SortingAnalyzer, ) from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs from spikeinterface.postprocessing import ( @@ -24,7 +24,7 @@ def export_to_phy( - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, output_folder: str | Path, compute_pc_features: bool = True, compute_amplitudes: bool = True, @@ -43,8 +43,8 @@ def export_to_phy( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object output_folder: str | Path The output folder where the phy template-gui files are saved compute_pc_features: bool, default: True @@ -60,7 +60,7 @@ def export_to_phy( peak_sign: "neg" | "pos" | "both", default: "neg" Used by compute_spike_amplitudes template_mode: str, default: "median" - Parameter "mode" to be given to SortingResult.get_template() + Parameter "mode" to be given to SortingAnalyzer.get_template() dtype: dtype or None, default: None Dtype to save binary data verbose: bool, default: True @@ -73,34 +73,34 @@ def export_to_phy( """ import pandas as pd - assert isinstance(sorting_result, SortingResult), "sorting_result must be a SortingResult object" - sorting = sorting_result.sorting + assert isinstance(sorting_analyzer, SortingAnalyzer), "sorting_analyzer must be a SortingAnalyzer object" + sorting = sorting_analyzer.sorting assert ( - sorting_result.get_num_segments() == 1 - ), f"Export to phy only works with one segment, your extractor has {sorting_result.get_num_segments()} segments" - num_chans = sorting_result.get_num_channels() - fs = sorting_result.sampling_frequency + sorting_analyzer.get_num_segments() == 1 + ), f"Export to phy only works with one segment, your extractor has {sorting_analyzer.get_num_segments()} segments" + num_chans = sorting_analyzer.get_num_channels() + fs = sorting_analyzer.sampling_frequency job_kwargs = fix_job_kwargs(job_kwargs) # check sparsity - if (num_chans > 64) and (sparsity is None and not sorting_result.is_sparse()): + if (num_chans > 64) and (sparsity is None and not sorting_analyzer.is_sparse()): warnings.warn( "Exporting to Phy with many channels and without sparsity might result in a heavy and less " - "informative visualization. You can use use a sparse SortingResult or you can use the 'sparsity' " + "informative visualization. You can use use a sparse SortingAnalyzer or you can use the 'sparsity' " "argument to enforce sparsity (see compute_sparsity())" ) save_sparse = True - if sorting_result.is_sparse(): - used_sparsity = sorting_result.sparsity + if sorting_analyzer.is_sparse(): + used_sparsity = sorting_analyzer.sparsity if sparsity is not None: - warnings.warn("If the sorting_result is sparse the 'sparsity' argument is ignored") + warnings.warn("If the sorting_analyzer is sparse the 'sparsity' argument is ignored") elif sparsity is not None: used_sparsity = sparsity else: - used_sparsity = ChannelSparsity.create_dense(sorting_result) + used_sparsity = ChannelSparsity.create_dense(sorting_analyzer) save_sparse = False # convenient sparsity dict for the 3 cases to retrieve channl_inds sparse_dict = used_sparsity.unit_id_to_channel_indices @@ -130,19 +130,19 @@ def export_to_phy( # save dat file if dtype is None: - dtype = sorting_result.get_dtype() + dtype = sorting_analyzer.get_dtype() - if sorting_result.has_recording(): + if sorting_analyzer.has_recording(): if copy_binary: rec_path = output_folder / "recording.dat" - write_binary_recording(sorting_result.recording, file_paths=rec_path, dtype=dtype, **job_kwargs) - elif isinstance(sorting_result.recording, BinaryRecordingExtractor): - if isinstance(sorting_result.recording, BinaryFolderRecording): - bin_kwargs = sorting_result.recording._bin_kwargs + write_binary_recording(sorting_analyzer.recording, file_paths=rec_path, dtype=dtype, **job_kwargs) + elif isinstance(sorting_analyzer.recording, BinaryRecordingExtractor): + if isinstance(sorting_analyzer.recording, BinaryFolderRecording): + bin_kwargs = sorting_analyzer.recording._bin_kwargs else: - bin_kwargs = sorting_result.recording._kwargs + bin_kwargs = sorting_analyzer.recording._kwargs rec_path = bin_kwargs["file_paths"][0] - dtype = sorting_result.recording.get_dtype() + dtype = sorting_analyzer.recording.get_dtype() else: rec_path = "None" else: # don't save recording.dat @@ -167,7 +167,7 @@ def export_to_phy( f.write(f"dtype = '{dtype_str}'\n") f.write(f"offset = 0\n") f.write(f"sample_rate = {fs}\n") - f.write(f"hp_filtered = {sorting_result.recording.is_filtered()}") + f.write(f"hp_filtered = {sorting_analyzer.recording.is_filtered()}") # export spike_times/spike_templates/spike_clusters # here spike_labels is a remapping to unit_index @@ -180,8 +180,8 @@ def export_to_phy( # export templates/templates_ind/similar_templates # shape (num_units, num_samples, max_num_channels) - templates_ext = sorting_result.get_extension("templates") - templates_ext is not None, "export_to_phy need SortingResult with extension 'templates'" + templates_ext = sorting_analyzer.get_extension("templates") + templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates'" max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) num_samples = dense_templates.shape[1] @@ -194,9 +194,9 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - if not sorting_result.has_extension("template_similarity"): - sorting_result.compute("template_similarity") - template_similarity = sorting_result.get_extension("template_similarity").get_data() + if not sorting_analyzer.has_extension("template_similarity"): + sorting_analyzer.compute("template_similarity") + template_similarity = sorting_analyzer.get_extension("template_similarity").get_data() np.save(str(output_folder / "templates.npy"), templates) if save_sparse: @@ -204,9 +204,9 @@ def export_to_phy( np.save(str(output_folder / "similar_templates.npy"), template_similarity) channel_maps = np.arange(num_chans, dtype="int32") - channel_map_si = sorting_result.channel_ids - channel_positions = sorting_result.get_channel_locations().astype("float32") - channel_groups = sorting_result.get_recording_property("group") + channel_map_si = sorting_analyzer.channel_ids + channel_positions = sorting_analyzer.get_channel_locations().astype("float32") + channel_groups = sorting_analyzer.get_recording_property("group") if channel_groups is None: channel_groups = np.zeros(num_chans, dtype="int32") np.save(str(output_folder / "channel_map.npy"), channel_maps) @@ -215,17 +215,17 @@ def export_to_phy( np.save(str(output_folder / "channel_groups.npy"), channel_groups) if compute_amplitudes: - if not sorting_result.has_extension("spike_amplitudes"): - sorting_result.compute("spike_amplitudes", **job_kwargs) - amplitudes = sorting_result.get_extension("spike_amplitudes").get_data() + if not sorting_analyzer.has_extension("spike_amplitudes"): + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() amplitudes = amplitudes[:, np.newaxis] np.save(str(output_folder / "amplitudes.npy"), amplitudes) if compute_pc_features: - if not sorting_result.has_extension("principal_components"): - sorting_result.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + if not sorting_analyzer.has_extension("principal_components"): + sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - pca_extension = sorting_result.get_extension("principal_components") + pca_extension = sorting_analyzer.get_extension("principal_components") pca_extension.run_for_all_spikes(output_folder / "pc_features.npy", **job_kwargs) @@ -250,8 +250,8 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if sorting_result.has_extension("quality_metrics"): - qm_data = sorting_result.get_extension("quality_metrics").get_data() + if sorting_analyzer.has_extension("quality_metrics"): + qm_data = sorting_analyzer.get_extension("quality_metrics").get_data() for column_name in qm_data.columns: # already computed by phy if column_name not in ["num_spikes", "firing_rate"]: diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 18526613de..a02c437483 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -7,7 +7,7 @@ from spikeinterface.core.template_tools import get_template_extremum_channel -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type @@ -21,12 +21,12 @@ class ComputeAmplitudeScalings(ResultExtension): """ - Computes the amplitude scalings from a SortingResult. + Computes the amplitude scalings from a SortingAnalyzer. Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object sparsity: ChannelSparsity or None, default: None If waveforms are not sparse, sparsity is required if the number of channels is greater than `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. @@ -35,10 +35,10 @@ class ComputeAmplitudeScalings(ResultExtension): dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. ms_before : float or None, default: None The cut out to apply before the spike peak to extract local waveforms. - If None, the SortingResult ms_before is used. + If None, the SortingAnalyzer ms_before is used. ms_after : float or None, default: None The cut out to apply after the spike peak to extract local waveforms. - If None, the SortingResult ms_after is used. + If None, the SortingAnalyzer ms_after is used. handle_collisions: bool, default: True Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a @@ -68,8 +68,8 @@ class ComputeAmplitudeScalings(ResultExtension): nodepipeline_variables = ["amplitude_scalings", "collision_mask"] need_job_kwargs = True - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) + def __init__(self, sorting_analyzer): + ResultExtension.__init__(self, sorting_analyzer) self.collisions = None @@ -93,9 +93,9 @@ def _set_params( return params def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - spikes = self.sorting_result.sorting.to_spike_vector() + spikes = self.sorting_analyzer.sorting.to_spike_vector() keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) new_data = dict() @@ -106,19 +106,19 @@ def _select_extension_data(self, unit_ids): def _get_pipeline_nodes(self): - recording = self.sorting_result.recording - sorting = self.sorting_result.sorting + recording = self.sorting_analyzer.recording + sorting = self.sorting_analyzer.sorting - # TODO return_scaled is not any more a property of SortingResult this is hard coded for now + # TODO return_scaled is not any more a property of SortingAnalyzer this is hard coded for now return_scaled = True - all_templates = _get_dense_templates_array(self.sorting_result, return_scaled=return_scaled) - nbefore = _get_nbefore(self.sorting_result) + all_templates = _get_dense_templates_array(self.sorting_analyzer, return_scaled=return_scaled) + nbefore = _get_nbefore(self.sorting_analyzer) nafter = all_templates.shape[1] - nbefore # if ms_before / ms_after are set in params then the original templates are shorten if self.params["ms_before"] is not None: - cut_out_before = int(self.params["ms_before"] * self.sorting_result.sampling_frequency / 1000.0) + cut_out_before = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) assert ( cut_out_before <= nbefore ), f"`ms_before` must be smaller than `ms_before` used in ComputeTemplates: {nbefore}" @@ -126,7 +126,7 @@ def _get_pipeline_nodes(self): cut_out_before = nbefore if self.params["ms_after"] is not None: - cut_out_after = int(self.params["ms_after"] * self.sorting_result.sampling_frequency / 1000.0) + cut_out_after = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) assert ( cut_out_after <= nafter ), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}" @@ -135,29 +135,29 @@ def _get_pipeline_nodes(self): peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" extremum_channels_indices = get_template_extremum_channel( - self.sorting_result, peak_sign=peak_sign, outputs="index" + self.sorting_analyzer, peak_sign=peak_sign, outputs="index" ) # collisions handle_collisions = self.params["handle_collisions"] delta_collision_ms = self.params["delta_collision_ms"] - delta_collision_samples = int(delta_collision_ms / 1000 * self.sorting_result.sampling_frequency) + delta_collision_samples = int(delta_collision_ms / 1000 * self.sorting_analyzer.sampling_frequency) - if self.sorting_result.is_sparse() and self.params["sparsity"] is None: - sparsity = self.sorting_result.sparsity - elif self.sorting_result.is_sparse() and self.params["sparsity"] is not None: + if self.sorting_analyzer.is_sparse() and self.params["sparsity"] is None: + sparsity = self.sorting_analyzer.sparsity + elif self.sorting_analyzer.is_sparse() and self.params["sparsity"] is not None: sparsity = self.params["sparsity"] # assert provided sparsity is sparser than the one in the waveform extractor - waveform_sparsity = self.sorting_result.sparsity + waveform_sparsity = self.sorting_analyzer.sparsity assert np.all( np.sum(waveform_sparsity.mask, 1) - np.sum(sparsity.mask, 1) > 0 ), "The provided sparsity needs to be sparser than the one in the waveform extractor!" - elif not self.sorting_result.is_sparse() and self.params["sparsity"] is not None: + elif not self.sorting_analyzer.is_sparse() and self.params["sparsity"] is not None: sparsity = self.params["sparsity"] else: if self.params["max_dense_channels"] is not None: assert recording.get_num_channels() <= self.params["max_dense_channels"], "" - sparsity = ChannelSparsity.create_dense(self.sorting_result) + sparsity = ChannelSparsity.create_dense(self.sorting_analyzer) sparsity_mask = sparsity.mask spike_retriever_node = SpikeRetriever( @@ -188,7 +188,7 @@ def _run(self, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amp_scalings, collision_mask = run_node_pipeline( - self.sorting_result.recording, + self.sorting_analyzer.recording, nodes, job_kwargs=job_kwargs, job_name="amplitude_scalings", @@ -542,8 +542,8 @@ def fit_collision( # Parameters # ---------- -# we : SortingResult -# The SortingResult object. +# we : SortingAnalyzer +# The SortingAnalyzer object. # sparsity : ChannelSparsity, default=None # The ChannelSparsity. If None, only main channels are plotted. # num_collisions : int, default=None diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 5be43ac05a..3becdd1b16 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension, SortingResult +from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension, SortingAnalyzer try: import numba @@ -18,8 +18,8 @@ class ComputeCorrelograms(ResultExtension): Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object window_ms : float, default: 100.0 The window in ms bin_ms : float, default: 5 @@ -52,8 +52,8 @@ class ComputeCorrelograms(ResultExtension): use_nodepipeline = False need_job_kwargs = False - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) + def __init__(self, sorting_analyzer): + ResultExtension.__init__(self, sorting_analyzer) def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) @@ -62,14 +62,14 @@ def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str def _select_extension_data(self, unit_ids): # filter metrics dataframe - unit_indices = self.sorting_result.sorting.ids_to_indices(unit_ids) + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) new_ccgs = self.data["ccgs"][unit_indices][:, unit_indices] new_bins = self.data["bins"] new_data = dict(ccgs=new_ccgs, bins=new_bins) return new_data def _run(self): - ccgs, bins = compute_correlograms_on_sorting(self.sorting_result.sorting, **self.params) + ccgs, bins = compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) self.data["ccgs"] = ccgs self.data["bins"] = bins @@ -78,26 +78,26 @@ def _get_data(self): register_result_extension(ComputeCorrelograms) -compute_correlograms_sorting_result = ComputeCorrelograms.function_factory() +compute_correlograms_sorting_analyzer = ComputeCorrelograms.function_factory() def compute_correlograms( - sorting_result_or_sorting, + sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", ): - if isinstance(sorting_result_or_sorting, SortingResult): - return compute_correlograms_sorting_result( - sorting_result_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + return compute_correlograms_sorting_analyzer( + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method ) else: return compute_correlograms_on_sorting( - sorting_result_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method ) -compute_correlograms.__doc__ = compute_correlograms_sorting_result.__doc__ +compute_correlograms.__doc__ = compute_correlograms_sorting_analyzer.__doc__ def _make_bins(sorting, window_ms, bin_ms): diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index fffe793655..367db16533 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension try: import numba @@ -17,8 +17,8 @@ class ComputeISIHistograms(ResultExtension): Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object window_ms : float, default: 50 The window in ms bin_ms : float, default: 1 @@ -40,8 +40,8 @@ class ComputeISIHistograms(ResultExtension): use_nodepipeline = False need_job_kwargs = False - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) + def __init__(self, sorting_analyzer): + ResultExtension.__init__(self, sorting_analyzer) def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) @@ -50,14 +50,14 @@ def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str def _select_extension_data(self, unit_ids): # filter metrics dataframe - unit_indices = self.sorting_result.sorting.ids_to_indices(unit_ids) + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) new_isi_hists = self.data["isi_histograms"][unit_indices, :] new_bins = self.data["bins"] new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) return new_extension_data def _run(self): - isi_histograms, bins = _compute_isi_histograms(self.sorting_result.sorting, **self.params) + isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params) self.data["isi_histograms"] = isi_histograms self.data["bins"] = bins diff --git a/src/spikeinterface/postprocessing/noise_level.py b/src/spikeinterface/postprocessing/noise_level.py index abd47f574f..a168f34c7b 100644 --- a/src/spikeinterface/postprocessing/noise_level.py +++ b/src/spikeinterface/postprocessing/noise_level.py @@ -1,3 +1,3 @@ # "noise_levels" extensions is now in core # this is kept name space compatibility but should be removed soon -from ..core.result_core import ComputeNoiseLevels, compute_noise_levels +from ..core.analyzer_extension_core import ComputeNoiseLevels, compute_noise_levels diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 025afec105..155d072b26 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -9,7 +9,7 @@ import numpy as np -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs @@ -23,8 +23,8 @@ class ComputePrincipalComponents(ResultExtension): Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object n_components: int, default: 5 Number of components fo PCA mode: "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" @@ -34,7 +34,7 @@ class ComputePrincipalComponents(ResultExtension): - "concatenated": channels are concatenated and a global PCA is fitted sparsity: ChannelSparsity or None, default: None The sparsity to apply to waveforms. - If sorting_result is already sparse, the default sparsity will be used + If sorting_analyzer is already sparse, the default sparsity will be used whiten: bool, default: True If True, waveforms are pre-whitened dtype: dtype, default: "float32" @@ -42,9 +42,9 @@ class ComputePrincipalComponents(ResultExtension): Examples -------- - >>> sorting_result = start_sorting_result(sorting, recording) - >>> sorting_result.compute("principal_components", n_components=3, mode='by_channel_local') - >>> ext_pca = sorting_result.get_extension("principal_components") + >>> sorting_analyzer = create_sorting_analyzer(sorting, recording) + >>> sorting_analyzer.compute("principal_components", n_components=3, mode='by_channel_local') + >>> ext_pca = sorting_analyzer.get_extension("principal_components") >>> # get pre-computed projections for unit_id=1 >>> unit_projections = ext_pca.get_projections_one_unit(unit_id=1, sparse=False) >>> # get pre-computed projections for some units on some channels @@ -65,8 +65,8 @@ class ComputePrincipalComponents(ResultExtension): use_nodepipeline = False need_job_kwargs = True - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) + def __init__(self, sorting_analyzer): + ResultExtension.__init__(self, sorting_analyzer) def _set_params( self, @@ -77,7 +77,7 @@ def _set_params( ): assert mode in _possible_modes, "Invalid mode!" - # the sparsity in params is ONLY the injected sparsity and not the sorting_result one + # the sparsity in params is ONLY the injected sparsity and not the sorting_analyzer one params = dict( n_components=n_components, mode=mode, @@ -88,9 +88,9 @@ def _set_params( def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) - spikes = self.sorting_result.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) + spikes = self.sorting_analyzer.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] keep_spike_mask = np.isin(some_spikes["unit_index"], keep_unit_indices) new_data = dict() @@ -114,7 +114,7 @@ def get_pca_model(self): mode = self.params["mode"] if mode == "by_channel_local": pca_models = [] - for chan_id in self.sorting_result.channel_ids: + for chan_id in self.sorting_analyzer.channel_ids: pca_models.append(self.data[f"pca_model_{mode}_{chan_id}"]) else: pca_models = self.data[f"pca_model_{mode}"] @@ -129,7 +129,7 @@ def get_projections_one_unit(self, unit_id, sparse=False): unit_id : int or str The unit id to return PCA projections for sparse: bool, default: False - If True, and SortingResult must be sparse then only projections on sparse channels are returned. + If True, and SortingAnalyzer must be sparse then only projections on sparse channels are returned. Channel indices are also returned. Returns @@ -140,15 +140,15 @@ def get_projections_one_unit(self, unit_id, sparse=False): channel_indices: np.array """ - sparsity = self.sorting_result.sparsity - sorting = self.sorting_result.sorting + sparsity = self.sorting_analyzer.sparsity + sorting = self.sorting_analyzer.sorting if sparse: assert self.params["mode"] != "concatenated", "mode concatenated cannot retrieve sparse projection" - assert sparsity is not None, "sparse projection need SortingResult to be sparse" + assert sparsity is not None, "sparse projection need SortingAnalyzer to be sparse" spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] unit_index = sorting.id_to_index(unit_id) spike_mask = some_spikes["unit_index"] == unit_index @@ -162,7 +162,7 @@ def get_projections_one_unit(self, unit_id, sparse=False): if sparse: return projections, channel_indices else: - num_chans = self.sorting_result.get_num_channels() + num_chans = self.sorting_analyzer.get_num_channels() projections_ = np.zeros( (projections.shape[0], projections.shape[1], num_chans), dtype=projections.dtype ) @@ -189,24 +189,24 @@ def get_some_projections(self, channel_ids=None, unit_ids=None): spike_unit_indices: np.array Array a copy of with some_spikes["unit_index"] of returned PCA projections of shape (num_spikes, ) """ - sorting = self.sorting_result.sorting + sorting = self.sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.unit_ids if channel_ids is None: - channel_ids = self.sorting_result.channel_ids + channel_ids = self.sorting_analyzer.channel_ids - channel_indices = self.sorting_result.channel_ids_to_indices(channel_ids) + channel_indices = self.sorting_analyzer.channel_ids_to_indices(channel_ids) # note : internally when sparse PCA are not aligned!! Exactly like waveforms. all_projections = self.data["pca_projection"] num_components = all_projections.shape[1] dtype = all_projections.dtype - sparsity = self.sorting_result.sparsity + sparsity = self.sorting_analyzer.sparsity spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] unit_indices = sorting.ids_to_indices(unit_ids) selected_inds = np.flatnonzero(np.isin(some_spikes["unit_index"], unit_indices)) @@ -263,7 +263,7 @@ def project_new(self, new_spikes, new_waveforms, progress_bar=True): def _run(self, **job_kwargs): """ Compute the PCs on waveforms extacted within the by ComputeWaveforms. - Projections are computed only on the waveforms sampled by the SortingResult. + Projections are computed only on the waveforms sampled by the SortingAnalyzer. """ p = self.params mode = p["mode"] @@ -277,7 +277,7 @@ def _run(self, **job_kwargs): # TODO : make parralel for by_channel_global and concatenated if mode == "by_channel_local": pca_models = self._fit_by_channel_local(n_jobs, progress_bar) - for chan_ind, chan_id in enumerate(self.sorting_result.channel_ids): + for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids): self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind] pca_model = pca_models elif mode == "by_channel_global": @@ -288,10 +288,10 @@ def _run(self, **job_kwargs): self.data[f"pca_model_{mode}"] = pca_model # transform - waveforms_ext = self.sorting_result.get_extension("waveforms") + waveforms_ext = self.sorting_analyzer.get_extension("waveforms") some_waveforms = waveforms_ext.data["waveforms"] - spikes = self.sorting_result.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + spikes = self.sorting_analyzer.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] pca_projection = self._transform_waveforms(some_spikes, some_waveforms, pca_model, progress_bar) @@ -318,7 +318,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) p = self.params - we = self.sorting_result + we = self.sorting_analyzer sorting = we.sorting assert ( we.has_recording() @@ -331,7 +331,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): assert file_path is not None file_path = Path(file_path) - sparsity = self.sorting_result.sparsity + sparsity = self.sorting_analyzer.sparsity if sparsity is None: sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} max_channels_per_template = we.get_num_channels() @@ -350,7 +350,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) - waveforms_ext = self.sorting_result.get_extension("waveforms") + waveforms_ext = self.sorting_analyzer.get_extension("waveforms") # and run func = _all_pc_extractor_chunk @@ -373,8 +373,8 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): p = self.params - unit_ids = self.sorting_result.unit_ids - channel_ids = self.sorting_result.channel_ids + unit_ids = self.sorting_analyzer.unit_ids + channel_ids = self.sorting_analyzer.channel_ids # there is one PCA per channel for independent fit per channel pca_models = [IncrementalPCA(n_components=p["n_components"], whiten=p["whiten"]) for _ in channel_ids] @@ -406,10 +406,10 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): return pca_models def _fit_by_channel_global(self, progress_bar): - # we = self.sorting_result + # we = self.sorting_analyzer p = self.params # unit_ids = we.unit_ids - unit_ids = self.sorting_result.unit_ids + unit_ids = self.sorting_analyzer.unit_ids # there is one unique PCA accross channels from sklearn.decomposition import IncrementalPCA @@ -436,9 +436,9 @@ def _fit_by_channel_global(self, progress_bar): def _fit_concatenated(self, progress_bar): p = self.params - unit_ids = self.sorting_result.unit_ids + unit_ids = self.sorting_analyzer.unit_ids - assert self.sorting_result.sparsity is None, "For mode 'concatenated' waveforms need to be dense" + assert self.sorting_analyzer.sparsity is None, "For mode 'concatenated' waveforms need to be dense" # there is one unique PCA accross channels from sklearn.decomposition import IncrementalPCA @@ -475,7 +475,7 @@ def _transform_waveforms(self, spikes, waveforms, pca_model, progress_bar): shape = (waveforms.shape[0], n_components) pca_projection = np.zeros(shape, dtype="float32") - unit_ids = self.sorting_result.unit_ids + unit_ids = self.sorting_analyzer.unit_ids # transform units_loop = enumerate(unit_ids) @@ -525,26 +525,26 @@ def _transform_waveforms(self, spikes, waveforms, pca_model, progress_bar): def _get_slice_waveforms(self, unit_id, spikes, waveforms): # slice by mask waveforms from one unit - unit_index = self.sorting_result.sorting.id_to_index(unit_id) + unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) spike_mask = spikes["unit_index"] == unit_index wfs = waveforms[spike_mask, :, :] - sparsity = self.sorting_result.sparsity + sparsity = self.sorting_analyzer.sparsity if sparsity is not None: channel_inds = sparsity.unit_id_to_channel_indices[unit_id] wfs = wfs[:, :, : channel_inds.size] else: - channel_inds = np.arange(self.sorting_result.channel_ids.size, dtype=int) + channel_inds = np.arange(self.sorting_analyzer.channel_ids.size, dtype=int) return wfs, channel_inds, spike_mask def _get_sparse_waveforms(self, unit_id): # get waveforms + channel_inds: dense or sparse - waveforms_ext = self.sorting_result.get_extension("waveforms") + waveforms_ext = self.sorting_analyzer.get_extension("waveforms") some_waveforms = waveforms_ext.data["waveforms"] - spikes = self.sorting_result.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_result.random_spikes_indices] + spikes = self.sorting_analyzer.sorting.to_spike_vector() + some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] return self._get_slice_waveforms(unit_id, some_spikes, some_waveforms) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index b9477f0a22..1899951dc5 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -7,7 +7,7 @@ from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type from spikeinterface.core.sorting_tools import spike_vector_to_indices @@ -22,8 +22,8 @@ class ComputeSpikeAmplitudes(ResultExtension): Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object ms_before : float, default: 0.5 The left window, before a peak, in milliseconds ms_after : float, default: 0.5 @@ -63,8 +63,8 @@ class ComputeSpikeAmplitudes(ResultExtension): nodepipeline_variables = ["amplitudes"] need_job_kwargs = True - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) + def __init__(self, sorting_analyzer): + ResultExtension.__init__(self, sorting_analyzer) self._all_spikes = None @@ -73,9 +73,9 @@ def _set_params(self, peak_sign="neg", return_scaled=True): return params def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_result.unit_ids, unit_ids)) + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - spikes = self.sorting_result.sorting.to_spike_vector() + spikes = self.sorting_analyzer.sorting.to_spike_vector() keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) new_data = dict() @@ -85,16 +85,16 @@ def _select_extension_data(self, unit_ids): def _get_pipeline_nodes(self): - recording = self.sorting_result.recording - sorting = self.sorting_result.sorting + recording = self.sorting_analyzer.recording + sorting = self.sorting_analyzer.sorting peak_sign = self.params["peak_sign"] return_scaled = self.params["return_scaled"] extremum_channels_indices = get_template_extremum_channel( - self.sorting_result, peak_sign=peak_sign, outputs="index" + self.sorting_analyzer, peak_sign=peak_sign, outputs="index" ) - peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_result, peak_sign=peak_sign) + peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) if return_scaled: # check if has scaled values: @@ -119,7 +119,7 @@ def _run(self, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amps = run_node_pipeline( - self.sorting_result.recording, + self.sorting_analyzer.recording, nodes, job_kwargs=job_kwargs, job_name="spike_amplitudes", @@ -132,11 +132,11 @@ def _get_data(self, outputs="numpy"): if outputs == "numpy": return all_amplitudes elif outputs == "by_unit": - unit_ids = self.sorting_result.unit_ids - spike_vector = self.sorting_result.sorting.to_spike_vector(concatenated=False) + unit_ids = self.sorting_analyzer.unit_ids + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) spike_indices = spike_vector_to_indices(spike_vector, unit_ids) amplitudes_by_units = {} - for segment_index in range(self.sorting_result.sorting.get_num_segments()): + for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): amplitudes_by_units[segment_index] = {} for unit_id in unit_ids: inds = spike_indices[segment_index][unit_id] diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index bc646676b8..76602e2763 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -3,7 +3,7 @@ import numpy as np from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sorting_tools import spike_vector_to_indices @@ -17,8 +17,8 @@ class ComputeSpikeLocations(ResultExtension): Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object ms_before : float, default: 0.5 The left window, before a peak, in milliseconds ms_after : float, default: 0.5 @@ -58,11 +58,11 @@ class ComputeSpikeLocations(ResultExtension): nodepipeline_variables = ["spike_locations"] need_job_kwargs = True - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) + def __init__(self, sorting_analyzer): + ResultExtension.__init__(self, sorting_analyzer) - extremum_channel_inds = get_template_extremum_channel(self.sorting_result, outputs="index") - self.spikes = self.sorting_result.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + extremum_channel_inds = get_template_extremum_channel(self.sorting_analyzer, outputs="index") + self.spikes = self.sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) def _set_params( self, @@ -89,7 +89,7 @@ def _set_params( return params def _select_extension_data(self, unit_ids): - old_unit_ids = self.sorting_result.unit_ids + old_unit_ids = self.sorting_analyzer.unit_ids unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) spike_mask = np.isin(self.spikes["unit_index"], unit_inds) @@ -99,11 +99,11 @@ def _select_extension_data(self, unit_ids): def _get_pipeline_nodes(self): from spikeinterface.sortingcomponents.peak_localization import get_localization_pipeline_nodes - recording = self.sorting_result.recording - sorting = self.sorting_result.sorting + recording = self.sorting_analyzer.recording + sorting = self.sorting_analyzer.sorting peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] extremum_channels_indices = get_template_extremum_channel( - self.sorting_result, peak_sign=peak_sign, outputs="index" + self.sorting_analyzer, peak_sign=peak_sign, outputs="index" ) retriever = SpikeRetriever( @@ -126,7 +126,7 @@ def _run(self, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() spike_locations = run_node_pipeline( - self.sorting_result.recording, + self.sorting_analyzer.recording, nodes, job_kwargs=job_kwargs, job_name="spike_locations", @@ -139,11 +139,11 @@ def _get_data(self, outputs="numpy"): if outputs == "numpy": return all_spike_locations elif outputs == "by_unit": - unit_ids = self.sorting_result.unit_ids - spike_vector = self.sorting_result.sorting.to_spike_vector(concatenated=False) + unit_ids = self.sorting_analyzer.unit_ids + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) spike_indices = spike_vector_to_indices(spike_vector, unit_ids) spike_locations_by_units = {} - for segment_index in range(self.sorting_result.sorting.get_num_segments()): + for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): spike_locations_by_units[segment_index] = {} for unit_id in unit_ids: inds = spike_indices[segment_index][unit_id] diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 3f9817934e..e4fd456107 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -11,7 +11,7 @@ from typing import Optional from copy import deepcopy -from ..core.sortingresult import register_result_extension, ResultExtension +from ..core.sortinganalyzer import register_result_extension, ResultExtension from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.template_tools import _get_dense_templates_array @@ -50,8 +50,8 @@ class ComputeTemplateMetrics(ResultExtension): Parameters ---------- - sorting_result: SortingResult - The SortingResult object + sorting_analyzer: SortingAnalyzer + The SortingAnalyzer object metric_names : list or None, default: None List of metrics to compute (see si.postprocessing.get_template_metric_names()) peak_sign : {"neg", "pos"}, default: "neg" @@ -124,7 +124,7 @@ def _set_params( "so that each unit will correspond to 1 row of the output dataframe." ) assert ( - self.sorting_result.get_channel_locations().shape[1] == 2 + self.sorting_analyzer.get_channel_locations().shape[1] == 2 ), "If multi-channel metrics are computed, channel locations must be 2D." if metric_names is None: @@ -160,15 +160,15 @@ def _run(self): sparsity = self.params["sparsity"] peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] - unit_ids = self.sorting_result.unit_ids - sampling_frequency = self.sorting_result.sampling_frequency + unit_ids = self.sorting_analyzer.unit_ids + sampling_frequency = self.sorting_analyzer.sampling_frequency metrics_single_channel = [m for m in metric_names if m in get_single_channel_template_metric_names()] metrics_multi_channel = [m for m in metric_names if m in get_multi_channel_template_metric_names()] if sparsity is None: extremum_channels_ids = get_template_extremum_channel( - self.sorting_result, peak_sign=peak_sign, outputs="id" + self.sorting_analyzer, peak_sign=peak_sign, outputs="id" ) template_metrics = pd.DataFrame(index=unit_ids, columns=metric_names) @@ -184,16 +184,16 @@ def _run(self): ) template_metrics = pd.DataFrame(index=multi_index, columns=metric_names) - all_templates = _get_dense_templates_array(self.sorting_result, return_scaled=True) + all_templates = _get_dense_templates_array(self.sorting_analyzer, return_scaled=True) - channel_locations = self.sorting_result.get_channel_locations() + channel_locations = self.sorting_analyzer.get_channel_locations() for unit_index, unit_id in enumerate(unit_ids): template_all_chans = all_templates[unit_index] chan_ids = np.array(extremum_channels_ids[unit_id]) if chan_ids.ndim == 0: chan_ids = [chan_ids] - chan_ind = self.sorting_result.channel_ids_to_indices(chan_ids) + chan_ind = self.sorting_analyzer.channel_ids_to_indices(chan_ids) template = template_all_chans[:, chan_ind] # compute single_channel metrics @@ -227,8 +227,8 @@ def _run(self): for metric_name in metrics_multi_channel: # retrieve template (with sparsity if waveform extractor is sparse) template = all_templates[unit_index, :, :] - if self.sorting_result.is_sparse(): - mask = self.sorting_result.sparsity.mask[unit_index, :] + if self.sorting_analyzer.is_sparse(): + mask = self.sorting_analyzer.sparsity.mask[unit_index, :] template = template[:, mask] if template.shape[1] < self.min_channels_for_multi_channel_warning: @@ -236,8 +236,8 @@ def _run(self): f"With less than {self.min_channels_for_multi_channel_warning} channels, " "multi-channel metrics might not be reliable." ) - if self.sorting_result.is_sparse(): - channel_locations_sparse = channel_locations[self.sorting_result.sparsity.mask[unit_index]] + if self.sorting_analyzer.is_sparse(): + channel_locations_sparse = channel_locations[self.sorting_analyzer.sparsity.mask[unit_index]] else: channel_locations_sparse = channel_locations diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 9bd28d5080..99f804b124 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension from ..core.template_tools import _get_dense_templates_array @@ -12,8 +12,8 @@ class ComputeTemplateSimilarity(ResultExtension): Parameters ---------- - sorting_result: SortingResult - The SortingResult object + sorting_analyzer: SortingAnalyzer + The SortingAnalyzer object method: str, default: "cosine_similarity" The method to compute the similarity @@ -31,8 +31,8 @@ class ComputeTemplateSimilarity(ResultExtension): use_nodepipeline = False need_job_kwargs = False - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) + def __init__(self, sorting_analyzer): + ResultExtension.__init__(self, sorting_analyzer) def _set_params(self, method="cosine_similarity"): params = dict(method=method) @@ -40,12 +40,12 @@ def _set_params(self, method="cosine_similarity"): def _select_extension_data(self, unit_ids): # filter metrics dataframe - unit_indices = self.sorting_result.sorting.ids_to_indices(unit_ids) + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) new_similarity = self.data["similarity"][unit_indices][:, unit_indices] return dict(similarity=new_similarity) def _run(self): - templates_array = _get_dense_templates_array(self.sorting_result, return_scaled=True) + templates_array = _get_dense_templates_array(self.sorting_analyzer, return_scaled=True) similarity = compute_similarity_with_templates_array( templates_array, templates_array, method=self.params["method"] ) @@ -55,7 +55,7 @@ def _get_data(self): return self.data["similarity"] -# @alessio: compute_template_similarity() is now one inner SortingResult only +# @alessio: compute_template_similarity() is now one inner SortingAnalyzer only register_result_extension(ComputeTemplateSimilarity) compute_template_similarity = ComputeTemplateSimilarity.function_factory() @@ -75,9 +75,9 @@ def compute_similarity_with_templates_array(templates_array, other_templates_arr return similarity -def compute_template_similarity_by_pair(sorting_result_1, sorting_result_2, method="cosine_similarity"): - templates_array_1 = _get_dense_templates_array(sorting_result_1, return_scaled=True) - templates_array_2 = _get_dense_templates_array(sorting_result_2, return_scaled=True) +def compute_template_similarity_by_pair(sorting_analyzer_1, sorting_analyzer_2, method="cosine_similarity"): + templates_array_1 = _get_dense_templates_array(sorting_analyzer_1, return_scaled=True) + templates_array_2 = _get_dense_templates_array(sorting_analyzer_2, return_scaled=True) similmarity = compute_similarity_with_templates_array(templates_array_1, templates_array_2, method) return similmarity diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 3520231a9d..5516fe592a 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -8,7 +8,7 @@ from pathlib import Path from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import start_sorting_result +from spikeinterface.core import create_sorting_analyzer from spikeinterface.core import estimate_sparsity @@ -43,7 +43,7 @@ def get_dataset(): return recording, sorting -def get_sorting_result(recording, sorting, format="memory", sparsity=None, name=""): +def get_sorting_analyzer(recording, sorting, format="memory", sparsity=None, name=""): sparse = sparsity is not None if format == "memory": folder = None @@ -54,9 +54,9 @@ def get_sorting_result(recording, sorting, format="memory", sparsity=None, name= if folder and folder.exists(): shutil.rmtree(folder) - sorting_result = start_sorting_result(sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity) - return sorting_result + return sorting_analyzer class ResultExtensionCommonTestSuite: @@ -83,20 +83,20 @@ def setUpClass(cls): def extension_name(self): return self.extension_class.extension_name - def _prepare_sorting_result(self, format, sparse): - # prepare a SortingResult object with depencies already computed + def _prepare_sorting_analyzer(self, format, sparse): + # prepare a SortingAnalyzer object with depencies already computed sparsity_ = self.sparsity if sparse else None - sorting_result = get_sorting_result( + sorting_analyzer = get_sorting_analyzer( self.recording, self.sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name ) - sorting_result.select_random_spikes(max_spikes_per_unit=50, seed=2205) + sorting_analyzer.select_random_spikes(max_spikes_per_unit=50, seed=2205) for dependency_name in self.extension_class.depend_on: if "|" in dependency_name: dependency_name = dependency_name.split("|")[0] - sorting_result.compute(dependency_name) - return sorting_result + sorting_analyzer.compute(dependency_name) + return sorting_analyzer - def _check_one(self, sorting_result): + def _check_one(self, sorting_analyzer): if self.extension_class.need_job_kwargs: job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) else: @@ -104,16 +104,16 @@ def _check_one(self, sorting_result): for params in self.extension_function_params_list: print(" params", params) - ext = sorting_result.compute(self.extension_name, **params, **job_kwargs) + ext = sorting_analyzer.compute(self.extension_name, **params, **job_kwargs) assert len(ext.data) > 0 main_data = ext.get_data() - ext = sorting_result.get_extension(self.extension_name) + ext = sorting_analyzer.get_extension(self.extension_name) assert ext is not None - some_unit_ids = sorting_result.unit_ids[::2] - sliced = sorting_result.select_units(some_unit_ids, format="memory") - assert np.array_equal(sliced.unit_ids, sorting_result.unit_ids[::2]) + some_unit_ids = sorting_analyzer.unit_ids[::2] + sliced = sorting_analyzer.select_units(some_unit_ids, format="memory") + assert np.array_equal(sliced.unit_ids, sorting_analyzer.unit_ids[::2]) # print(sliced) def test_extension(self): @@ -121,5 +121,5 @@ def test_extension(self): for format in ("memory", "binary_folder", "zarr"): print() print("sparse", sparse, format) - sorting_result = self._prepare_sorting_result(format, sparse) - self._check_one(sorting_result) + sorting_analyzer = self._prepare_sorting_analyzer(format, sparse) + self._check_one(sorting_analyzer) diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index eda88a5f1d..40034b7363 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -15,14 +15,14 @@ class AmplitudeScalingsExtensionTest(ResultExtensionCommonTestSuite, unittest.Te ] def test_scaling_values(self): - sorting_result = self._prepare_sorting_result("memory", True) - sorting_result.compute("amplitude_scalings", handle_collisions=False) + sorting_analyzer = self._prepare_sorting_analyzer("memory", True) + sorting_analyzer.compute("amplitude_scalings", handle_collisions=False) - spikes = sorting_result.sorting.to_spike_vector() + spikes = sorting_analyzer.sorting.to_spike_vector() - ext = sorting_result.get_extension("amplitude_scalings") + ext = sorting_analyzer.get_extension("amplitude_scalings") - for unit_index, unit_id in enumerate(sorting_result.unit_ids): + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): mask = spikes["unit_index"] == unit_index scalings = ext.data["amplitude_scalings"][mask] median_scaling = np.median(scalings) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index b686e078ee..c7e9942f2d 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -22,11 +22,11 @@ class PrincipalComponentsExtensionTest(ResultExtensionCommonTestSuite, unittest. def test_mode_concatenated(self): # this is tested outside "extension_function_params_list" because it do not support sparsity! - sorting_result = self._prepare_sorting_result(format="memory", sparse=False) + sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=False) n_components = 3 - sorting_result.compute("principal_components", mode="concatenated", n_components=n_components) - ext = sorting_result.get_extension("principal_components") + sorting_analyzer.compute("principal_components", mode="concatenated", n_components=n_components) + ext = sorting_analyzer.get_extension("principal_components") assert ext is not None assert len(ext.data) > 0 pca = ext.data["pca_projection"] @@ -37,14 +37,14 @@ def test_get_projections(self): for sparse in (False, True): - sorting_result = self._prepare_sorting_result(format="memory", sparse=sparse) - num_chans = sorting_result.get_num_channels() + sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=sparse) + num_chans = sorting_analyzer.get_num_channels() n_components = 2 - sorting_result.compute("principal_components", mode="by_channel_global", n_components=n_components) - ext = sorting_result.get_extension("principal_components") + sorting_analyzer.compute("principal_components", mode="by_channel_global", n_components=n_components) + ext = sorting_analyzer.get_extension("principal_components") - for unit_id in sorting_result.unit_ids: + for unit_id in sorting_analyzer.unit_ids: if not sparse: one_proj = ext.get_projections_one_unit(unit_id, sparse=False) assert one_proj.shape[1] == n_components @@ -59,20 +59,20 @@ def test_get_projections(self): assert one_proj.shape[2] < num_chans assert one_proj.shape[2] == chan_inds.size - some_unit_ids = sorting_result.unit_ids[::2] - some_channel_ids = sorting_result.channel_ids[::2] + some_unit_ids = sorting_analyzer.unit_ids[::2] + some_channel_ids = sorting_analyzer.channel_ids[::2] # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == sorting_result.random_spikes_indices.size + assert spike_unit_index.shape[0] == sorting_analyzer.random_spikes_indices.size assert some_projections.shape[1] == n_components assert some_projections.shape[2] == num_chans # this should be some spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < sorting_result.random_spikes_indices.size + assert spike_unit_index.shape[0] < sorting_analyzer.random_spikes_indices.size assert some_projections.shape[1] == n_components assert some_projections.shape[2] == num_chans assert 1 not in spike_unit_index @@ -82,7 +82,7 @@ def test_get_projections(self): channel_ids=some_channel_ids, unit_ids=some_unit_ids ) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < sorting_result.random_spikes_indices.size + assert spike_unit_index.shape[0] < sorting_analyzer.random_spikes_indices.size assert some_projections.shape[1] == n_components assert some_projections.shape[2] == some_channel_ids.size assert 1 not in spike_unit_index @@ -90,13 +90,13 @@ def test_get_projections(self): def test_compute_for_all_spikes(self): for sparse in (True, False): - sorting_result = self._prepare_sorting_result(format="memory", sparse=sparse) + sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=sparse) - num_spikes = sorting_result.sorting.to_spike_vector().size + num_spikes = sorting_analyzer.sorting.to_spike_vector().size n_components = 3 - sorting_result.compute("principal_components", mode="by_channel_local", n_components=n_components) - ext = sorting_result.get_extension("principal_components") + sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) + ext = sorting_analyzer.get_extension("principal_components") pc_file1 = cache_folder / "all_pc1.npy" ext.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) @@ -112,16 +112,16 @@ def test_compute_for_all_spikes(self): def test_project_new(self): from sklearn.decomposition import IncrementalPCA - sorting_result = self._prepare_sorting_result(format="memory", sparse=False) + sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=False) - waveforms = sorting_result.get_extension("waveforms").data["waveforms"] + waveforms = sorting_analyzer.get_extension("waveforms").data["waveforms"] n_components = 3 - sorting_result.compute("principal_components", mode="by_channel_local", n_components=n_components) - ext_pca = sorting_result.get_extension(self.extension_name) + sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) + ext_pca = sorting_analyzer.get_extension(self.extension_name) num_spike = 100 - new_spikes = sorting_result.sorting.to_spike_vector()[:num_spike] + new_spikes = sorting_analyzer.sorting.to_spike_vector()[:num_spike] new_waveforms = np.random.randn(num_spike, waveforms.shape[1], waveforms.shape[2]) new_proj = ext_pca.project_new(new_spikes, new_waveforms) @@ -139,7 +139,7 @@ def test_project_new(self): test.test_compute_for_all_spikes() test.test_project_new() - # ext = test.sorting_results["sparseTrue_memory"].get_extension("principal_components") + # ext = test.sorting_analyzers["sparseTrue_memory"].get_extension("principal_components") # pca = ext.data["pca_projection"] # import matplotlib.pyplot as plt # fig, ax = plt.subplots() diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index a7cca70363..bee4816e80 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -18,6 +18,6 @@ class ComputeSpikeAmplitudesTest(ResultExtensionCommonTestSuite, unittest.TestCa test.setUpClass() test.test_extension() - # for k, sorting_result in test.sorting_results.items(): - # print(sorting_result) - # print(sorting_result.get_extension("spike_amplitudes").data["amplitudes"].shape) + # for k, sorting_analyzer in test.sorting_analyzers.items(): + # print(sorting_analyzer) + # print(sorting_analyzer.get_extension("spike_amplitudes").data["amplitudes"].shape) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 7f48ccb525..26a065dc29 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -1,6 +1,6 @@ import unittest -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite, get_sorting_result, get_dataset +from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite, get_sorting_analyzer, get_dataset from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity @@ -16,16 +16,16 @@ def test_check_equal_template_with_distribution_overlap(): recording, sorting = get_dataset() - sorting_result = get_sorting_result(recording, sorting, sparsity=None) - sorting_result.select_random_spikes() - sorting_result.compute("waveforms") - sorting_result.compute("templates") + sorting_analyzer = get_sorting_analyzer(recording, sorting, sparsity=None) + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("waveforms") + sorting_analyzer.compute("templates") - wf_ext = sorting_result.get_extension("waveforms") + wf_ext = sorting_analyzer.get_extension("waveforms") - for unit_id0 in sorting_result.unit_ids: + for unit_id0 in sorting_analyzer.unit_ids: waveforms0 = wf_ext.get_waveforms_one_unit(unit_id0) - for unit_id1 in sorting_result.unit_ids: + for unit_id1 in sorting_analyzer.unit_ids: if unit_id0 == unit_id1: continue waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index be9d1fad82..d0f9830091 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -12,7 +12,7 @@ except ImportError: HAVE_NUMBA = False -from ..core.sortingresult import register_result_extension, ResultExtension +from ..core.sortinganalyzer import register_result_extension, ResultExtension from ..core import compute_sparsity from ..core.template_tools import get_template_extremum_channel, _get_nbefore, _get_dense_templates_array @@ -33,8 +33,8 @@ class ComputeUnitLocations(ResultExtension): Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object method: "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The method to use for localization outputs: "numpy" | "by_unit", default: "numpy" @@ -56,15 +56,15 @@ class ComputeUnitLocations(ResultExtension): use_nodepipeline = False need_job_kwargs = False - def __init__(self, sorting_result): - ResultExtension.__init__(self, sorting_result) + def __init__(self, sorting_analyzer): + ResultExtension.__init__(self, sorting_analyzer) def _set_params(self, method="monopolar_triangulation", **method_kwargs): params = dict(method=method, method_kwargs=method_kwargs) return params def _select_extension_data(self, unit_ids): - unit_inds = self.sorting_result.sorting.ids_to_indices(unit_ids) + unit_inds = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) new_unit_location = self.data["unit_locations"][unit_inds] return dict(unit_locations=new_unit_location) @@ -75,11 +75,11 @@ def _run(self): assert method in possible_localization_methods if method == "center_of_mass": - unit_location = compute_center_of_mass(self.sorting_result, **method_kwargs) + unit_location = compute_center_of_mass(self.sorting_analyzer, **method_kwargs) elif method == "grid_convolution": - unit_location = compute_grid_convolution(self.sorting_result, **method_kwargs) + unit_location = compute_grid_convolution(self.sorting_analyzer, **method_kwargs) elif method == "monopolar_triangulation": - unit_location = compute_monopolar_triangulation(self.sorting_result, **method_kwargs) + unit_location = compute_monopolar_triangulation(self.sorting_analyzer, **method_kwargs) self.data["unit_locations"] = unit_location def get_data(self, outputs="numpy"): @@ -87,7 +87,7 @@ def get_data(self, outputs="numpy"): return self.data["unit_locations"] elif outputs == "by_unit": locations_by_unit = {} - for unit_ind, unit_id in enumerate(self.sorting_result.unit_ids): + for unit_ind, unit_id in enumerate(self.sorting_analyzer.unit_ids): locations_by_unit[unit_id] = self.data["unit_locations"][unit_ind] return locations_by_unit @@ -184,7 +184,7 @@ def estimate_distance_error_with_log(vec, wf_data, local_contact_locations, max_ def compute_monopolar_triangulation( - sorting_result, + sorting_analyzer, optimizer="minimize_with_log_penality", radius_um=75, max_distance_um=1000, @@ -211,8 +211,8 @@ def compute_monopolar_triangulation( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object method: "least_square" | "minimize_with_log_penality", default: "least_square" The optimizer to use radius_um: float, default: 75 @@ -238,13 +238,13 @@ def compute_monopolar_triangulation( assert optimizer in ("least_square", "minimize_with_log_penality") assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids - contact_locations = sorting_result.get_channel_locations() + contact_locations = sorting_analyzer.get_channel_locations() - sparsity = compute_sparsity(sorting_result, method="radius", radius_um=radius_um) - templates = _get_dense_templates_array(sorting_result) - nbefore = _get_nbefore(sorting_result) + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=radius_um) + templates = _get_dense_templates_array(sorting_analyzer) + nbefore = _get_nbefore(sorting_analyzer) if enforce_decrease: neighbours_mask = np.zeros((templates.shape[0], templates.shape[2]), dtype=bool) @@ -252,7 +252,7 @@ def compute_monopolar_triangulation( chan_inds = sparsity.unit_id_to_channel_indices[unit_id] neighbours_mask[i, chan_inds] = True enforce_decrease_radial_parents = make_radial_order_parents(contact_locations, neighbours_mask) - best_channels = get_template_extremum_channel(sorting_result, outputs="index") + best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index") unit_location = np.zeros((unit_ids.size, 4), dtype="float64") for i, unit_id in enumerate(unit_ids): @@ -281,14 +281,14 @@ def compute_monopolar_triangulation( return unit_location -def compute_center_of_mass(sorting_result, peak_sign="neg", radius_um=75, feature="ptp"): +def compute_center_of_mass(sorting_analyzer, peak_sign="neg", radius_um=75, feature="ptp"): """ Computes the center of mass (COM) of a unit based on the template amplitudes. Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um: float @@ -300,15 +300,15 @@ def compute_center_of_mass(sorting_result, peak_sign="neg", radius_um=75, featur ------- unit_location: np.array """ - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids - contact_locations = sorting_result.get_channel_locations() + contact_locations = sorting_analyzer.get_channel_locations() assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" - sparsity = compute_sparsity(sorting_result, peak_sign=peak_sign, method="radius", radius_um=radius_um) - templates = _get_dense_templates_array(sorting_result) - nbefore = _get_nbefore(sorting_result) + sparsity = compute_sparsity(sorting_analyzer, peak_sign=peak_sign, method="radius", radius_um=radius_um) + templates = _get_dense_templates_array(sorting_analyzer) + nbefore = _get_nbefore(sorting_analyzer) unit_location = np.zeros((unit_ids.size, 2), dtype="float64") for i, unit_id in enumerate(unit_ids): @@ -334,7 +334,7 @@ def compute_center_of_mass(sorting_result, peak_sign="neg", radius_um=75, featur def compute_grid_convolution( - sorting_result, + sorting_analyzer, peak_sign="neg", radius_um=40.0, upsampling_um=5, @@ -349,8 +349,8 @@ def compute_grid_convolution( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object peak_sign: "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um: float, default: 40.0 @@ -375,14 +375,14 @@ def compute_grid_convolution( unit_location: np.array """ - contact_locations = sorting_result.get_channel_locations() - unit_ids = sorting_result.unit_ids + contact_locations = sorting_analyzer.get_channel_locations() + unit_ids = sorting_analyzer.unit_ids - templates = _get_dense_templates_array(sorting_result) - nbefore = _get_nbefore(sorting_result) + templates = _get_dense_templates_array(sorting_analyzer) + nbefore = _get_nbefore(sorting_analyzer) nafter = templates.shape[1] - nbefore - fs = sorting_result.sampling_frequency + fs = sorting_analyzer.sampling_frequency percentile = 100 - percentile assert 0 <= percentile <= 100, "Percentile should be in [0, 100]" @@ -398,7 +398,7 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, margin_um, weight_method ) - peak_channels = get_template_extremum_channel(sorting_result, peak_sign, outputs="index") + peak_channels = get_template_extremum_channel(sorting_analyzer, peak_sign, outputs="index") weights_sparsity_mask = weights > 0 diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index e90add4067..9530f2bf9c 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -17,7 +17,7 @@ import warnings from ..postprocessing import correlogram_for_one_segment -from ..core import SortingResult, get_noise_levels +from ..core import SortingAnalyzer, get_noise_levels from ..core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, @@ -36,13 +36,13 @@ _default_params = dict() -def compute_num_spikes(sorting_result, unit_ids=None, **kwargs): +def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): """Compute the number of spike across segments. Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object unit_ids : list or None The list of unit ids to compute the number of spikes. If None, all units are used. @@ -52,7 +52,7 @@ def compute_num_spikes(sorting_result, unit_ids=None, **kwargs): The number of spikes, across all segments, for each unit ID. """ - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.unit_ids num_segs = sorting.get_num_segments() @@ -68,13 +68,13 @@ def compute_num_spikes(sorting_result, unit_ids=None, **kwargs): return num_spikes -def compute_firing_rates(sorting_result, unit_ids=None, **kwargs): +def compute_firing_rates(sorting_analyzer, unit_ids=None, **kwargs): """Compute the firing rate across segments. Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object unit_ids : list or None The list of unit ids to compute the firing rate. If None, all units are used. @@ -84,25 +84,25 @@ def compute_firing_rates(sorting_result, unit_ids=None, **kwargs): The firing rate, across all segments, for each unit ID. """ - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.unit_ids - total_duration = sorting_result.get_total_duration() + total_duration = sorting_analyzer.get_total_duration() firing_rates = {} - num_spikes = compute_num_spikes(sorting_result) + num_spikes = compute_num_spikes(sorting_analyzer) for unit_id in unit_ids: firing_rates[unit_id] = num_spikes[unit_id] / total_duration return firing_rates -def compute_presence_ratios(sorting_result, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None, **kwargs): +def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None, **kwargs): """Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object bin_duration_s : float, default: 60 The duration of each bin in seconds. If the duration is less than this value, presence_ratio is set to NaN @@ -122,15 +122,15 @@ def compute_presence_ratios(sorting_result, bin_duration_s=60.0, mean_fr_ratio_t The total duration, across all segments, is divided into "num_bins". To do so, spike trains across segments are concatenated to mimic a continuous segment. """ - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting if unit_ids is None: - unit_ids = sorting_result.unit_ids - num_segs = sorting_result.get_num_segments() + unit_ids = sorting_analyzer.unit_ids + num_segs = sorting_analyzer.get_num_segments() - seg_lengths = [sorting_result.get_num_samples(i) for i in range(num_segs)] - total_length = sorting_result.get_total_samples() - total_duration = sorting_result.get_total_duration() - bin_duration_samples = int((bin_duration_s * sorting_result.sampling_frequency)) + seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + total_length = sorting_analyzer.get_total_samples() + total_duration = sorting_analyzer.get_total_duration() + bin_duration_samples = int((bin_duration_s * sorting_analyzer.sampling_frequency)) num_bin_edges = total_length // bin_duration_samples + 1 bin_edges = np.arange(num_bin_edges) * bin_duration_samples @@ -177,7 +177,7 @@ def compute_presence_ratios(sorting_result, bin_duration_s=60.0, mean_fr_ratio_t def compute_snrs( - sorting_result, + sorting_analyzer, peak_sign: str = "neg", peak_mode: str = "extremum", unit_ids=None, @@ -186,14 +186,14 @@ def compute_snrs( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the template to compute best channels. peak_mode: "extremum" | "at_index", default: "extremum" How to compute the amplitude. Extremum takes the maxima/minima - At_index takes the value at t=sorting_result.nbefore + At_index takes the value at t=sorting_analyzer.nbefore unit_ids : list or None The list of unit ids to compute the SNR. If None, all units are used. @@ -202,18 +202,18 @@ def compute_snrs( snrs : dict Computed signal to noise ratio for each unit. """ - assert sorting_result.has_extension("noise_levels") - noise_levels = sorting_result.get_extension("noise_levels").get_data() + assert sorting_analyzer.has_extension("noise_levels") + noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() assert peak_sign in ("neg", "pos", "both") assert peak_mode in ("extremum", "at_index") if unit_ids is None: - unit_ids = sorting_result.unit_ids - channel_ids = sorting_result.channel_ids + unit_ids = sorting_analyzer.unit_ids + channel_ids = sorting_analyzer.channel_ids - extremum_channels_ids = get_template_extremum_channel(sorting_result, peak_sign=peak_sign, mode=peak_mode) - unit_amplitudes = get_template_extremum_amplitude(sorting_result, peak_sign=peak_sign, mode=peak_mode) + extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) + unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) # make a dict to access by chan_id noise_levels = dict(zip(channel_ids, noise_levels)) @@ -231,7 +231,7 @@ def compute_snrs( _default_params["snr"] = dict(peak_sign="neg", peak_mode="extremum") -def compute_isi_violations(sorting_result, isi_threshold_ms=1.5, min_isi_ms=0, unit_ids=None): +def compute_isi_violations(sorting_analyzer, isi_threshold_ms=1.5, min_isi_ms=0, unit_ids=None): """Calculate Inter-Spike Interval (ISI) violations. It computes several metrics related to isi violations: @@ -241,8 +241,8 @@ def compute_isi_violations(sorting_result, isi_threshold_ms=1.5, min_isi_ms=0, u Parameters ---------- - sorting_result : SortingResult - The SortingResult object + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object isi_threshold_ms : float, default: 1.5 Threshold for classifying adjacent spikes as an ISI violation, in ms. This is the biophysical refractory period @@ -275,13 +275,13 @@ def compute_isi_violations(sorting_result, isi_threshold_ms=1.5, min_isi_ms=0, u """ res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting if unit_ids is None: - unit_ids = sorting_result.unit_ids - num_segs = sorting_result.get_num_segments() + unit_ids = sorting_analyzer.unit_ids + num_segs = sorting_analyzer.get_num_segments() - total_duration_s = sorting_result.get_total_duration() - fs = sorting_result.sampling_frequency + total_duration_s = sorting_analyzer.get_total_duration() + fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 min_isi_s = min_isi_ms / 1000 @@ -313,7 +313,7 @@ def compute_isi_violations(sorting_result, isi_threshold_ms=1.5, min_isi_ms=0, u def compute_refrac_period_violations( - sorting_result, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, unit_ids=None + sorting_analyzer, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, unit_ids=None ): """Calculates the number of refractory period violations. @@ -324,8 +324,8 @@ def compute_refrac_period_violations( Parameters ---------- - sorting_result : SortingResult - The SortingResult object + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object refractory_period_ms : float, default: 1.0 The period (in ms) where no 2 good spikes can occur. censored_period_ms : float, default: 0.0 @@ -357,17 +357,17 @@ def compute_refrac_period_violations( print("compute_refrac_period_violations cannot run without numba.") return None - sorting = sorting_result.sorting - fs = sorting_result.sampling_frequency - num_units = len(sorting_result.unit_ids) - num_segments = sorting_result.get_num_segments() + sorting = sorting_analyzer.sorting + fs = sorting_analyzer.sampling_frequency + num_units = len(sorting_analyzer.unit_ids) + num_segments = sorting_analyzer.get_num_segments() spikes = sorting.to_spike_vector(concatenated=False) if unit_ids is None: - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids - num_spikes = compute_num_spikes(sorting_result) + num_spikes = compute_num_spikes(sorting_analyzer) t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) @@ -378,7 +378,7 @@ def compute_refrac_period_violations( spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) - T = sorting_result.get_total_samples() + T = sorting_analyzer.get_total_samples() nb_violations = {} rp_contamination = {} @@ -402,7 +402,7 @@ def compute_refrac_period_violations( def compute_sliding_rp_violations( - sorting_result, + sorting_analyzer, min_spikes=0, bin_size_ms=0.25, window_size_s=1, @@ -417,8 +417,8 @@ def compute_sliding_rp_violations( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object min_spikes : int, default: 0 Contamination is set to np.nan if the unit has less than this many spikes across all segments. @@ -446,12 +446,12 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ - duration = sorting_result.get_total_duration() - sorting = sorting_result.sorting + duration = sorting_analyzer.get_total_duration() + sorting = sorting_analyzer.sorting if unit_ids is None: - unit_ids = sorting_result.unit_ids - num_segs = sorting_result.get_num_segments() - fs = sorting_result.sampling_frequency + unit_ids = sorting_analyzer.unit_ids + num_segs = sorting_analyzer.get_num_segments() + fs = sorting_analyzer.sampling_frequency contamination = {} @@ -496,14 +496,14 @@ def compute_sliding_rp_violations( ) -def compute_synchrony_metrics(sorting_result, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): +def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): """Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of "synchrony_size" spikes at the exact same sample index. Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object synchrony_sizes : list or tuple, default: (2, 4, 8) The synchrony sizes to compute. unit_ids : list or None, default: None @@ -521,17 +521,17 @@ def compute_synchrony_metrics(sorting_result, synchrony_sizes=(2, 4, 8), unit_id This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" - spike_counts = sorting_result.sorting.count_num_spikes_per_unit(outputs="dict") - sorting = sorting_result.sorting + spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit(outputs="dict") + sorting = sorting_analyzer.sorting spikes = sorting.to_spike_vector(concatenated=False) if unit_ids is None: - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids # Pre-allocate synchrony counts synchrony_counts = {} for synchrony_size in synchrony_sizes: - synchrony_counts[synchrony_size] = np.zeros(len(sorting_result.unit_ids), dtype=np.int64) + synchrony_counts[synchrony_size] = np.zeros(len(sorting_analyzer.unit_ids), dtype=np.int64) all_unit_ids = list(sorting.unit_ids) for segment_index in range(sorting.get_num_segments()): @@ -569,14 +569,14 @@ def compute_synchrony_metrics(sorting_result, synchrony_sizes=(2, 4, 8), unit_id _default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) -def compute_firing_ranges(sorting_result, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): +def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): """Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object bin_size_s : float, default: 5 The size of the bin in seconds. percentiles : tuple, default: (5, 95) @@ -593,16 +593,16 @@ def compute_firing_ranges(sorting_result, bin_size_s=5, percentiles=(5, 95), uni ----- Designed by Simon Musall and ported to SpikeInterface by Alessio Buccino. """ - sampling_frequency = sorting_result.sampling_frequency + sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.unit_ids if all( [ - sorting_result.get_num_samples(segment_index) < bin_size_samples - for segment_index in range(sorting_result.get_num_segments()) + sorting_analyzer.get_num_samples(segment_index) < bin_size_samples + for segment_index in range(sorting_analyzer.get_num_segments()) ] ): warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.") @@ -610,8 +610,8 @@ def compute_firing_ranges(sorting_result, bin_size_s=5, percentiles=(5, 95), uni # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} - for segment_index in range(sorting_result.get_num_segments()): - num_samples = sorting_result.get_num_samples(segment_index) + for segment_index in range(sorting_analyzer.get_num_segments()): + num_samples = sorting_analyzer.get_num_samples(segment_index) edges = np.arange(0, num_samples + 1, bin_size_samples) for unit_id in unit_ids: @@ -634,7 +634,7 @@ def compute_firing_ranges(sorting_result, bin_size_s=5, percentiles=(5, 95), uni def compute_amplitude_cv_metrics( - sorting_result, + sorting_analyzer, average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, @@ -647,8 +647,8 @@ def compute_amplitude_cv_metrics( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object average_num_spikes_per_bin : int, default: 50 The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is @@ -677,15 +677,15 @@ def compute_amplitude_cv_metrics( "spike_amplitudes", "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" - sorting = sorting_result.sorting - total_duration = sorting_result.get_total_duration() + sorting = sorting_analyzer.sorting + total_duration = sorting_analyzer.get_total_duration() spikes = sorting.to_spike_vector() num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") if unit_ids is None: unit_ids = sorting.unit_ids - if sorting_result.has_extension(amplitude_extension): - amps = sorting_result.get_extension(amplitude_extension).get_data() + if sorting_analyzer.has_extension(amplitude_extension): + amps = sorting_analyzer.get_extension(amplitude_extension).get_data() else: warnings.warn("compute_amplitude_cv_metrics() need 'spike_amplitudes' or 'amplitude_scalings'") empty_dict = {unit_id: np.nan for unit_id in unit_ids} @@ -693,7 +693,7 @@ def compute_amplitude_cv_metrics( # precompute segment slice segment_slices = [] - for segment_index in range(sorting_result.get_num_segments()): + for segment_index in range(sorting_analyzer.get_num_segments()): i0 = np.searchsorted(spikes["segment_index"], segment_index) i1 = np.searchsorted(spikes["segment_index"], segment_index + 1) segment_slices.append(slice(i0, i1)) @@ -702,13 +702,13 @@ def compute_amplitude_cv_metrics( amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: firing_rate = num_spikes[unit_id] / total_duration - temporal_bin_size_samples = int((average_num_spikes_per_bin / firing_rate) * sorting_result.sampling_frequency) + temporal_bin_size_samples = int((average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency) amp_spreads = [] # bins and amplitude means are computed for each segment - for segment_index in range(sorting_result.get_num_segments()): + for segment_index in range(sorting_analyzer.get_num_segments()): sample_bin_edges = np.arange( - 0, sorting_result.get_num_samples(segment_index) + 1, temporal_bin_size_samples + 0, sorting_analyzer.get_num_samples(segment_index) + 1, temporal_bin_size_samples ) spikes_in_segment = spikes[segment_slices[segment_index]] amps_in_segment = amps[segment_slices[segment_index]] @@ -738,36 +738,36 @@ def compute_amplitude_cv_metrics( ) -def _get_amplitudes_by_units(sorting_result, unit_ids, peak_sign): +def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): # used by compute_amplitude_cutoffs and compute_amplitude_medians amplitudes_by_units = {} - if sorting_result.has_extension("spike_amplitudes"): - spikes = sorting_result.sorting.to_spike_vector() - ext = sorting_result.get_extension("spike_amplitudes") + if sorting_analyzer.has_extension("spike_amplitudes"): + spikes = sorting_analyzer.sorting.to_spike_vector() + ext = sorting_analyzer.get_extension("spike_amplitudes") all_amplitudes = ext.get_data() for unit_id in unit_ids: - unit_index = sorting_result.sorting.id_to_index(unit_id) + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) spike_mask = spikes["unit_index"] == unit_index amplitudes_by_units[unit_id] = all_amplitudes[spike_mask] - elif sorting_result.has_extension("waveforms"): - waveforms_ext = sorting_result.get_extension("waveforms") + elif sorting_analyzer.has_extension("waveforms"): + waveforms_ext = sorting_analyzer.get_extension("waveforms") before = waveforms_ext.nbefore - extremum_channels_ids = get_template_extremum_channel(sorting_result, peak_sign=peak_sign) + extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) for unit_id in unit_ids: waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) chan_id = extremum_channels_ids[unit_id] - if sorting_result.is_sparse(): - chan_ind = np.where(sorting_result.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] + if sorting_analyzer.is_sparse(): + chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] else: - chan_ind = sorting_result.channel_ids_to_indices([chan_id])[0] + chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] return amplitudes_by_units def compute_amplitude_cutoffs( - sorting_result, + sorting_analyzer, peak_sign="neg", num_histogram_bins=500, histogram_smoothing_value=3, @@ -778,8 +778,8 @@ def compute_amplitude_cutoffs( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the peaks. num_histogram_bins : int, default: 100 @@ -803,7 +803,7 @@ def compute_amplitude_cutoffs( ----- This approach assumes the amplitude histogram is symmetric (not valid in the presence of drift). If available, amplitudes are extracted from the "spike_amplitude" extension (recommended). - If the "spike_amplitude" extension is not available, the amplitudes are extracted from the SortingResult, + If the "spike_amplitude" extension is not available, the amplitudes are extracted from the SortingAnalyzer, which usually has waveforms for a small subset of spikes (500 by default). References @@ -815,21 +815,21 @@ def compute_amplitude_cutoffs( """ if unit_ids is None: - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids all_fraction_missing = {} - if sorting_result.has_extension("spike_amplitudes") or sorting_result.has_extension("waveforms"): + if sorting_analyzer.has_extension("spike_amplitudes") or sorting_analyzer.has_extension("waveforms"): invert_amplitudes = False if ( - sorting_result.has_extension("spike_amplitudes") - and sorting_result.get_extension("spike_amplitudes").params["peak_sign"] == "pos" + sorting_analyzer.has_extension("spike_amplitudes") + and sorting_analyzer.get_extension("spike_amplitudes").params["peak_sign"] == "pos" ): invert_amplitudes = True - elif sorting_result.has_extension("waveforms") and peak_sign == "pos": + elif sorting_analyzer.has_extension("waveforms") and peak_sign == "pos": invert_amplitudes = True - amplitudes_by_units = _get_amplitudes_by_units(sorting_result, unit_ids, peak_sign) + amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -856,13 +856,13 @@ def compute_amplitude_cutoffs( ) -def compute_amplitude_medians(sorting_result, peak_sign="neg", unit_ids=None): +def compute_amplitude_medians(sorting_analyzer, peak_sign="neg", unit_ids=None): """Compute median of the amplitude distributions (in absolute value). Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the peaks. unit_ids : list or None @@ -879,13 +879,13 @@ def compute_amplitude_medians(sorting_result, peak_sign="neg", unit_ids=None): This code is ported from: https://github.com/int-brain-lab/ibllib/blob/master/brainbox/metrics/single_units.py """ - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting if unit_ids is None: - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids all_amplitude_medians = {} - if sorting_result.has_extension("spike_amplitudes") or sorting_result.has_extension("waveforms"): - amplitudes_by_units = _get_amplitudes_by_units(sorting_result, unit_ids, peak_sign) + if sorting_analyzer.has_extension("spike_amplitudes") or sorting_analyzer.has_extension("waveforms"): + amplitudes_by_units = _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) else: @@ -900,7 +900,7 @@ def compute_amplitude_medians(sorting_result, peak_sign="neg", unit_ids=None): def compute_drift_metrics( - sorting_result, + sorting_analyzer, interval_s=60, min_spikes_per_interval=100, direction="y", @@ -924,8 +924,8 @@ def compute_drift_metrics( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object interval_s : int, default: 60 Interval length is seconds for computing spike depth min_spikes_per_interval : int, default: 100 @@ -961,12 +961,12 @@ def compute_drift_metrics( there are large displacements in between segments, the resulting metric values will be very high. """ res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.unit_ids - if sorting_result.has_extension("spike_locations"): - spike_locations_ext = sorting_result.get_extension("spike_locations") + if sorting_analyzer.has_extension("spike_locations"): + spike_locations_ext = sorting_analyzer.get_extension("spike_locations") spike_locations = spike_locations_ext.get_data() # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") spikes = sorting.to_spike_vector() @@ -988,11 +988,11 @@ def compute_drift_metrics( else: return res(empty_dict, empty_dict, empty_dict) - interval_samples = int(interval_s * sorting_result.sampling_frequency) + interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) - total_duration = sorting_result.get_total_duration() + total_duration = sorting_analyzer.get_total_duration() if total_duration < min_num_bins * interval_s: warnings.warn( "The recording is too short given the specified 'interval_s' and " @@ -1016,8 +1016,8 @@ def compute_drift_metrics( # now compute median positions and concatenate them over segments median_position_segments = None - for segment_index in range(sorting_result.get_num_segments()): - seg_length = sorting_result.get_num_samples(segment_index) + for segment_index in range(sorting_analyzer.get_num_segments()): + seg_length = sorting_analyzer.get_num_samples(segment_index) num_bin_edges = seg_length // interval_samples + 1 bins = np.arange(num_bin_edges) * interval_samples spike_vector = sorting.to_spike_vector() @@ -1363,7 +1363,7 @@ def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters, def compute_sd_ratio( - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, @@ -1377,8 +1377,8 @@ def compute_sd_ratio( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object censored_period_ms : float, default: 4.0 The censored period in milliseconds. This is to remove any potential bursts that could affect the SD. correct_for_drift: bool, default: True @@ -1400,21 +1400,21 @@ def compute_sd_ratio( import numba from ..curation.curation_tools import _find_duplicated_spikes_keep_first_iterative - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting - censored_period = int(round(censored_period_ms * 1e-3 * sorting_result.sampling_frequency)) + censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids - if not sorting_result.has_recording(): + if not sorting_analyzer.has_recording(): warnings.warn( - "The `sd_ratio` metric cannot work with a recordless SortingResult object" + "The `sd_ratio` metric cannot work with a recordless SortingAnalyzer object" "SD ratio metric will be set to NaN" ) return {unit_id: np.nan for unit_id in unit_ids} - if sorting_result.has_extension("spike_amplitudes"): - amplitudes_ext = sorting_result.get_extension("spike_amplitudes") + if sorting_analyzer.has_extension("spike_amplitudes"): + amplitudes_ext = sorting_analyzer.get_extension("spike_amplitudes") # spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit") spike_amplitudes = amplitudes_ext.get_data() else: @@ -1426,23 +1426,23 @@ def compute_sd_ratio( return {unit_id: np.nan for unit_id in unit_ids} noise_levels = get_noise_levels( - sorting_result.recording, return_scaled=amplitudes_ext.params["return_scaled"], method="std" + sorting_analyzer.recording, return_scaled=amplitudes_ext.params["return_scaled"], method="std" ) - best_channels = get_template_extremum_channel(sorting_result, outputs="index", **kwargs) + best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", **kwargs) n_spikes = sorting.count_num_spikes_per_unit() if correct_for_template_itself: - tamplates_array = _get_dense_templates_array(sorting_result, return_scaled=True) + tamplates_array = _get_dense_templates_array(sorting_analyzer, return_scaled=True) spikes = sorting.to_spike_vector() sd_ratio = {} for unit_id in unit_ids: - unit_index = sorting_result.sorting.id_to_index(unit_id) + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) spk_amp = [] - for segment_index in range(sorting_result.get_num_segments()): - # spike_train = sorting_result.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype( + for segment_index in range(sorting_analyzer.get_num_segments()): + # spike_train = sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype( # np.int64, copy=False # ) spike_mask = (spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index) @@ -1472,14 +1472,14 @@ def compute_sd_ratio( std_noise = noise_levels[best_channel] if correct_for_template_itself: - # template = sorting_result.get_template(unit_id, force_dense=True)[:, best_channel] + # template = sorting_analyzer.get_template(unit_id, force_dense=True)[:, best_channel] template = tamplates_array[unit_index, :, :][:, best_channel] nsamples = template.shape[0] # Computing the variance of a trace that is all 0 and n_spikes non-overlapping template. # TODO: Take into account that templates for different segments might differ. - p = nsamples * n_spikes[unit_id] / sorting_result.get_total_samples() + p = nsamples * n_spikes[unit_id] / sorting_analyzer.get_total_samples() total_variance = p * np.mean(template**2) - p**2 * np.mean(template) std_noise = np.sqrt(std_noise**2 - total_variance) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index f6ac46d24c..53984579d4 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -58,14 +58,14 @@ def get_quality_pca_metric_list(): def calculate_pc_metrics( - sorting_result, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False + sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False ): """Calculate principal component derived metrics. Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object metric_names : list of str, default: None The list of PC metrics to compute. If not provided, defaults to all PC metrics. @@ -85,21 +85,21 @@ def calculate_pc_metrics( pc_metrics : dict The computed PC metrics. """ - pca_ext = sorting_result.get_extension("principal_components") + pca_ext = sorting_analyzer.get_extension("principal_components") assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting if metric_names is None: metric_names = _possible_pc_metric_names if qm_params is None: qm_params = _default_params - extremum_channels = get_template_extremum_channel(sorting_result) + extremum_channels = get_template_extremum_channel(sorting_analyzer) if unit_ids is None: - unit_ids = sorting_result.unit_ids - channel_ids = sorting_result.channel_ids + unit_ids = sorting_analyzer.unit_ids + channel_ids = sorting_analyzer.channel_ids # create output dict of dict pc_metrics['metric_name'][unit_id] pc_metrics = {k: {} for k in metric_names} @@ -113,8 +113,8 @@ def calculate_pc_metrics( # Compute nspikes and firing rate outside of main loop for speed if any([n in metric_names for n in ["nn_isolation", "nn_noise_overlap"]]): - n_spikes_all_units = compute_num_spikes(sorting_result, unit_ids=unit_ids) - fr_all_units = compute_firing_rates(sorting_result, unit_ids=unit_ids) + n_spikes_all_units = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) + fr_all_units = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids) else: n_spikes_all_units = None fr_all_units = None @@ -130,15 +130,15 @@ def calculate_pc_metrics( items = [] for unit_id in unit_ids: - if sorting_result.is_sparse(): - neighbor_channel_ids = sorting_result.sparsity.unit_id_to_channel_ids[unit_id] + if sorting_analyzer.is_sparse(): + neighbor_channel_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] neighbor_unit_ids = [ other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids ] else: neighbor_channel_ids = channel_ids neighbor_unit_ids = unit_ids - neighbor_channel_indices = sorting_result.channel_ids_to_indices(neighbor_channel_ids) + neighbor_channel_indices = sorting_analyzer.channel_ids_to_indices(neighbor_channel_ids) labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] @@ -351,7 +351,7 @@ def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_n def nearest_neighbors_isolation( - sorting_result, + sorting_analyzer, this_unit_id: int | str, n_spikes_all_units: dict = None, fr_all_units: dict = None, @@ -369,8 +369,8 @@ def nearest_neighbors_isolation( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object this_unit_id : int | str The ID for the unit to calculate these metrics for. n_spikes_all_units: dict, default: None @@ -396,10 +396,10 @@ def nearest_neighbors_isolation( radius_um : float, default: 100 The radius, in um, that channels need to be within the peak channel to be included. peak_sign: "neg" | "pos" | "both", default: "neg" - The peak_sign used to compute sparsity and neighbor units. Used if sorting_result + The peak_sign used to compute sparsity and neighbor units. Used if sorting_analyzer is not sparse already. min_spatial_overlap : float, default: 100 - In case sorting_result is sparse, other units are selected if they share at least + In case sorting_analyzer is sparse, other units are selected if they share at least `min_spatial_overlap` times `n_target_unit_channels` with the target unit seed : int, default: None Seed for random subsampling of spikes. @@ -444,15 +444,15 @@ def nearest_neighbors_isolation( """ rng = np.random.default_rng(seed=seed) - waveforms_ext = sorting_result.get_extension("waveforms") + waveforms_ext = sorting_analyzer.get_extension("waveforms") assert waveforms_ext is not None, "nearest_neighbors_isolation() need extension 'waveforms'" - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting all_units_ids = sorting.get_unit_ids() if n_spikes_all_units is None: - n_spikes_all_units = compute_num_spikes(sorting_result) + n_spikes_all_units = compute_num_spikes(sorting_analyzer) if fr_all_units is None: - fr_all_units = compute_firing_rates(sorting_result) + fr_all_units = compute_firing_rates(sorting_analyzer) # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: @@ -482,17 +482,17 @@ def nearest_neighbors_isolation( other_units_ids = np.setdiff1d(all_units_ids, this_unit_id) # get waveforms of target unit - # waveforms_target_unit = sorting_result.get_waveforms(unit_id=this_unit_id) + # waveforms_target_unit = sorting_analyzer.get_waveforms(unit_id=this_unit_id) waveforms_target_unit = waveforms_ext.get_waveforms_one_unit(unit_id=this_unit_id, force_dense=False) n_spikes_target_unit = waveforms_target_unit.shape[0] # find units whose signal channels (i.e. channels inside some radius around # the channel with largest amplitude) overlap with signal channels of the target unit - if sorting_result.is_sparse(): - sparsity = sorting_result.sparsity + if sorting_analyzer.is_sparse(): + sparsity = sorting_analyzer.sparsity else: - sparsity = compute_sparsity(sorting_result, method="radius", peak_sign=peak_sign, radius_um=radius_um) + sparsity = compute_sparsity(sorting_analyzer, method="radius", peak_sign=peak_sign, radius_um=radius_um) closest_chans_target_unit = sparsity.unit_id_to_channel_indices[this_unit_id] n_channels_target_unit = len(closest_chans_target_unit) # select other units that have a minimum spatial overlap with target unit @@ -513,7 +513,7 @@ def nearest_neighbors_isolation( len(other_units_ids), ) for other_unit_id in other_units_ids: - # waveforms_other_unit = sorting_result.get_waveforms(unit_id=other_unit_id) + # waveforms_other_unit = sorting_analyzer.get_waveforms(unit_id=other_unit_id) waveforms_other_unit = waveforms_ext.get_waveforms_one_unit(unit_id=other_unit_id, force_dense=False) n_spikes_other_unit = waveforms_other_unit.shape[0] @@ -528,7 +528,7 @@ def nearest_neighbors_isolation( # project this unit and other unit waveforms on common subspace common_channel_idxs = np.intersect1d(closest_chans_target_unit, closest_chans_other_unit) - if sorting_result.is_sparse(): + if sorting_analyzer.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ :, :, np.isin(closest_chans_target_unit, common_channel_idxs) @@ -565,7 +565,7 @@ def nearest_neighbors_isolation( def nearest_neighbors_noise_overlap( - sorting_result, + sorting_analyzer, this_unit_id: int | str, n_spikes_all_units: dict = None, fr_all_units: dict = None, @@ -582,8 +582,8 @@ def nearest_neighbors_noise_overlap( Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object this_unit_id : int | str The ID of the unit to calculate this metric on. n_spikes_all_units: dict, default: None @@ -607,7 +607,7 @@ def nearest_neighbors_noise_overlap( radius_um : float, default: 100 The radius, in um, that channels need to be within the peak channel to be included. peak_sign: "neg" | "pos" | "both", default: "neg" - The peak_sign used to compute sparsity and neighbor units. Used if sorting_result + The peak_sign used to compute sparsity and neighbor units. Used if sorting_analyzer is not sparse already. seed : int, default: 0 Random seed for subsampling spikes. @@ -638,16 +638,16 @@ def nearest_neighbors_noise_overlap( """ rng = np.random.default_rng(seed=seed) - waveforms_ext = sorting_result.get_extension("waveforms") + waveforms_ext = sorting_analyzer.get_extension("waveforms") assert waveforms_ext is not None, "nearest_neighbors_isolation() need extension 'waveforms'" - templates_ext = sorting_result.get_extension("templates") + templates_ext = sorting_analyzer.get_extension("templates") assert templates_ext is not None, "nearest_neighbors_isolation() need extension 'templates'" if n_spikes_all_units is None: - n_spikes_all_units = compute_num_spikes(sorting_result) + n_spikes_all_units = compute_num_spikes(sorting_analyzer) if fr_all_units is None: - fr_all_units = compute_firing_rates(sorting_result) + fr_all_units = compute_firing_rates(sorting_analyzer) # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: @@ -665,7 +665,7 @@ def nearest_neighbors_noise_overlap( else: # get random snippets from the recording to create a noise cluster nsamples = waveforms_ext.nbefore + waveforms_ext.nafter - recording = sorting_result.recording + recording = sorting_analyzer.recording noise_cluster = get_random_data_chunks( recording, return_scaled=waveforms_ext.params["return_scaled"], @@ -676,7 +676,7 @@ def nearest_neighbors_noise_overlap( noise_cluster = np.reshape(noise_cluster, (max_spikes, nsamples, -1)) # get waveforms for target cluster - # waveforms = sorting_result.get_waveforms(unit_id=this_unit_id).copy() + # waveforms = sorting_analyzer.get_waveforms(unit_id=this_unit_id).copy() waveforms = waveforms_ext.get_waveforms_one_unit(unit_id=this_unit_id, force_dense=False).copy() # adjust the size of the target and noise clusters to be equal @@ -692,20 +692,20 @@ def nearest_neighbors_noise_overlap( n_snippets = max_spikes # restrict to channels with significant signal - if sorting_result.is_sparse(): - sparsity = sorting_result.sparsity + if sorting_analyzer.is_sparse(): + sparsity = sorting_analyzer.sparsity else: - sparsity = compute_sparsity(sorting_result, method="radius", peak_sign=peak_sign, radius_um=radius_um) + sparsity = compute_sparsity(sorting_analyzer, method="radius", peak_sign=peak_sign, radius_um=radius_um) noise_cluster = noise_cluster[:, :, sparsity.unit_id_to_channel_indices[this_unit_id]] # compute weighted noise snippet (Z) - # median_waveform = sorting_result.get_template(unit_id=this_unit_id, mode="median") + # median_waveform = sorting_analyzer.get_template(unit_id=this_unit_id, mode="median") all_templates = templates_ext.get_data(operator="median") - this_unit_index = sorting_result.sorting.id_to_index(this_unit_id) + this_unit_index = sorting_analyzer.sorting.id_to_index(this_unit_id) median_waveform = all_templates[this_unit_index, :, :] - # in case sorting_result is sparse, waveforms and templates are already sparse - if not sorting_result.is_sparse(): + # in case sorting_analyzer is sparse, waveforms and templates are already sparse + if not sorting_analyzer.is_sparse(): # @alessio : this next line is suspicious because the waveforms is already sparse no ? Am i wrong ? waveforms = waveforms[:, :, sparsity.unit_id_to_channel_indices[this_unit_id]] median_waveform = median_waveform[:, sparsity.unit_id_to_channel_indices[this_unit_id]] diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 0a7f9559e2..fb32280a3b 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -9,7 +9,7 @@ import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.core.sortingresult import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension from .quality_metric_list import calculate_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names @@ -23,8 +23,8 @@ class ComputeQualityMetrics(ResultExtension): Parameters ---------- - sorting_result: SortingResult - A SortingResult object + sorting_analyzer: SortingAnalyzer + A SortingAnalyzer object metric_names : list or None List of quality metrics to compute. qm_params : dict or None @@ -53,14 +53,14 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list - if self.sorting_result.has_extension("principal_components") and not skip_pc_metrics: + if self.sorting_analyzer.has_extension("principal_components") and not skip_pc_metrics: # by default 'nearest_neightbor' is removed because too slow pc_metrics = _possible_pc_metric_names.copy() pc_metrics.remove("nn_isolation") pc_metrics.remove("nn_noise_overlap") metric_names += pc_metrics # if spike_locations are not available, drift is removed from the list - if not self.sorting_result.has_extension("spike_locations"): + if not self.sorting_analyzer.has_extension("spike_locations"): if "drift" in metric_names: metric_names.remove("drift") @@ -100,7 +100,7 @@ def _run(self, verbose=False, **job_kwargs): n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] - sorting = self.sorting_result.sorting + sorting = self.sorting_analyzer.sorting unit_ids = sorting.unit_ids non_empty_unit_ids = sorting.get_non_empty_unit_ids() empty_unit_ids = unit_ids[~np.isin(unit_ids, non_empty_unit_ids)] @@ -126,7 +126,7 @@ def _run(self, verbose=False, **job_kwargs): func = _misc_metric_name_to_func[metric_name] params = qm_params[metric_name] if metric_name in qm_params else {} - res = func(self.sorting_result, unit_ids=non_empty_unit_ids, **params) + res = func(self.sorting_analyzer, unit_ids=non_empty_unit_ids, **params) # QM with uninstall dependencies might return None if res is not None: if isinstance(res, dict): @@ -141,10 +141,10 @@ def _run(self, verbose=False, **job_kwargs): # metrics based on PCs pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]: - if not self.sorting_result.has_extension("principal_components"): + if not self.sorting_analyzer.has_extension("principal_components"): raise ValueError("waveform_principal_component must be provied") pc_metrics = calculate_pc_metrics( - self.sorting_result, + self.sorting_analyzer, unit_ids=non_empty_unit_ids, metric_names=pc_metric_names, # sparsity=sparsity, diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 27a84da440..96920c08e5 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -7,7 +7,7 @@ synthetize_spike_train_bad_isi, add_synchrony_to_sorting, generate_ground_truth_recording, - start_sorting_result, + create_sorting_analyzer, ) # from spikeinterface.extractors.toy_example import toy_example @@ -47,7 +47,7 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _sorting_result_simple(): +def _sorting_analyzer_simple(): recording, sorting = generate_ground_truth_recording( durations=[ 50.0, @@ -60,21 +60,21 @@ def _sorting_result_simple(): seed=2205, ) - sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - sorting_result.select_random_spikes(max_spikes_per_unit=300, seed=2205) - sorting_result.compute("noise_levels") - sorting_result.compute("waveforms", **job_kwargs) - sorting_result.compute("templates") - sorting_result.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_result.compute("spike_amplitudes", **job_kwargs) + sorting_analyzer.select_random_spikes(max_spikes_per_unit=300, seed=2205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - return sorting_result + return sorting_analyzer @pytest.fixture(scope="module") -def sorting_result_simple(): - return _sorting_result_simple() +def sorting_analyzer_simple(): + return _sorting_analyzer_simple() def _sorting_violation(): @@ -106,7 +106,7 @@ def _sorting_violation(): return sorting -def _sorting_result_violations(): +def _sorting_analyzer_violations(): sorting = _sorting_violation() duration = (sorting.to_spike_vector()["sample_index"][-1] + 1) / sorting.sampling_frequency @@ -119,14 +119,14 @@ def _sorting_result_violations(): noise_kwargs=dict(noise_level=5.0, strategy="tile_pregenerated"), seed=2205, ) - sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) # this used only for ISI metrics so no need to compute heavy extensions - return sorting_result + return sorting_analyzer @pytest.fixture(scope="module") -def sorting_result_violations(): - return _sorting_result_violations() +def sorting_analyzer_violations(): + return _sorting_analyzer_violations() def test_mahalanobis_metrics(): @@ -191,10 +191,10 @@ def test_simplified_silhouette_score_metrics(): assert sim_sil_score1 < sim_sil_score2 -def test_calculate_firing_rate_num_spikes(sorting_result_simple): - sorting_result = sorting_result_simple - firing_rates = compute_firing_rates(sorting_result) - num_spikes = compute_num_spikes(sorting_result) +def test_calculate_firing_rate_num_spikes(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + firing_rates = compute_firing_rates(sorting_analyzer) + num_spikes = compute_num_spikes(sorting_analyzer) # testing method accuracy with magic number is not a good pratcice, I remove this. # firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09} @@ -203,20 +203,20 @@ def test_calculate_firing_rate_num_spikes(sorting_result_simple): # np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) -def test_calculate_firing_range(sorting_result_simple): - sorting_result = sorting_result_simple - firing_ranges = compute_firing_ranges(sorting_result) +def test_calculate_firing_range(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + firing_ranges = compute_firing_ranges(sorting_analyzer) print(firing_ranges) with pytest.warns(UserWarning) as w: - firing_ranges_nan = compute_firing_ranges(sorting_result, bin_size_s=sorting_result.get_total_duration() + 1) + firing_ranges_nan = compute_firing_ranges(sorting_analyzer, bin_size_s=sorting_analyzer.get_total_duration() + 1) assert np.all([np.isnan(f) for f in firing_ranges_nan.values()]) -def test_calculate_amplitude_cutoff(sorting_result_simple): - sorting_result = sorting_result_simple - # spike_amps = sorting_result.get_extension("spike_amplitudes").get_data() - amp_cuts = compute_amplitude_cutoffs(sorting_result, num_histogram_bins=10) +def test_calculate_amplitude_cutoff(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + amp_cuts = compute_amplitude_cutoffs(sorting_analyzer, num_histogram_bins=10) # print(amp_cuts) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -224,10 +224,10 @@ def test_calculate_amplitude_cutoff(sorting_result_simple): # assert np.allclose(list(amp_cuts_gt.values()), list(amp_cuts.values()), rtol=0.05) -def test_calculate_amplitude_median(sorting_result_simple): - sorting_result = sorting_result_simple - # spike_amps = sorting_result.get_extension("spike_amplitudes").get_data() - amp_medians = compute_amplitude_medians(sorting_result) +def test_calculate_amplitude_median(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + amp_medians = compute_amplitude_medians(sorting_analyzer) # print(amp_medians) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -235,16 +235,16 @@ def test_calculate_amplitude_median(sorting_result_simple): # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) -def test_calculate_amplitude_cv_metrics(sorting_result_simple): - sorting_result = sorting_result_simple - amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(sorting_result, average_num_spikes_per_bin=20) +def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(sorting_analyzer, average_num_spikes_per_bin=20) print(amp_cv_median) print(amp_cv_range) - # amps_scalings = compute_amplitude_scalings(sorting_result) - sorting_result.compute("amplitude_scalings", **job_kwargs) + # amps_scalings = compute_amplitude_scalings(sorting_analyzer) + sorting_analyzer.compute("amplitude_scalings", **job_kwargs) amp_cv_median_scalings, amp_cv_range_scalings = compute_amplitude_cv_metrics( - sorting_result, + sorting_analyzer, average_num_spikes_per_bin=20, amplitude_extension="amplitude_scalings", min_num_bins=5, @@ -253,9 +253,9 @@ def test_calculate_amplitude_cv_metrics(sorting_result_simple): print(amp_cv_range_scalings) -def test_calculate_snrs(sorting_result_simple): - sorting_result = sorting_result_simple - snrs = compute_snrs(sorting_result) +def test_calculate_snrs(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + snrs = compute_snrs(sorting_analyzer) print(snrs) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -263,9 +263,9 @@ def test_calculate_snrs(sorting_result_simple): # assert np.allclose(list(snrs_gt.values()), list(snrs.values()), rtol=0.05) -def test_calculate_presence_ratio(sorting_result_simple): - sorting_result = sorting_result_simple - ratios = compute_presence_ratios(sorting_result, bin_duration_s=10) +def test_calculate_presence_ratio(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + ratios = compute_presence_ratios(sorting_analyzer, bin_duration_s=10) print(ratios) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -273,9 +273,9 @@ def test_calculate_presence_ratio(sorting_result_simple): # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) -def test_calculate_isi_violations(sorting_result_violations): - sorting_result = sorting_result_violations - isi_viol, counts = compute_isi_violations(sorting_result, isi_threshold_ms=1, min_isi_ms=0.0) +def test_calculate_isi_violations(sorting_analyzer_violations): + sorting_analyzer = sorting_analyzer_violations + isi_viol, counts = compute_isi_violations(sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0) print(isi_viol) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -285,9 +285,9 @@ def test_calculate_isi_violations(sorting_result_violations): # np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) -def test_calculate_sliding_rp_violations(sorting_result_violations): - sorting_result = sorting_result_violations - contaminations = compute_sliding_rp_violations(sorting_result, bin_size_ms=0.25, window_size_s=1) +def test_calculate_sliding_rp_violations(sorting_analyzer_violations): + sorting_analyzer = sorting_analyzer_violations + contaminations = compute_sliding_rp_violations(sorting_analyzer, bin_size_ms=0.25, window_size_s=1) print(contaminations) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -295,10 +295,10 @@ def test_calculate_sliding_rp_violations(sorting_result_violations): # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) -def test_calculate_rp_violations(sorting_result_violations): - sorting_result = sorting_result_violations +def test_calculate_rp_violations(sorting_analyzer_violations): + sorting_analyzer = sorting_analyzer_violations rp_contamination, counts = compute_refrac_period_violations( - sorting_result, refractory_period_ms=1, censored_period_ms=0.0 + sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0 ) print(rp_contamination, counts) @@ -312,19 +312,19 @@ def test_calculate_rp_violations(sorting_result_violations): {0: np.array([28, 150], dtype=np.int16), 1: np.array([], dtype=np.int16)}, 30000 ) # we.sorting = sorting - sorting_result2 = start_sorting_result(sorting, sorting_result.recording, format="memory", sparse=False) + sorting_analyzer2 = create_sorting_analyzer(sorting, sorting_analyzer.recording, format="memory", sparse=False) rp_contamination, counts = compute_refrac_period_violations( - sorting_result2, refractory_period_ms=1, censored_period_ms=0.0 + sorting_analyzer2, refractory_period_ms=1, censored_period_ms=0.0 ) assert np.isnan(rp_contamination[1]) -def test_synchrony_metrics(sorting_result_simple): - sorting_result = sorting_result_simple - sorting = sorting_result.sorting +def test_synchrony_metrics(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + sorting = sorting_analyzer.sorting synchrony_sizes = (2, 3, 4) - synchrony_metrics = compute_synchrony_metrics(sorting_result, synchrony_sizes=synchrony_sizes) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=synchrony_sizes) print(synchrony_metrics) # check returns @@ -333,13 +333,13 @@ def test_synchrony_metrics(sorting_result_simple): # here we test that increasing added synchrony is captured by syncrhony metrics added_synchrony_levels = (0.2, 0.5, 0.8) - previous_sorting_result = sorting_result + previous_sorting_analyzer = sorting_analyzer for sync_level in added_synchrony_levels: sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) - sorting_result_sync = start_sorting_result(sorting_sync, sorting_result.recording, format="memory") + sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory") - previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_result, synchrony_sizes=synchrony_sizes) - current_synchrony_metrics = compute_synchrony_metrics(sorting_result_sync, synchrony_sizes=synchrony_sizes) + previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer, synchrony_sizes=synchrony_sizes) + current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes) print(current_synchrony_metrics) # check that all values increased for i, col in enumerate(previous_synchrony_metrics._fields): @@ -351,16 +351,16 @@ def test_synchrony_metrics(sorting_result_simple): ) # set new previous waveform extractor - previous_sorting_result = sorting_result_sync + previous_sorting_analyzer = sorting_analyzer_sync @pytest.mark.sortingcomponents -def test_calculate_drift_metrics(sorting_result_simple): - sorting_result = sorting_result_simple - sorting_result.compute("spike_locations", **job_kwargs) +def test_calculate_drift_metrics(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + sorting_analyzer.compute("spike_locations", **job_kwargs) drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics( - sorting_result, interval_s=10, min_spikes_per_interval=10 + sorting_analyzer, interval_s=10, min_spikes_per_interval=10 ) # print(drifts_ptps, drifts_stds, drift_mads) @@ -374,35 +374,35 @@ def test_calculate_drift_metrics(sorting_result_simple): # assert np.allclose(list(drift_mads_gt.values()), list(drift_mads.values()), rtol=0.05) -def test_calculate_sd_ratio(sorting_result_simple): +def test_calculate_sd_ratio(sorting_analyzer_simple): sd_ratio = compute_sd_ratio( - sorting_result_simple, + sorting_analyzer_simple, ) - assert np.all(list(sd_ratio.keys()) == sorting_result_simple.unit_ids) + assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) if __name__ == "__main__": - sorting_result = _sorting_result_simple() - print(sorting_result) - - # test_calculate_firing_rate_num_spikes(sorting_result) - # test_calculate_snrs(sorting_result) - test_calculate_amplitude_cutoff(sorting_result) - # test_calculate_presence_ratio(sorting_result) - # test_calculate_amplitude_median(sorting_result) - # test_calculate_sliding_rp_violations(sorting_result) - # test_calculate_drift_metrics(sorting_result) - # test_synchrony_metrics(sorting_result) - # test_calculate_firing_range(sorting_result) - # test_calculate_amplitude_cv_metrics(sorting_result) - test_calculate_sd_ratio(sorting_result) - - # sorting_result_violations = _sorting_result_violations() - # print(sorting_result_violations) - # test_calculate_isi_violations(sorting_result_violations) - # test_calculate_sliding_rp_violations(sorting_result_violations) - # test_calculate_rp_violations(sorting_result_violations) + sorting_analyzer = _sorting_analyzer_simple() + print(sorting_analyzer) + + # test_calculate_firing_rate_num_spikes(sorting_analyzer) + # test_calculate_snrs(sorting_analyzer) + test_calculate_amplitude_cutoff(sorting_analyzer) + # test_calculate_presence_ratio(sorting_analyzer) + # test_calculate_amplitude_median(sorting_analyzer) + # test_calculate_sliding_rp_violations(sorting_analyzer) + # test_calculate_drift_metrics(sorting_analyzer) + # test_synchrony_metrics(sorting_analyzer) + # test_calculate_firing_range(sorting_analyzer) + # test_calculate_amplitude_cv_metrics(sorting_analyzer) + test_calculate_sd_ratio(sorting_analyzer) + + # sorting_analyzer_violations = _sorting_analyzer_violations() + # print(sorting_analyzer_violations) + # test_calculate_isi_violations(sorting_analyzer_violations) + # test_calculate_sliding_rp_violations(sorting_analyzer_violations) + # test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 29b334e97f..6aa0ba73d6 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -8,7 +8,7 @@ synthetize_spike_train_bad_isi, add_synchrony_to_sorting, generate_ground_truth_recording, - start_sorting_result, + create_sorting_analyzer, ) # from spikeinterface.extractors.toy_example import toy_example @@ -24,7 +24,7 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _sorting_result_simple(): +def _sorting_analyzer_simple(): recording, sorting = generate_ground_truth_recording( durations=[ 50.0, @@ -37,29 +37,29 @@ def _sorting_result_simple(): seed=2205, ) - sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - sorting_result.select_random_spikes(max_spikes_per_unit=300, seed=2205) - sorting_result.compute("noise_levels") - sorting_result.compute("waveforms", **job_kwargs) - sorting_result.compute("templates", operators=["average", "std", "median"]) - sorting_result.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_result.compute("spike_amplitudes", **job_kwargs) + sorting_analyzer.select_random_spikes(max_spikes_per_unit=300, seed=2205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates", operators=["average", "std", "median"]) + sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - return sorting_result + return sorting_analyzer @pytest.fixture(scope="module") -def sorting_result_simple(): - return _sorting_result_simple() +def sorting_analyzer_simple(): + return _sorting_analyzer_simple() -def test_calculate_pc_metrics(sorting_result_simple): - sorting_result = sorting_result_simple - res1 = calculate_pc_metrics(sorting_result, n_jobs=1, progress_bar=True) +def test_calculate_pc_metrics(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + res1 = calculate_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True) res1 = pd.DataFrame(res1) - res2 = calculate_pc_metrics(sorting_result, n_jobs=2, progress_bar=True) + res2 = calculate_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True) res2 = pd.DataFrame(res2) for k in res1.columns: @@ -68,20 +68,20 @@ def test_calculate_pc_metrics(sorting_result_simple): assert np.array_equal(res1[k].values[mask], res2[k].values[mask]) -def test_nearest_neighbors_isolation(sorting_result_simple): - sorting_result = sorting_result_simple - this_unit_id = sorting_result.unit_ids[0] - nearest_neighbors_isolation(sorting_result, this_unit_id) +def test_nearest_neighbors_isolation(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + this_unit_id = sorting_analyzer.unit_ids[0] + nearest_neighbors_isolation(sorting_analyzer, this_unit_id) -def test_nearest_neighbors_noise_overlap(sorting_result_simple): - sorting_result = sorting_result_simple - this_unit_id = sorting_result.unit_ids[0] - nearest_neighbors_noise_overlap(sorting_result, this_unit_id) +def test_nearest_neighbors_noise_overlap(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + this_unit_id = sorting_analyzer.unit_ids[0] + nearest_neighbors_noise_overlap(sorting_analyzer, this_unit_id) if __name__ == "__main__": - sorting_result = _sorting_result_simple() - test_calculate_pc_metrics(sorting_result) - test_nearest_neighbors_isolation(sorting_result) - test_nearest_neighbors_noise_overlap(sorting_result) + sorting_analyzer = _sorting_analyzer_simple() + test_calculate_pc_metrics(sorting_analyzer) + test_nearest_neighbors_isolation(sorting_analyzer) + test_nearest_neighbors_noise_overlap(sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 4d94371de5..4f83bc8986 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -7,7 +7,7 @@ from spikeinterface.core import ( generate_ground_truth_recording, - start_sorting_result, + create_sorting_analyzer, NumpySorting, aggregate_units, ) @@ -27,7 +27,7 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def get_sorting_result(seed=2205): +def get_sorting_analyzer(seed=2205): # we need high firing rate for amplitude_cutoff recording, sorting = generate_ground_truth_recording( durations=[ @@ -51,30 +51,30 @@ def get_sorting_result(seed=2205): seed=seed, ) - sorting_result = start_sorting_result(sorting, recording, format="memory", sparse=True) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - sorting_result.select_random_spikes(max_spikes_per_unit=300, seed=seed) - sorting_result.compute("noise_levels") - sorting_result.compute("waveforms", **job_kwargs) - sorting_result.compute("templates") - sorting_result.compute("spike_amplitudes", **job_kwargs) + sorting_analyzer.select_random_spikes(max_spikes_per_unit=300, seed=seed) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - return sorting_result + return sorting_analyzer @pytest.fixture(scope="module") -def sorting_result_simple(): - sorting_result = get_sorting_result(seed=2205) - return sorting_result +def sorting_analyzer_simple(): + sorting_analyzer = get_sorting_analyzer(seed=2205) + return sorting_analyzer -def test_compute_quality_metrics(sorting_result_simple): - sorting_result = sorting_result_simple - print(sorting_result) +def test_compute_quality_metrics(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + print(sorting_analyzer) # without PCs metrics = compute_quality_metrics( - sorting_result, + sorting_analyzer, metric_names=["snr"], qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, @@ -82,15 +82,15 @@ def test_compute_quality_metrics(sorting_result_simple): ) # print(metrics) - qm = sorting_result.get_extension("quality_metrics") + qm = sorting_analyzer.get_extension("quality_metrics") assert qm.params["qm_params"]["isi_violation"]["isi_threshold_ms"] == 2 assert "snr" in metrics.columns assert "isolation_distance" not in metrics.columns # with PCs - sorting_result.compute("principal_components") + sorting_analyzer.compute("principal_components") metrics = compute_quality_metrics( - sorting_result, + sorting_analyzer, metric_names=None, qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, @@ -100,11 +100,11 @@ def test_compute_quality_metrics(sorting_result_simple): assert "isolation_distance" in metrics.columns -def test_compute_quality_metrics_recordingless(sorting_result_simple): +def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): - sorting_result = sorting_result_simple + sorting_analyzer = sorting_analyzer_simple metrics = compute_quality_metrics( - sorting_result, + sorting_analyzer, metric_names=None, qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, @@ -112,15 +112,15 @@ def test_compute_quality_metrics_recordingless(sorting_result_simple): ) # make a copy and make it recordingless - sorting_result_norec = sorting_result.save_as(format="memory") - sorting_result_norec.delete_extension("quality_metrics") - sorting_result_norec._recording = None - assert not sorting_result_norec.has_recording() + sorting_analyzer_norec = sorting_analyzer.save_as(format="memory") + sorting_analyzer_norec.delete_extension("quality_metrics") + sorting_analyzer_norec._recording = None + assert not sorting_analyzer_norec.has_recording() - print(sorting_result_norec) + print(sorting_analyzer_norec) metrics_norec = compute_quality_metrics( - sorting_result_norec, + sorting_analyzer_norec, metric_names=None, qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, @@ -134,26 +134,26 @@ def test_compute_quality_metrics_recordingless(sorting_result_simple): assert np.allclose(metrics[metric_name].values, metrics_norec[metric_name].values, rtol=1e-02) -def test_empty_units(sorting_result_simple): - sorting_result = sorting_result_simple +def test_empty_units(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple empty_spike_train = np.array([], dtype="int64") empty_sorting = NumpySorting.from_unit_dict( {100: empty_spike_train, 200: empty_spike_train, 300: empty_spike_train}, - sampling_frequency=sorting_result.sampling_frequency, + sampling_frequency=sorting_analyzer.sampling_frequency, ) - sorting_empty = aggregate_units([sorting_result.sorting, empty_sorting]) + sorting_empty = aggregate_units([sorting_analyzer.sorting, empty_sorting]) assert len(sorting_empty.get_empty_unit_ids()) == 3 - sorting_result_empty = start_sorting_result(sorting_empty, sorting_result.recording, format="memory") - sorting_result_empty.select_random_spikes(max_spikes_per_unit=300, seed=2205) - sorting_result_empty.compute("noise_levels") - sorting_result_empty.compute("waveforms", **job_kwargs) - sorting_result_empty.compute("templates") - sorting_result_empty.compute("spike_amplitudes", **job_kwargs) + sorting_analyzer_empty = create_sorting_analyzer(sorting_empty, sorting_analyzer.recording, format="memory") + sorting_analyzer_empty.select_random_spikes(max_spikes_per_unit=300, seed=2205) + sorting_analyzer_empty.compute("noise_levels") + sorting_analyzer_empty.compute("waveforms", **job_kwargs) + sorting_analyzer_empty.compute("templates") + sorting_analyzer_empty.compute("spike_amplitudes", **job_kwargs) metrics_empty = compute_quality_metrics( - sorting_result_empty, + sorting_analyzer_empty, metric_names=None, qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, @@ -299,9 +299,9 @@ def test_empty_units(sorting_result_simple): if __name__ == "__main__": - sorting_result = get_sorting_result() - print(sorting_result) + sorting_analyzer = get_sorting_analyzer() + print(sorting_analyzer) - test_compute_quality_metrics(sorting_result) - test_compute_quality_metrics_recordingless(sorting_result) - test_empty_units(sorting_result) + test_compute_quality_metrics(sorting_analyzer) + test_compute_quality_metrics_recordingless(sorting_analyzer) + test_empty_units(sorting_analyzer) diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index 6bb59c6c4e..b7c94c6238 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -2,7 +2,7 @@ import numpy as np from pathlib import Path -from spikeinterface import NumpySorting, start_sorting_result, get_noise_levels, compute_sparsity +from spikeinterface import NumpySorting, create_sorting_analyzer, get_noise_levels, compute_sparsity from spikeinterface.sortingcomponents.matching import find_spikes_from_templates, matching_methods @@ -13,34 +13,34 @@ job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) -def get_sorting_result(): +def get_sorting_analyzer(): recording, sorting = make_dataset() - sorting_result = start_sorting_result(sorting, recording, sparse=False) - sorting_result.select_random_spikes() - sorting_result.compute("fast_templates", **job_kwargs) - sorting_result.compute("noise_levels") - return sorting_result + sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("fast_templates", **job_kwargs) + sorting_analyzer.compute("noise_levels") + return sorting_analyzer -@pytest.fixture(name="sorting_result", scope="module") -def sorting_result_fixture(): - return get_sorting_result() +@pytest.fixture(name="sorting_analyzer", scope="module") +def sorting_analyzer_fixture(): + return get_sorting_analyzer() @pytest.mark.parametrize("method", matching_methods.keys()) -def test_find_spikes_from_templates(method, sorting_result): - recording = sorting_result.recording +def test_find_spikes_from_templates(method, sorting_analyzer): + recording = sorting_analyzer.recording # waveform = waveform_extractor.get_waveforms(waveform_extractor.unit_ids[0]) # num_waveforms, _, _ = waveform.shape # assert num_waveforms != 0 - templates = sorting_result.get_extension("fast_templates").get_data(outputs="Templates") - sparsity = compute_sparsity(sorting_result, method="snr", threshold=0.5) + templates = sorting_analyzer.get_extension("fast_templates").get_data(outputs="Templates") + sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=0.5) templates = templates.to_sparse(sparsity) - noise_levels = sorting_result.get_extension("noise_levels").get_data() + noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() - # sorting_result + # sorting_analyzer method_kwargs_all = {"templates": templates, "noise_levels": noise_levels} method_kwargs = {} # method_kwargs["wobble"] = { @@ -65,15 +65,15 @@ def test_find_spikes_from_templates(method, sorting_result): # import matplotlib.pyplot as plt # import spikeinterface.full as si - # sorting_result.compute("waveforms") - # sorting_result.compute("templates") + # sorting_analyzer.compute("waveforms") + # sorting_analyzer.compute("templates") - # gt_sorting = sorting_result.sorting + # gt_sorting = sorting_analyzer.sorting # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) - # metrics = si.compute_quality_metrics(sorting_result, metric_names=["snr"]) + # metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) # fig, ax = plt.subplots() # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) @@ -83,11 +83,11 @@ def test_find_spikes_from_templates(method, sorting_result): if __name__ == "__main__": - sorting_result = get_sorting_result() + sorting_analyzer = get_sorting_analyzer() # method = "naive" # method = "tdc-peeler" # method = "circus" # method = "circus-omp-svd" method = "wobble" - test_find_spikes_from_templates(method, sorting_result) + test_find_spikes_from_templates(method, sorting_analyzer) diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 4e640ea044..0226d706d1 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -14,7 +14,7 @@ from spikeinterface.postprocessing import compute_principal_components from spikeinterface.core import BaseRecording from spikeinterface.core.sparsity import ChannelSparsity -from spikeinterface import NumpySorting, start_sorting_result +from spikeinterface import NumpySorting, create_sorting_analyzer from spikeinterface.core.job_tools import _shared_job_kwargs_doc from .waveform_utils import to_temporal_representation, from_temporal_representation @@ -139,12 +139,12 @@ def fit( # Creates a numpy sorting object where the spike times are the peak times and the unit ids are the peak channel sorting = NumpySorting.from_peaks(peaks, recording.sampling_frequency, recording.channel_ids) - # TODO alessio, herberto : the fitting is done with a SortingResult which is a postprocessing object, I think we should not do this for a component - sorting_result = start_sorting_result(sorting, recording, sparse=True) - sorting_result.select_random_spikes() - sorting_result.compute("waveforms", ms_before=ms_before, ms_after=ms_after) - sorting_result.compute("principal_components", n_components=n_components, mode="by_channel_global", whiten=whiten) - pca_model = sorting_result.get_extension("principal_components").get_pca_model() + # TODO alessio, herberto : the fitting is done with a SortingAnalyzer which is a postprocessing object, I think we should not do this for a component + sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True) + sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("waveforms", ms_before=ms_before, ms_after=ms_after) + sorting_analyzer.compute("principal_components", n_components=n_components, mode="by_channel_global", whiten=whiten) + pca_model = sorting_analyzer.get_extension("principal_components").get_pca_model() params = { "ms_before": ms_before, diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index f865542018..5bd7b9679d 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -6,7 +6,7 @@ from .base import BaseWidget, to_attr from .utils import get_some_colors -from ..core import SortingResult +from ..core import SortingAnalyzer class AllAmplitudesDistributionsWidget(BaseWidget): @@ -15,33 +15,33 @@ class AllAmplitudesDistributionsWidget(BaseWidget): Parameters ---------- - sorting_result: SortingResult - The SortingResult + sorting_analyzer: SortingAnalyzer + The SortingAnalyzer unit_ids: list List of unit ids, default None unit_colors: None or dict Dict of colors with key: unit, value: color, default None """ - def __init__(self, sorting_result: SortingResult, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs): + def __init__(self, sorting_analyzer: SortingAnalyzer, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs): - sorting_result = self.ensure_sorting_result(sorting_result) - self.check_extensions(sorting_result, "spike_amplitudes") + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "spike_amplitudes") - amplitudes = sorting_result.get_extension("spike_amplitudes").get_data() + amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() - num_segments = sorting_result.get_num_segments() + num_segments = sorting_analyzer.get_num_segments() if unit_ids is None: - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids if unit_colors is None: - unit_colors = get_some_colors(sorting_result.unit_ids) + unit_colors = get_some_colors(sorting_analyzer.unit_ids) amplitudes_by_units = {} - spikes = sorting_result.sorting.to_spike_vector() + spikes = sorting_analyzer.sorting.to_spike_vector() for unit_id in unit_ids: - unit_index = sorting_result.sorting.id_to_index(unit_id) + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) spike_mask = spikes["unit_index"] == unit_index amplitudes_by_units[unit_id] = amplitudes[spike_mask] diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 1867cae7da..efbf6f3f32 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -6,7 +6,7 @@ from .base import BaseWidget, to_attr from .utils import get_some_colors -from ..core.sortingresult import SortingResult +from ..core.sortinganalyzer import SortingAnalyzer class AmplitudesWidget(BaseWidget): @@ -15,7 +15,7 @@ class AmplitudesWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult + sorting_analyzer : SortingAnalyzer The input waveform extractor unit_ids : list or None, default: None List of unit ids @@ -38,7 +38,7 @@ class AmplitudesWidget(BaseWidget): def __init__( self, - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, unit_ids=None, unit_colors=None, segment_index=None, @@ -51,12 +51,12 @@ def __init__( **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - sorting = sorting_result.sorting - self.check_extensions(sorting_result, "spike_amplitudes") + sorting = sorting_analyzer.sorting + self.check_extensions(sorting_analyzer, "spike_amplitudes") - amplitudes = sorting_result.get_extension("spike_amplitudes").get_data(outputs="by_unit") + amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit") if unit_ids is None: unit_ids = sorting.unit_ids @@ -71,7 +71,7 @@ def __init__( else: segment_index = 0 amplitudes_segment = amplitudes[segment_index] - total_duration = sorting_result.get_num_samples(segment_index) / sorting_result.sampling_frequency + total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency spiketrains_segment = {} for i, unit_id in enumerate(sorting.unit_ids): @@ -101,7 +101,7 @@ def __init__( bins = 100 plot_data = dict( - sorting_result=sorting_result, + sorting_analyzer=sorting_analyzer, amplitudes=amplitudes_to_plot, unit_ids=unit_ids, unit_colors=unit_colors, @@ -189,7 +189,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() cm = 1 / 2.54 - we = data_plot["sorting_result"] + we = data_plot["sorting_analyzer"] width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index eb715b1c3a..3bfacfc370 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -5,7 +5,7 @@ global default_backend_ default_backend_ = "matplotlib" -from ..core import SortingResult, BaseSorting +from ..core import SortingAnalyzer, BaseSorting from ..core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor @@ -106,39 +106,39 @@ def do_plot(self): func(self.data_plot, **self.backend_kwargs) @classmethod - def ensure_sorting_result(cls, input): - # internal help to accept both SortingResult or MockWaveformExtractor for a ploter - if isinstance(input, SortingResult): + def ensure_sorting_analyzer(cls, input): + # internal help to accept both SortingAnalyzer or MockWaveformExtractor for a ploter + if isinstance(input, SortingAnalyzer): return input elif isinstance(input, MockWaveformExtractor): - return input.sorting_result + return input.sorting_analyzer else: return input @classmethod def ensure_sorting(cls, input): - # internal help to accept both Sorting or SortingResult or MockWaveformExtractor for a ploter + # internal help to accept both Sorting or SortingAnalyzer or MockWaveformExtractor for a ploter if isinstance(input, BaseSorting): return input - elif isinstance(input, SortingResult): + elif isinstance(input, SortingAnalyzer): return input.sorting elif isinstance(input, MockWaveformExtractor): - return input.sorting_result.sorting + return input.sorting_analyzer.sorting else: return input @staticmethod - def check_extensions(sorting_result, extensions): + def check_extensions(sorting_analyzer, extensions): if isinstance(extensions, str): extensions = [extensions] error_msg = "" raise_error = False for extension in extensions: - if not sorting_result.has_extension(extension): + if not sorting_analyzer.has_extension(extension): raise_error = True error_msg += ( f"The {extension} waveform extension is required for this widget. " - f"Run the `sorting_result.compute('{extension}', ...)` to compute it.\n" + f"Run the `sorting_analyzer.compute('{extension}', ...)` to compute it.\n" ) if raise_error: raise Exception(error_msg) diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 087a97df9e..6eb565d56a 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -4,7 +4,7 @@ from typing import Union from .base import BaseWidget, to_attr -from ..core.sortingresult import SortingResult +from ..core.sortinganalyzer import SortingAnalyzer from ..core.basesorting import BaseSorting from ..postprocessing import compute_correlograms @@ -15,7 +15,7 @@ class CrossCorrelogramsWidget(BaseWidget): Parameters ---------- - sorting_result_or_sorting : SortingResult or BaseSorting + sorting_analyzer_or_sorting : SortingAnalyzer or BaseSorting The object to compute/get crosscorrelograms from unit_ids list or None, default: None List of unit ids @@ -23,10 +23,10 @@ class CrossCorrelogramsWidget(BaseWidget): For sortingview backend. Threshold for computing pair-wise cross-correlograms. If template similarity between two units is below this threshold, the cross-correlogram is not displayed window_ms : float, default: 100.0 - Window for CCGs in ms. If correlograms are already computed (e.g. with SortingResult), + Window for CCGs in ms. If correlograms are already computed (e.g. with SortingAnalyzer), this argument is ignored bin_ms : float, default: 1.0 - Bin size in ms. If correlograms are already computed (e.g. with SortingResult), + Bin size in ms. If correlograms are already computed (e.g. with SortingAnalyzer), this argument is ignored hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed @@ -36,7 +36,7 @@ class CrossCorrelogramsWidget(BaseWidget): def __init__( self, - sorting_result_or_sorting: Union[SortingResult, BaseSorting], + sorting_analyzer_or_sorting: Union[SortingAnalyzer, BaseSorting], unit_ids=None, min_similarity_for_correlograms=0.2, window_ms=100.0, @@ -46,21 +46,21 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result_or_sorting = self.ensure_sorting_result(sorting_result_or_sorting) + sorting_analyzer_or_sorting = self.ensure_sorting_analyzer(sorting_analyzer_or_sorting) if min_similarity_for_correlograms is None: min_similarity_for_correlograms = 0 similarity = None - if isinstance(sorting_result_or_sorting, SortingResult): - sorting = sorting_result_or_sorting.sorting - self.check_extensions(sorting_result_or_sorting, "correlograms") - ccc = sorting_result_or_sorting.get_extension("correlograms") + if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + sorting = sorting_analyzer_or_sorting.sorting + self.check_extensions(sorting_analyzer_or_sorting, "correlograms") + ccc = sorting_analyzer_or_sorting.get_extension("correlograms") ccgs, bins = ccc.get_data() if min_similarity_for_correlograms > 0: - self.check_extensions(sorting_result_or_sorting, "template_similarity") - similarity = sorting_result_or_sorting.get_extension("template_similarity").get_data() + self.check_extensions(sorting_analyzer_or_sorting, "template_similarity") + similarity = sorting_analyzer_or_sorting.get_extension("template_similarity").get_data() else: - sorting = sorting_result_or_sorting + sorting = sorting_analyzer_or_sorting ccgs, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms) if unit_ids is None: diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index 5a9b77dfcd..3f9ee549be 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -1,7 +1,7 @@ from __future__ import annotations from .metrics import MetricsBaseWidget -from ..core.sortingresult import SortingResult +from ..core.sortinganalyzer import SortingAnalyzer class QualityMetricsWidget(MetricsBaseWidget): @@ -10,7 +10,7 @@ class QualityMetricsWidget(MetricsBaseWidget): Parameters ---------- - sorting_result : SortingResult + sorting_analyzer : SortingAnalyzer The object to get quality metrics from unit_ids: list or None, default: None List of unit ids @@ -26,7 +26,7 @@ class QualityMetricsWidget(MetricsBaseWidget): def __init__( self, - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, unit_ids=None, include_metrics=None, skip_metrics=None, @@ -35,11 +35,11 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) - self.check_extensions(sorting_result, "quality_metrics") - quality_metrics = sorting_result.get_extension("quality_metrics").get_data() + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "quality_metrics") + quality_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting MetricsBaseWidget.__init__( self, diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 1c01e99a69..78293757ec 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -11,7 +11,7 @@ from .unit_templates import UnitTemplatesWidget -from ..core import SortingResult +from ..core import SortingAnalyzer class SortingSummaryWidget(BaseWidget): @@ -20,13 +20,13 @@ class SortingSummaryWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult - The SortingResult object + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object unit_ids : list or None, default: None List of unit ids sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply - If SortingResult is already sparse, the argument is ignored + If SortingAnalyzer is already sparse, the argument is ignored max_amplitudes_per_unit : int or None, default: None Maximum number of spikes per unit for plotting amplitudes. If None, all spikes are plotted @@ -47,7 +47,7 @@ class SortingSummaryWidget(BaseWidget): def __init__( self, - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, unit_ids=None, sparsity=None, max_amplitudes_per_unit=None, @@ -58,15 +58,15 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) - self.check_extensions(sorting_result, ["correlograms", "spike_amplitudes", "unit_locations", "similarity"]) - sorting = sorting_result.sorting + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "similarity"]) + sorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.get_unit_ids() plot_data = dict( - sorting_result=sorting_result, + sorting_analyzer=sorting_analyzer, unit_ids=unit_ids, sparsity=sparsity, min_similarity_for_correlograms=min_similarity_for_correlograms, @@ -83,7 +83,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) - sorting_result = dp.sorting_result + sorting_analyzer = dp.sorting_analyzer unit_ids = dp.unit_ids sparsity = dp.sparsity min_similarity_for_correlograms = dp.min_similarity_for_correlograms @@ -91,7 +91,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = make_serializable(dp.unit_ids) v_spike_amplitudes = AmplitudesWidget( - sorting_result, + sorting_analyzer, unit_ids=unit_ids, max_spikes_per_unit=dp.max_amplitudes_per_unit, hide_unit_selector=True, @@ -100,7 +100,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): backend="sortingview", ).view v_average_waveforms = UnitTemplatesWidget( - sorting_result, + sorting_analyzer, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True, @@ -109,7 +109,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): backend="sortingview", ).view v_cross_correlograms = CrossCorrelogramsWidget( - sorting_result, + sorting_analyzer, unit_ids=unit_ids, min_similarity_for_correlograms=min_similarity_for_correlograms, hide_unit_selector=True, @@ -119,7 +119,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ).view v_unit_locations = UnitLocationsWidget( - sorting_result, + sorting_analyzer, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, @@ -128,7 +128,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ).view w = TemplateSimilarityWidget( - sorting_result, + sorting_analyzer, unit_ids=unit_ids, immediate_plot=False, generate_url=False, @@ -147,7 +147,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # unit ids v_units_table = generate_unit_table_view( - dp.sorting_result.sorting, dp.unit_table_properties, similarity_scores=similarity_scores + dp.sorting_analyzer.sorting, dp.unit_table_properties, similarity_scores=similarity_scores ) if dp.curation: diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index b1791f0912..94c9def630 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -4,7 +4,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from ..core.sortingresult import SortingResult +from ..core.sortinganalyzer import SortingAnalyzer class SpikeLocationsWidget(BaseWidget): @@ -13,7 +13,7 @@ class SpikeLocationsWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult + sorting_analyzer : SortingAnalyzer The object to get spike locations from unit_ids : list or None, default: None List of unit ids @@ -40,7 +40,7 @@ class SpikeLocationsWidget(BaseWidget): def __init__( self, - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, unit_ids=None, segment_index=None, max_spikes_per_unit=500, @@ -53,16 +53,16 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) - self.check_extensions(sorting_result, "spike_locations") + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "spike_locations") - spike_locations_by_units = sorting_result.get_extension("spike_locations").get_data(outputs="by_unit") + spike_locations_by_units = sorting_analyzer.get_extension("spike_locations").get_data(outputs="by_unit") - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting - channel_ids = sorting_result.channel_ids - channel_locations = sorting_result.get_channel_locations() - probegroup = sorting_result.get_probegroup() + channel_ids = sorting_analyzer.channel_ids + channel_locations = sorting_analyzer.get_channel_locations() + probegroup = sorting_analyzer.get_probegroup() if sorting.get_num_segments() > 1: assert segment_index is not None, "Specify segment index for multi-segment object" diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 7515bc5d64..d354a82086 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -7,7 +7,7 @@ from .traces import TracesWidget from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel -from ..core.sortingresult import SortingResult +from ..core.sortinganalyzer import SortingAnalyzer from ..core.baserecording import BaseRecording from ..core.basesorting import BaseSorting from ..postprocessing import compute_unit_locations @@ -19,8 +19,8 @@ class SpikesOnTracesWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult - The SortingResult + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer channel_ids : list or None, default: None The channel ids to display unit_ids : list or None, default: None @@ -31,7 +31,7 @@ class SpikesOnTracesWidget(BaseWidget): List with start time and end time in seconds sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply - If SortingResult is already sparse, the argument is ignored + If SortingAnalyzer is already sparse, the argument is ignored unit_colors : dict or None, default: None If given, a dictionary with unit ids as keys and colors as values If None, then the get_unit_colors() is internally used. (matplotlib backend) @@ -62,7 +62,7 @@ class SpikesOnTracesWidget(BaseWidget): def __init__( self, - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, segment_index=None, channel_ids=None, unit_ids=None, @@ -83,10 +83,10 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) - self.check_extensions(sorting_result, "unit_locations") + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "unit_locations") - sorting: BaseSorting = sorting_result.sorting + sorting: BaseSorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.get_unit_ids() @@ -96,22 +96,22 @@ def __init__( unit_colors = get_unit_colors(sorting) # sparsity is done on all the units even if unit_ids is a few ones because some backend need then all - if sorting_result.is_sparse(): - sparsity = sorting_result.sparsity + if sorting_analyzer.is_sparse(): + sparsity = sorting_analyzer.sparsity else: if sparsity is None: # in this case, we construct a sparsity dictionary only with the best channel - extremum_channel_ids = get_template_extremum_channel(sorting_result) + extremum_channel_ids = get_template_extremum_channel(sorting_analyzer) unit_id_to_channel_ids = {u: [ch] for u, ch in extremum_channel_ids.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, - unit_ids=sorting_result.unit_ids, - channel_ids=sorting_result.channel_ids, + unit_ids=sorting_analyzer.unit_ids, + channel_ids=sorting_analyzer.channel_ids, ) else: assert isinstance(sparsity, ChannelSparsity) - unit_locations = sorting_result.get_extension("unit_locations").get_data(outputs="by_unit") + unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") options = dict( segment_index=segment_index, @@ -130,7 +130,7 @@ def __init__( ) plot_data = dict( - sorting_result=sorting_result, + sorting_analyzer=sorting_analyzer, options=options, unit_ids=unit_ids, sparsity=sparsity, @@ -148,9 +148,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.lines import Line2D dp = to_attr(data_plot) - sorting_result = dp.sorting_result - recording = sorting_result.recording - sorting = sorting_result.sorting + sorting_analyzer = dp.sorting_analyzer + recording = sorting_analyzer.recording + sorting = sorting_analyzer.sorting # first plot time series traces_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) @@ -248,7 +248,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() dp = to_attr(data_plot) - sorting_result = dp.sorting_result + sorting_analyzer = dp.sorting_analyzer ratios = [0.2, 0.8] @@ -260,7 +260,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # plot timeseries self._traces_widget = TracesWidget( - sorting_result.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts + sorting_analyzer.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts ) self.ax = self._traces_widget.ax self.axes = self._traces_widget.axes diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index c3b7d7f3e8..b80c863e75 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -1,7 +1,7 @@ from __future__ import annotations from .metrics import MetricsBaseWidget -from ..core.sortingresult import SortingResult +from ..core.sortinganalyzer import SortingAnalyzer class TemplateMetricsWidget(MetricsBaseWidget): @@ -10,7 +10,7 @@ class TemplateMetricsWidget(MetricsBaseWidget): Parameters ---------- - sorting_result : SortingResult + sorting_analyzer : SortingAnalyzer The object to get quality metrics from unit_ids : list or None, default: None List of unit ids @@ -26,7 +26,7 @@ class TemplateMetricsWidget(MetricsBaseWidget): def __init__( self, - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, unit_ids=None, include_metrics=None, skip_metrics=None, @@ -35,11 +35,11 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) - self.check_extensions(sorting_result, "template_metrics") - template_metrics = sorting_result.get_extension("template_metrics").get_data() + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "template_metrics") + template_metrics = sorting_analyzer.get_extension("template_metrics").get_data() - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting MetricsBaseWidget.__init__( self, diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 39883094cf..b469d9901f 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -3,7 +3,7 @@ import numpy as np from .base import BaseWidget, to_attr -from ..core.sortingresult import SortingResult +from ..core.sortinganalyzer import SortingAnalyzer class TemplateSimilarityWidget(BaseWidget): @@ -12,7 +12,7 @@ class TemplateSimilarityWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult + sorting_analyzer : SortingAnalyzer The object to get template similarity from unit_ids : list or None, default: None List of unit ids default: None @@ -29,7 +29,7 @@ class TemplateSimilarityWidget(BaseWidget): def __init__( self, - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, unit_ids=None, cmap="viridis", display_diagonal_values=False, @@ -38,13 +38,13 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) - self.check_extensions(sorting_result, "template_similarity") + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "template_similarity") - tsc = sorting_result.get_extension("template_similarity") + tsc = sorting_analyzer.get_extension("template_similarity") similarity = tsc.get_data().copy() - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.unit_ids else: diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 1cd1ba477f..54ca074f16 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -14,7 +14,7 @@ from spikeinterface import ( compute_sparsity, generate_ground_truth_recording, - start_sorting_result, + create_sorting_analyzer, ) @@ -76,24 +76,24 @@ def setUpClass(cls): job_kwargs = dict(n_jobs=-1) # create dense - cls.sorting_result_dense = start_sorting_result(cls.sorting, cls.recording, format="memory", sparse=False) - cls.sorting_result_dense.select_random_spikes() - cls.sorting_result_dense.compute(extensions_to_compute, **job_kwargs) + cls.sorting_analyzer_dense = create_sorting_analyzer(cls.sorting, cls.recording, format="memory", sparse=False) + cls.sorting_analyzer_dense.select_random_spikes() + cls.sorting_analyzer_dense.compute(extensions_to_compute, **job_kwargs) sw.set_default_plotter_backend("matplotlib") # make sparse waveforms - cls.sparsity_radius = compute_sparsity(cls.sorting_result_dense, method="radius", radius_um=50) - cls.sparsity_strict = compute_sparsity(cls.sorting_result_dense, method="radius", radius_um=20) - cls.sparsity_large = compute_sparsity(cls.sorting_result_dense, method="radius", radius_um=80) - cls.sparsity_best = compute_sparsity(cls.sorting_result_dense, method="best_channels", num_channels=5) + cls.sparsity_radius = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=50) + cls.sparsity_strict = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=20) + cls.sparsity_large = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=80) + cls.sparsity_best = compute_sparsity(cls.sorting_analyzer_dense, method="best_channels", num_channels=5) # create sparse - cls.sorting_result_sparse = start_sorting_result( + cls.sorting_analyzer_sparse = create_sorting_analyzer( cls.sorting, cls.recording, format="memory", sparsity=cls.sparsity_radius ) - cls.sorting_result_sparse.select_random_spikes() - cls.sorting_result_sparse.compute(extensions_to_compute, **job_kwargs) + cls.sorting_analyzer_sparse.select_random_spikes() + cls.sorting_analyzer_sparse.compute(extensions_to_compute, **job_kwargs) cls.skip_backends = ["ipywidgets", "ephyviewer"] # cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] @@ -152,34 +152,34 @@ def test_plot_spikes_on_traces(self): possible_backends = list(sw.SpikesOnTracesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_spikes_on_traces(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_spikes_on_traces(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_waveforms(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_waveforms(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_waveforms(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_waveforms( - self.sorting_result_dense, + self.sorting_analyzer_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_waveforms( - self.sorting_result_dense, + self.sorting_analyzer_dense, sparsity=self.sparsity_best, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_waveforms( - self.sorting_result_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) # extra sparsity sw.plot_unit_waveforms( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, unit_ids=unit_ids, backend=backend, @@ -188,7 +188,7 @@ def test_plot_unit_waveforms(self): # test "larger" sparsity with self.assertRaises(AssertionError): sw.plot_unit_waveforms( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, sparsity=self.sparsity_large, unit_ids=unit_ids, backend=backend, @@ -201,11 +201,11 @@ def test_plot_unit_templates(self): if backend not in self.skip_backends: print(f"Testing backend {backend}") print("Dense") - sw.plot_unit_templates(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_templates(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] print("Dense + radius") sw.plot_unit_templates( - self.sorting_result_dense, + self.sorting_analyzer_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, @@ -213,7 +213,7 @@ def test_plot_unit_templates(self): ) print("Dense + best") sw.plot_unit_templates( - self.sorting_result_dense, + self.sorting_analyzer_dense, sparsity=self.sparsity_best, unit_ids=unit_ids, backend=backend, @@ -222,7 +222,7 @@ def test_plot_unit_templates(self): # test different shadings print("Sparse") sw.plot_unit_templates( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, unit_ids=unit_ids, templates_percentile_shading=None, backend=backend, @@ -230,7 +230,7 @@ def test_plot_unit_templates(self): ) print("Sparse2") sw.plot_unit_templates( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, unit_ids=unit_ids, # templates_percentile_shading=None, scale=10, @@ -240,7 +240,7 @@ def test_plot_unit_templates(self): # test different shadings print("Sparse3") sw.plot_unit_templates( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, unit_ids=unit_ids, backend=backend, templates_percentile_shading=None, @@ -249,7 +249,7 @@ def test_plot_unit_templates(self): ) print("Sparse4") sw.plot_unit_templates( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, unit_ids=unit_ids, templates_percentile_shading=0.1, backend=backend, @@ -257,7 +257,7 @@ def test_plot_unit_templates(self): ) print("Extra sparsity") sw.plot_unit_templates( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, unit_ids=unit_ids, templates_percentile_shading=[1, 10, 90, 99], @@ -267,7 +267,7 @@ def test_plot_unit_templates(self): # test "larger" sparsity with self.assertRaises(AssertionError): sw.plot_unit_templates( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, sparsity=self.sparsity_large, unit_ids=unit_ids, backend=backend, @@ -275,7 +275,7 @@ def test_plot_unit_templates(self): ) if backend != "sortingview": sw.plot_unit_templates( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, unit_ids=unit_ids, templates_percentile_shading=[1, 5, 25, 75, 95, 99], backend=backend, @@ -285,7 +285,7 @@ def test_plot_unit_templates(self): # sortingview doesn't support more than 2 shadings with self.assertRaises(AssertionError): sw.plot_unit_templates( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, unit_ids=unit_ids, templates_percentile_shading=[1, 5, 25, 75, 95, 99], backend=backend, @@ -300,16 +300,16 @@ def test_plot_unit_waveforms_density_map(self): # on dense sw.plot_unit_waveforms_density_map( - self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) # on sparse sw.plot_unit_waveforms_density_map( - self.sorting_result_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) # externals parsity sw.plot_unit_waveforms_density_map( - self.sorting_result_dense, + self.sorting_analyzer_dense, sparsity=self.sparsity_radius, same_axis=False, unit_ids=unit_ids, @@ -319,7 +319,7 @@ def test_plot_unit_waveforms_density_map(self): # on sparse with same_axis sw.plot_unit_waveforms_density_map( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, sparsity=None, same_axis=True, unit_ids=unit_ids, @@ -362,12 +362,12 @@ def test_plot_crosscorrelograms(self): **self.backend_kwargs[backend], ) sw.plot_crosscorrelograms( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend], ) sw.plot_crosscorrelograms( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, min_similarity_for_correlograms=0.6, backend=backend, **self.backend_kwargs[backend], @@ -391,20 +391,20 @@ def test_plot_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_amplitudes(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) - unit_ids = self.sorting_result_dense.unit_ids[:4] + sw.plot_amplitudes(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) + unit_ids = self.sorting_analyzer_dense.unit_ids[:4] sw.plot_amplitudes( - self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) sw.plot_amplitudes( - self.sorting_result_dense, + self.sorting_analyzer_dense, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend], ) sw.plot_amplitudes( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, unit_ids=unit_ids, plot_histograms=True, backend=backend, @@ -415,12 +415,12 @@ def test_plot_all_amplitudes_distributions(self): possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - unit_ids = self.sorting_result_dense.unit_ids[:4] + unit_ids = self.sorting_analyzer_dense.unit_ids[:4] sw.plot_all_amplitudes_distributions( - self.sorting_result_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) sw.plot_all_amplitudes_distributions( - self.sorting_result_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) def test_plot_unit_locations(self): @@ -428,10 +428,10 @@ def test_plot_unit_locations(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_locations( - self.sorting_result_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) sw.plot_unit_locations( - self.sorting_result_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) def test_plot_spike_locations(self): @@ -439,53 +439,53 @@ def test_plot_spike_locations(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_spike_locations( - self.sorting_result_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) sw.plot_spike_locations( - self.sorting_result_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + self.sorting_analyzer_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) def test_plot_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_similarity(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_template_similarity(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_quality_metrics(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_quality_metrics(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_quality_metrics(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_quality_metrics(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_template_metrics(self): possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_metrics(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_template_metrics(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_metrics(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_metrics(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_depths(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_unit_depths(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_depths(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_depths(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): possible_backends = list(sw.UnitSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( - self.sorting_result_dense, - self.sorting_result_dense.sorting.unit_ids[0], + self.sorting_analyzer_dense, + self.sorting_analyzer_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_summary( - self.sorting_result_sparse, - self.sorting_result_sparse.sorting.unit_ids[0], + self.sorting_analyzer_sparse, + self.sorting_analyzer_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend], ) @@ -494,10 +494,10 @@ def test_plot_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.sorting_result_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_sorting_summary(self.sorting_result_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary( - self.sorting_result_sparse, + self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, backend=backend, **self.backend_kwargs[backend], @@ -531,7 +531,7 @@ def test_plot_unit_probe_map(self): possible_backends = list(sw.UnitProbeMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_probe_map(self.sorting_result_dense) + sw.plot_unit_probe_map(self.sorting_analyzer_dense) def test_plot_unit_presence(self): possible_backends = list(sw.UnitPresenceWidget.get_possible_backends()) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index b99bb2a274..982b34f0c5 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -16,8 +16,8 @@ class UnitDepthsWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult - The SortingResult object + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object unit_colors : dict or None, default: None If given, a dictionary with unit ids as keys and colors as values depth_axis : int, default: 1 @@ -26,26 +26,26 @@ class UnitDepthsWidget(BaseWidget): Sign of peak for amplitudes """ - def __init__(self, sorting_result, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs): + def __init__(self, sorting_analyzer, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs): - sorting_result = self.ensure_sorting_result(sorting_result) + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - unit_ids = sorting_result.sorting.unit_ids + unit_ids = sorting_analyzer.sorting.unit_ids if unit_colors is None: - unit_colors = get_unit_colors(sorting_result.sorting) + unit_colors = get_unit_colors(sorting_analyzer.sorting) colors = [unit_colors[unit_id] for unit_id in unit_ids] - self.check_extensions(sorting_result, "unit_locations") - ulc = sorting_result.get_extension("unit_locations") + self.check_extensions(sorting_analyzer, "unit_locations") + ulc = sorting_analyzer.get_extension("unit_locations") unit_locations = ulc.get_data(outputs="numpy") unit_depths = unit_locations[:, depth_axis] - unit_amplitudes = get_template_extremum_amplitude(sorting_result, peak_sign=peak_sign) + unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign) unit_amplitudes = np.abs([unit_amplitudes[unit_id] for unit_id in unit_ids]) - num_spikes = sorting_result.sorting.count_num_spikes_per_unit(outputs="array") + num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(outputs="array") plot_data = dict( unit_depths=unit_depths, diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index ec5660fdcc..3329c2183c 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -7,7 +7,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from ..core.sortingresult import SortingResult +from ..core.sortinganalyzer import SortingAnalyzer class UnitLocationsWidget(BaseWidget): @@ -16,8 +16,8 @@ class UnitLocationsWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult - The SortingResult that must contains "unit_locations" extension + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer that must contains "unit_locations" extension unit_ids : list or None, default: None List of unit ids with_channel_ids : bool, default: False @@ -37,7 +37,7 @@ class UnitLocationsWidget(BaseWidget): def __init__( self, - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, unit_ids=None, with_channel_ids=False, unit_colors=None, @@ -48,17 +48,17 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - self.check_extensions(sorting_result, "unit_locations") - ulc = sorting_result.get_extension("unit_locations") + self.check_extensions(sorting_analyzer, "unit_locations") + ulc = sorting_analyzer.get_extension("unit_locations") unit_locations = ulc.get_data(outputs="by_unit") - sorting = sorting_result.sorting + sorting = sorting_analyzer.sorting - channel_ids = sorting_result.channel_ids - channel_locations = sorting_result.get_channel_locations() - probegroup = sorting_result.get_probegroup() + channel_ids = sorting_analyzer.channel_ids + channel_locations = sorting_analyzer.get_channel_locations() + probegroup = sorting_analyzer.get_probegroup() if unit_colors is None: unit_colors = get_unit_colors(sorting) diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index 895ef6709c..034a0bda49 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -8,7 +8,7 @@ from .base import BaseWidget, to_attr # from .utils import get_unit_colors -from ..core.sortingresult import SortingResult +from ..core.sortinganalyzer import SortingAnalyzer from ..core.template_tools import _get_dense_templates_array @@ -20,7 +20,7 @@ class UnitProbeMapWidget(BaseWidget): Parameters ---------- - sorting_result: SortingResult + sorting_analyzer: SortingAnalyzer unit_ids: list List of unit ids. channel_ids: list @@ -33,7 +33,7 @@ class UnitProbeMapWidget(BaseWidget): def __init__( self, - sorting_result, + sorting_analyzer, unit_ids=None, channel_ids=None, animated=None, @@ -42,17 +42,17 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) if unit_ids is None: - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids self.unit_ids = unit_ids if channel_ids is None: - channel_ids = sorting_result.channel_ids + channel_ids = sorting_analyzer.channel_ids self.channel_ids = channel_ids data_plot = dict( - sorting_result=sorting_result, + sorting_analyzer=sorting_analyzer, unit_ids=unit_ids, channel_ids=channel_ids, animated=animated, @@ -76,13 +76,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - sorting_result = dp.sorting_result - probe = sorting_result.get_probe() + sorting_analyzer = dp.sorting_analyzer + probe = sorting_analyzer.get_probe() probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - templates = _get_dense_templates_array(sorting_result, return_scaled=True) - templates = templates[sorting_result.sorting.ids_to_indices(dp.unit_ids), :, :] + templates = _get_dense_templates_array(sorting_analyzer, return_scaled=True) + templates = templates[sorting_analyzer.sorting.ids_to_indices(dp.unit_ids), :, :] all_poly_contact = [] for i, unit_id in enumerate(dp.unit_ids): diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 09d0dfa2c9..ea6476784e 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -21,22 +21,22 @@ class UnitSummaryWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult - The SortingResult object + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object unit_id : int or str The unit id to plot the summary of unit_colors : dict or None, default: None If given, a dictionary with unit ids as keys and colors as values, sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. - If SortingResult is already sparse, the argument is ignored + If SortingAnalyzer is already sparse, the argument is ignored """ # possible_backends = {} def __init__( self, - sorting_result, + sorting_analyzer, unit_id, unit_colors=None, sparsity=None, @@ -45,13 +45,13 @@ def __init__( **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) if unit_colors is None: - unit_colors = get_unit_colors(sorting_result.sorting) + unit_colors = get_unit_colors(sorting_analyzer.sorting) plot_data = dict( - sorting_result=sorting_result, + sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, @@ -66,7 +66,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) unit_id = dp.unit_id - sorting_result = dp.sorting_result + sorting_analyzer = dp.sorting_analyzer unit_colors = dp.unit_colors sparsity = dp.sparsity @@ -83,17 +83,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 3 - if sorting_result.has_extension("correlograms") or sorting_result.has_extension("spike_amplitudes"): + if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"): ncols += 1 - if sorting_result.has_extension("spike_amplitudes"): + if sorting_analyzer.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) - if sorting_result.has_extension("unit_locations"): + if sorting_analyzer.has_extension("unit_locations"): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( - sorting_result, + sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, @@ -101,7 +101,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax=ax1, ) - unit_locations = sorting_result.get_extension("unit_locations").get_data(outputs="by_unit") + unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] ax1.set_xlim(x - 80, x + 80) @@ -112,7 +112,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax2 = fig.add_subplot(gs[:2, 1]) w = UnitWaveformsWidget( - sorting_result, + sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, plot_templates=True, @@ -127,7 +127,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax3 = fig.add_subplot(gs[:2, 2]) UnitWaveformDensityMapWidget( - sorting_result, + sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, @@ -137,10 +137,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) ax3.set_ylabel(None) - if sorting_result.has_extension("correlograms"): + if sorting_analyzer.has_extension("correlograms"): ax4 = fig.add_subplot(gs[:2, 3]) AutoCorrelogramsWidget( - sorting_result, + sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, backend="matplotlib", @@ -150,12 +150,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_title(None) ax4.set_yticks([]) - if sorting_result.has_extension("spike_amplitudes"): + if sorting_analyzer.has_extension("spike_amplitudes"): ax5 = fig.add_subplot(gs[2, :3]) ax6 = fig.add_subplot(gs[2, 3]) axes = np.array([ax5, ax6]) AmplitudesWidget( - sorting_result, + sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 65ce40edf0..db3ddbf6b9 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -5,7 +5,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from ..core import ChannelSparsity, SortingResult +from ..core import ChannelSparsity, SortingAnalyzer from ..core.basesorting import BaseSorting from ..core.template_tools import _get_dense_templates_array @@ -16,8 +16,8 @@ class UnitWaveformsWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult - The SortingResult + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer channel_ids: list or None, default: None The channel ids to display unit_ids : list or None, default: None @@ -26,7 +26,7 @@ class UnitWaveformsWidget(BaseWidget): If True, templates are plotted over the waveforms sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply - If SortingResult is already sparse, the argument is ignored + If SortingAnalyzer is already sparse, the argument is ignored set_title : bool, default: True Create a plot title with the unit number if True plot_channels : bool, default: False @@ -77,7 +77,7 @@ class UnitWaveformsWidget(BaseWidget): def __init__( self, - sorting_result: SortingResult, + sorting_analyzer: SortingAnalyzer, channel_ids=None, unit_ids=None, plot_waveforms=True, @@ -105,26 +105,26 @@ def __init__( **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) - sorting: BaseSorting = sorting_result.sorting + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + sorting: BaseSorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting.unit_ids if channel_ids is None: - channel_ids = sorting_result.channel_ids + channel_ids = sorting_analyzer.channel_ids if unit_colors is None: unit_colors = get_unit_colors(sorting) - channel_locations = sorting_result.get_channel_locations()[sorting_result.channel_ids_to_indices(channel_ids)] + channel_locations = sorting_analyzer.get_channel_locations()[sorting_analyzer.channel_ids_to_indices(channel_ids)] extra_sparsity = False - if sorting_result.is_sparse(): + if sorting_analyzer.is_sparse(): if sparsity is None: - sparsity = sorting_result.sparsity + sparsity = sorting_analyzer.sparsity else: # assert provided sparsity is a subset of waveform sparsity - combined_mask = np.logical_or(sorting_result.sparsity.mask, sparsity.mask) - assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_result.sparsity.mask, 1) == 0), ( + combined_mask = np.logical_or(sorting_analyzer.sparsity.mask, sparsity.mask) + assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer.sparsity.mask, 1) == 0), ( "The provided 'sparsity' needs to include only the sparse channels " "used to extract waveforms (for example, by using a smaller 'radius_um')." ) @@ -132,34 +132,34 @@ def __init__( else: if sparsity is None: # in this case, we construct a dense sparsity - unit_id_to_channel_ids = {u: sorting_result.channel_ids for u in sorting_result.unit_ids} + unit_id_to_channel_ids = {u: sorting_analyzer.channel_ids for u in sorting_analyzer.unit_ids} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, - unit_ids=sorting_result.unit_ids, - channel_ids=sorting_result.channel_ids, + unit_ids=sorting_analyzer.unit_ids, + channel_ids=sorting_analyzer.channel_ids, ) else: assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" # get templates - ext = sorting_result.get_extension("templates") + ext = sorting_analyzer.get_extension("templates") assert ext is not None, "plot_waveforms() need extension 'templates'" templates = ext.get_templates(unit_ids=unit_ids, operator="average") - templates_shading = self._get_template_shadings(sorting_result, unit_ids, templates_percentile_shading) + templates_shading = self._get_template_shadings(sorting_analyzer, unit_ids, templates_percentile_shading) xvectors, y_scale, y_offset, delta_x = get_waveforms_scales( - sorting_result, templates, channel_locations, x_offset_units + sorting_analyzer, templates, channel_locations, x_offset_units ) wfs_by_ids = {} if plot_waveforms: - wf_ext = sorting_result.get_extension("waveforms") + wf_ext = sorting_analyzer.get_extension("waveforms") assert wf_ext is not None, "plot_waveforms() need extension 'waveforms'" for unit_id in unit_ids: unit_index = list(sorting.unit_ids).index(unit_id) if not extra_sparsity: - if sorting_result.is_sparse(): + if sorting_analyzer.is_sparse(): # wfs = we.get_waveforms(unit_id) wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) else: @@ -172,7 +172,7 @@ def __init__( # wfs = we.get_waveforms(unit_id) wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) # find additional slice to apply to sparse waveforms - (wfs_sparse_indices,) = np.nonzero(sorting_result.sparsity.mask[unit_index]) + (wfs_sparse_indices,) = np.nonzero(sorting_analyzer.sparsity.mask[unit_index]) (extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index]) (extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices)) # apply extra sparsity @@ -180,8 +180,8 @@ def __init__( wfs_by_ids[unit_id] = wfs plot_data = dict( - sorting_result=sorting_result, - sampling_frequency=sorting_result.sampling_frequency, + sorting_analyzer=sorting_analyzer, + sampling_frequency=sorting_analyzer.sampling_frequency, unit_ids=unit_ids, channel_ids=channel_ids, sparsity=sparsity, @@ -346,7 +346,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() cm = 1 / 2.54 - self.sorting_result = data_plot["sorting_result"] + self.sorting_analyzer = data_plot["sorting_analyzer"] width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -415,8 +415,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: display(self.widget) - def _get_template_shadings(self, sorting_result, unit_ids, templates_percentile_shading): - ext = sorting_result.get_extension("templates") + def _get_template_shadings(self, sorting_analyzer, unit_ids, templates_percentile_shading): + ext = sorting_analyzer.get_extension("templates") templates = ext.get_templates(unit_ids=unit_ids, operator="average") if templates_percentile_shading is None: @@ -450,8 +450,8 @@ def _update_plot(self, change): hide_axis = self.hide_axis_button.value do_shading = self.template_shading_button.value - wf_ext = self.sorting_result.get_extension("waveforms") - templates_ext = self.sorting_result.get_extension("templates") + wf_ext = self.sorting_analyzer.get_extension("waveforms") + templates_ext = self.sorting_analyzer.get_extension("templates") templates = templates_ext.get_templates(unit_ids=unit_ids, operator="average") # matplotlib next_data_plot dict update at each call @@ -459,7 +459,7 @@ def _update_plot(self, change): data_plot["unit_ids"] = unit_ids data_plot["templates"] = templates templates_shadings = self._get_template_shadings( - self.sorting_result, unit_ids, data_plot["templates_percentile_shading"] + self.sorting_analyzer, unit_ids, data_plot["templates_percentile_shading"] ) data_plot["templates_shading"] = templates_shadings data_plot["same_axis"] = same_axis @@ -493,7 +493,7 @@ def _update_plot(self, change): ax.axis("off") # update probe plot - channel_locations = self.sorting_result.get_channel_locations() + channel_locations = self.sorting_analyzer.get_channel_locations() self.ax_probe.plot( channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 ) @@ -520,7 +520,7 @@ def _update_plot(self, change): fig_probe.canvas.flush_events() -def get_waveforms_scales(sorting_result, templates, channel_locations, x_offset_units=False): +def get_waveforms_scales(sorting_analyzer, templates, channel_locations, x_offset_units=False): """ Return scales and x_vector for templates plotting """ @@ -546,7 +546,7 @@ def get_waveforms_scales(sorting_result, templates, channel_locations, x_offset_ y_offset = channel_locations[:, 1][None, :] - nbefore = sorting_result.get_extension("waveforms").nbefore + nbefore = sorting_analyzer.get_extension("waveforms").nbefore nsamples = templates.shape[1] xvect = delta_x * (np.arange(nsamples) - nbefore) / nsamples * 0.7 diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index ce0053e9af..41c77f59fa 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -14,15 +14,15 @@ class UnitWaveformDensityMapWidget(BaseWidget): Parameters ---------- - sorting_result : SortingResult - The SortingResult for calculating waveforms + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer for calculating waveforms channel_ids : list or None, default: None The channel ids to display unit_ids : list or None, default: None List of unit ids sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply - If SortingResult is already sparse, the argument is ignored + If SortingAnalyzer is already sparse, the argument is ignored use_max_channel : bool, default: False Use only the max channel peak_sign : "neg" | "pos" | "both", default: "neg" @@ -37,7 +37,7 @@ class UnitWaveformDensityMapWidget(BaseWidget): def __init__( self, - sorting_result, + sorting_analyzer, channel_ids=None, unit_ids=None, sparsity=None, @@ -48,39 +48,39 @@ def __init__( backend=None, **backend_kwargs, ): - sorting_result = self.ensure_sorting_result(sorting_result) + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) if channel_ids is None: - channel_ids = sorting_result.channel_ids + channel_ids = sorting_analyzer.channel_ids if unit_ids is None: - unit_ids = sorting_result.unit_ids + unit_ids = sorting_analyzer.unit_ids if unit_colors is None: - unit_colors = get_unit_colors(sorting_result.sorting) + unit_colors = get_unit_colors(sorting_analyzer.sorting) if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" max_channels = get_template_extremum_channel( - sorting_result, mode="extremum", peak_sign=peak_sign, outputs="index" + sorting_analyzer, mode="extremum", peak_sign=peak_sign, outputs="index" ) # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all - if sorting_result.is_sparse(): - assert sparsity is None, "UnitWaveformDensity SortingResult is already sparse" - used_sparsity = sorting_result.sparsity + if sorting_analyzer.is_sparse(): + assert sparsity is None, "UnitWaveformDensity SortingAnalyzer is already sparse" + used_sparsity = sorting_analyzer.sparsity elif sparsity is not None: assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" used_sparsity = sparsity else: # in this case, we construct a dense sparsity - used_sparsity = ChannelSparsity.create_dense(sorting_result) + used_sparsity = ChannelSparsity.create_dense(sorting_analyzer) channel_inds = used_sparsity.unit_id_to_channel_indices # bins # templates = we.get_all_templates(unit_ids=unit_ids) - templates = sorting_result.get_extension("templates").get_templates(unit_ids=unit_ids) + templates = sorting_analyzer.get_extension("templates").get_templates(unit_ids=unit_ids) bin_min = np.min(templates) * 1.3 bin_max = np.max(templates) * 1.3 bin_size = (bin_max - bin_min) / 100 @@ -90,14 +90,14 @@ def __init__( if same_axis: all_hist2d = None # channel union across units - unit_inds = sorting_result.sorting.ids_to_indices(unit_ids) + unit_inds = sorting_analyzer.sorting.ids_to_indices(unit_ids) (shared_chan_inds,) = np.nonzero(np.sum(used_sparsity.mask[unit_inds, :], axis=0)) else: all_hist2d = {} - wf_ext = sorting_result.get_extension("waveforms") + wf_ext = sorting_analyzer.get_extension("waveforms") for i, unit_id in enumerate(unit_ids): - unit_index = sorting_result.sorting.id_to_index(unit_id) + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) chan_inds = channel_inds[unit_id] # this have already the sparsity @@ -147,7 +147,7 @@ def __init__( # plot median templates_flat = {} for i, unit_id in enumerate(unit_ids): - unit_index = sorting_result.sorting.id_to_index(unit_id) + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) chan_inds = channel_inds[unit_id] template = templates[i, :, chan_inds] template_flat = template.flatten() @@ -156,7 +156,7 @@ def __init__( plot_data = dict( unit_ids=unit_ids, unit_colors=unit_colors, - channel_ids=sorting_result.channel_ids, + channel_ids=sorting_analyzer.channel_ids, channel_inds=channel_inds, same_axis=same_axis, bin_min=bin_min, From b38149609ad2e11bd315ddabd697ef94add835ef Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 16 Feb 2024 14:24:56 +0100 Subject: [PATCH 099/192] oups --- src/spikeinterface/core/__init__.py | 4 ++-- src/spikeinterface/core/sortinganalyzer.py | 8 +++++-- .../tests/test_analyzer_extension_core.py | 4 +++- .../core/tests/test_sortinganalyzer.py | 4 +++- src/spikeinterface/exporters/report.py | 4 +++- .../exporters/tests/test_report.py | 6 ++++- .../postprocessing/correlograms.py | 1 - src/spikeinterface/postprocessing/isi.py | 1 - .../tests/common_extension_tests.py | 4 +++- .../tests/test_template_similarity.py | 9 +++++--- .../qualitymetrics/misc_metrics.py | 4 +++- .../tests/test_metrics_functions.py | 8 +++++-- .../sorters/internal/tridesclous2.py | 21 ++++++++++------- .../clustering/position_and_features.py | 23 +++++++++++-------- .../tests/test_template_matching.py | 17 ++++---------- .../test_waveform_thresholder.py | 16 +++++++++---- .../waveforms/temporal_pca.py | 4 +++- .../widgets/all_amplitudes_distributions.py | 4 +++- src/spikeinterface/widgets/rasters.py | 2 +- .../widgets/tests/test_widgets.py | 8 +++++-- src/spikeinterface/widgets/unit_depths.py | 4 +++- src/spikeinterface/widgets/unit_presence.py | 4 ++-- src/spikeinterface/widgets/unit_waveforms.py | 4 +++- 23 files changed, 104 insertions(+), 60 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index bdc29cd17a..d1f67412ec 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -112,8 +112,8 @@ # from .waveform_extractor import ( # WaveformExtractor, # BaseWaveformExtractorExtension, - # extract_waveforms, - # load_waveforms, +# extract_waveforms, +# load_waveforms, # precompute_sparsity, # ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 928dffa780..d8d5afaee2 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -570,7 +570,9 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> if format == "memory": # This make a copy of actual SortingAnalyzer - new_sorting_analyzer = SortingAnalyzer.create_memory(sorting_provenance, recording, sparsity, self.rec_attributes) + new_sorting_analyzer = SortingAnalyzer.create_memory( + sorting_provenance, recording, sparsity, self.rec_attributes + ) elif format == "binary_folder": # create a new folder @@ -609,7 +611,9 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> # make a copy of extensions # note that the copy of extension handle itself the slicing of units when necessary and also the saveing for extension_name, extension in self.extensions.items(): - new_ext = new_sorting_analyzer.extensions[extension_name] = extension.copy(new_sorting_analyzer, unit_ids=unit_ids) + new_ext = new_sorting_analyzer.extensions[extension_name] = extension.copy( + new_sorting_analyzer, unit_ids=unit_ids + ) return new_sorting_analyzer diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index cb7450d561..482963ffe1 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -153,7 +153,9 @@ def test_ComputeFastTemplates(format, sparse): # compare ComputeTemplates with dense and ComputeFastTemplates: should give the same on "average" other_sorting_analyzer = get_sorting_analyzer(format=format, sparse=False) other_sorting_analyzer.select_random_spikes(max_spikes_per_unit=20, seed=2205) - other_sorting_analyzer.compute("waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) + other_sorting_analyzer.compute( + "waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs + ) other_sorting_analyzer.compute( "templates", operators=[ diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 1ae0d193e6..3cd1286afb 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -58,7 +58,9 @@ def test_SortingAnalyzer_zarr(): if folder.exists(): shutil.rmtree(folder) - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None) + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None + ) sorting_analyzer = load_sorting_analyzer(folder, format="auto") _check_sorting_analyzers(sorting_analyzer, sorting) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index d375996945..c29c8aaf2b 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -101,7 +101,9 @@ def export_report( # unit list units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" - units["max_on_channel_id"] = pd.Series(get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="id")) + units["max_on_channel_id"] = pd.Series( + get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="id") + ) units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign="neg")) units.to_csv(output_folder / "unit list.csv", sep="\t") diff --git a/src/spikeinterface/exporters/tests/test_report.py b/src/spikeinterface/exporters/tests/test_report.py index c89a0d70c6..cd000bc077 100644 --- a/src/spikeinterface/exporters/tests/test_report.py +++ b/src/spikeinterface/exporters/tests/test_report.py @@ -5,7 +5,11 @@ from spikeinterface.exporters import export_report -from spikeinterface.exporters.tests.common import cache_folder, make_sorting_analyzer, sorting_analyzer_sparse_for_export +from spikeinterface.exporters.tests.common import ( + cache_folder, + make_sorting_analyzer, + sorting_analyzer_sparse_for_export, +) def test_export_report(sorting_analyzer_sparse_for_export): diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 3becdd1b16..12a9c8d42f 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -167,7 +167,6 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ return _compute_crosscorr_numba(spike_times1.astype(np.int64), spike_times2.astype(np.int64), window_size, bin_size) - def compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): """ Computes several cross-correlogram in one course from several clusters. diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 367db16533..9bb1a3a1ba 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -69,7 +69,6 @@ def _get_data(self): compute_isi_histograms = ComputeISIHistograms.function_factory() - def _compute_isi_histograms(sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): """ Computes the Inter-Spike Intervals histogram for all diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 5516fe592a..a24e962e56 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -54,7 +54,9 @@ def get_sorting_analyzer(recording, sorting, format="memory", sparsity=None, nam if folder and folder.exists(): shutil.rmtree(folder) - sorting_analyzer = create_sorting_analyzer(sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity) + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format=format, folder=folder, sparse=False, sparsity=sparsity + ) return sorting_analyzer diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 26a065dc29..b8fc608d2e 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -1,6 +1,10 @@ import unittest -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite, get_sorting_analyzer, get_dataset +from spikeinterface.postprocessing.tests.common_extension_tests import ( + ResultExtensionCommonTestSuite, + get_sorting_analyzer, + get_dataset, +) from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity @@ -13,7 +17,7 @@ class SimilarityExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase) def test_check_equal_template_with_distribution_overlap(): - + recording, sorting = get_dataset() sorting_analyzer = get_sorting_analyzer(recording, sorting, sparsity=None) @@ -32,7 +36,6 @@ def test_check_equal_template_with_distribution_overlap(): check_equal_template_with_distribution_overlap(waveforms0, waveforms1) - if __name__ == "__main__": # test = SimilarityExtensionTest() # test.setUpClass() diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 9530f2bf9c..4ee4588f0c 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -702,7 +702,9 @@ def compute_amplitude_cv_metrics( amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: firing_rate = num_spikes[unit_id] / total_duration - temporal_bin_size_samples = int((average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency) + temporal_bin_size_samples = int( + (average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency + ) amp_spreads = [] # bins and amplitude means are computed for each segment diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 96920c08e5..3f3cec54fe 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -209,7 +209,9 @@ def test_calculate_firing_range(sorting_analyzer_simple): print(firing_ranges) with pytest.warns(UserWarning) as w: - firing_ranges_nan = compute_firing_ranges(sorting_analyzer, bin_size_s=sorting_analyzer.get_total_duration() + 1) + firing_ranges_nan = compute_firing_ranges( + sorting_analyzer, bin_size_s=sorting_analyzer.get_total_duration() + 1 + ) assert np.all([np.isnan(f) for f in firing_ranges_nan.values()]) @@ -338,7 +340,9 @@ def test_synchrony_metrics(sorting_analyzer_simple): sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory") - previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer, synchrony_sizes=synchrony_sizes) + previous_synchrony_metrics = compute_synchrony_metrics( + previous_sorting_analyzer, synchrony_sizes=synchrony_sizes + ) current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes) print(current_synchrony_metrics) # check that all values increased diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index de4e2d44ec..8a9dfc1cef 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -300,23 +300,28 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) # sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler") - - nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.) - nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.) - templates_array = estimate_templates_average(recording, sorting_pre_peeler.to_spike_vector(), sorting_pre_peeler.unit_ids, - nbefore, nafter, return_scaled=False, **job_kwargs) + nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0) + nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0) + templates_array = estimate_templates_average( + recording, + sorting_pre_peeler.to_spike_vector(), + sorting_pre_peeler.unit_ids, + nbefore, + nafter, + return_scaled=False, + **job_kwargs, + ) templates_dense = Templates( templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, - probe=recording.get_probe() + probe=recording.get_probe(), ) # TODO : try other methods for sparsity # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.) - sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=1.) + sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=1.0) templates = templates_dense.to_sparse(sparsity) - # snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum") # print(snrs) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index f317706838..c4a999ae92 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -168,19 +168,22 @@ def main_function(cls, recording, peaks, params): tmp_folder = Path(os.path.join(get_global_tmp_folder(), name)) sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) - - nbefore = int(params["ms_before"] * fs / 1000.) - nafter = int(params["ms_after"] * fs / 1000.) - templates_array = estimate_templates_average(recording, sorting.to_spike_vector(), sorting.unit_ids, - nbefore, nafter, return_scaled=False, **params["job_kwargs"]) + + nbefore = int(params["ms_before"] * fs / 1000.0) + nafter = int(params["ms_after"] * fs / 1000.0) + templates_array = estimate_templates_average( + recording, + sorting.to_spike_vector(), + sorting.unit_ids, + nbefore, + nafter, + return_scaled=False, + **params["job_kwargs"], + ) templates = Templates( - templates_array=templates_array, - sampling_frequency=fs, - nbefore=nbefore, - probe=recording.get_probe() + templates_array=templates_array, sampling_frequency=fs, nbefore=nbefore, probe=recording.get_probe() ) - labels, peak_labels = remove_duplicates_via_matching( templates, peak_labels, job_kwargs=params["job_kwargs"], **params["cleaning_kwargs"] ) diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index b7c94c6238..a73ce93b4c 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -9,10 +9,9 @@ from spikeinterface.sortingcomponents.tests.common import make_dataset - - job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) + def get_sorting_analyzer(): recording, sorting = make_dataset() sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=False) @@ -53,14 +52,10 @@ def test_find_spikes_from_templates(method, sorting_analyzer): method_kwargs_ = method_kwargs.get(method, {}) method_kwargs_.update(method_kwargs_all) - spikes = find_spikes_from_templates( - recording, method=method, method_kwargs=method_kwargs_, **job_kwargs - ) - - + spikes = find_spikes_from_templates(recording, method=method, method_kwargs=method_kwargs_, **job_kwargs) # DEBUG = True - + # if DEBUG: # import matplotlib.pyplot as plt # import spikeinterface.full as si @@ -68,11 +63,10 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # sorting_analyzer.compute("waveforms") # sorting_analyzer.compute("templates") - # gt_sorting = sorting_analyzer.sorting # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) - + # metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) # fig, ax = plt.subplots() @@ -85,9 +79,8 @@ def test_find_spikes_from_templates(method, sorting_analyzer): if __name__ == "__main__": sorting_analyzer = get_sorting_analyzer() # method = "naive" - # method = "tdc-peeler" + # method = "tdc-peeler" # method = "circus" # method = "circus-omp-svd" method = "wobble" test_find_spikes_from_templates(method, sorting_analyzer) - diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py index fdbabc7584..4f55030283 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py @@ -20,7 +20,9 @@ def extract_dense_waveforms_node(generated_recording): ) -def test_waveform_thresholder_ptp(extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs): +def test_waveform_thresholder_ptp( + extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs +): recording = generated_recording peaks = detected_peaks @@ -39,7 +41,9 @@ def test_waveform_thresholder_ptp(extract_dense_waveforms_node, generated_record assert np.all(data[data != 0] > 3) -def test_waveform_thresholder_mean(extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs): +def test_waveform_thresholder_mean( + extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs +): recording = generated_recording peaks = detected_peaks @@ -56,7 +60,9 @@ def test_waveform_thresholder_mean(extract_dense_waveforms_node, generated_recor assert np.all(tresholded_waveforms.mean(axis=1) >= 0) -def test_waveform_thresholder_energy(extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs): +def test_waveform_thresholder_energy( + extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs +): recording = generated_recording peaks = detected_peaks @@ -75,7 +81,9 @@ def test_waveform_thresholder_energy(extract_dense_waveforms_node, generated_rec assert np.all(data[data != 0] > 3) -def test_waveform_thresholder_operator(extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs): +def test_waveform_thresholder_operator( + extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs +): recording = generated_recording peaks = detected_peaks diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 0226d706d1..fb9d1010f8 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -143,7 +143,9 @@ def fit( sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True) sorting_analyzer.select_random_spikes() sorting_analyzer.compute("waveforms", ms_before=ms_before, ms_after=ms_after) - sorting_analyzer.compute("principal_components", n_components=n_components, mode="by_channel_global", whiten=whiten) + sorting_analyzer.compute( + "principal_components", n_components=n_components, mode="by_channel_global", whiten=whiten + ) pca_model = sorting_analyzer.get_extension("principal_components").get_pca_model() params = { diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 5bd7b9679d..59a69640da 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -23,7 +23,9 @@ class AllAmplitudesDistributionsWidget(BaseWidget): Dict of colors with key: unit, value: color, default None """ - def __init__(self, sorting_analyzer: SortingAnalyzer, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs): + def __init__( + self, sorting_analyzer: SortingAnalyzer, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs + ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) self.check_extensions(sorting_analyzer, "spike_amplitudes") diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index a460a8e179..957eaadcc9 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -28,7 +28,7 @@ def __init__( self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", backend=None, **backend_kwargs ): sorting = self.ensure_sorting(sorting) - + if segment_index is None: if sorting.get_num_segments() != 1: raise ValueError("You must provide segment_index=...") diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 54ca074f16..141f73e881 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -449,8 +449,12 @@ def test_plot_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_similarity(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_template_similarity(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity( + self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend] + ) + sw.plot_template_similarity( + self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend] + ) def test_plot_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 982b34f0c5..c5fe3e05e8 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -26,7 +26,9 @@ class UnitDepthsWidget(BaseWidget): Sign of peak for amplitudes """ - def __init__(self, sorting_analyzer, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs): + def __init__( + self, sorting_analyzer, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs + ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) diff --git a/src/spikeinterface/widgets/unit_presence.py b/src/spikeinterface/widgets/unit_presence.py index 69f673b0db..746868b89d 100644 --- a/src/spikeinterface/widgets/unit_presence.py +++ b/src/spikeinterface/widgets/unit_presence.py @@ -32,9 +32,9 @@ def __init__( smooth_sigma=4.5, backend=None, **backend_kwargs, - ): + ): sorting = self.ensure_sorting(sorting) - + if segment_index is None: nseg = sorting.get_num_segments() if nseg != 1: diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index db3ddbf6b9..ab415ae2f0 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -115,7 +115,9 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(sorting) - channel_locations = sorting_analyzer.get_channel_locations()[sorting_analyzer.channel_ids_to_indices(channel_ids)] + channel_locations = sorting_analyzer.get_channel_locations()[ + sorting_analyzer.channel_ids_to_indices(channel_ids) + ] extra_sparsity = False if sorting_analyzer.is_sparse(): From b9dc753941ad8ff41817cfc43e989aac8aa715ed Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 16 Feb 2024 15:50:39 +0100 Subject: [PATCH 100/192] oups --- src/spikeinterface/postprocessing/principal_component.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 155d072b26..af41f95d87 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -211,8 +211,6 @@ def get_some_projections(self, channel_ids=None, unit_ids=None): unit_indices = sorting.ids_to_indices(unit_ids) selected_inds = np.flatnonzero(np.isin(some_spikes["unit_index"], unit_indices)) - print(selected_inds.size, unit_indices, some_spikes["unit_index"].size) - print(np.min(selected_inds), np.max(selected_inds)) spike_unit_indices = some_spikes["unit_index"][selected_inds] if sparsity is None: From d78594652956409a8134279c5e2d3832fab57791 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 16 Feb 2024 10:02:47 -0500 Subject: [PATCH 101/192] sorteranalyzer typo fixes --- src/spikeinterface/core/sortinganalyzer.py | 76 +++++++++++----------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index d8d5afaee2..7006d2ee84 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -41,9 +41,9 @@ def create_sorting_analyzer( Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. This object will handle a list of ResultExtension for all the post processing steps like: waveforms, - templates, unit locations, spike locations, quality mertics ... + templates, unit locations, spike locations, quality metrics ... - This object will be also use used for ploting purpose. + This object will be also use used for plotting purpose. Parameters @@ -59,12 +59,12 @@ def create_sorting_analyzer( The "folder" argument must be specified in case of mode "folder". If "memory" is used, the waveforms are stored in RAM. Use this option carefully! sparse: bool, default: True - If True, then a sparsity mask is computed usingthe `estimate_sparsity()` function is run using + If True, then a sparsity mask is computed using the `estimate_sparsity()` function using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the sparsity will be propagated to all ResultExtention that handle sparsity (like wavforms, pca, ...) You can control `estimate_sparsity()` : all extra arguments are propagated to it (included job_kwargs) sparsity: ChannelSparsity or None, default: None - The sparsity used to compute waveforms. If this is given, `sparse` is ignored. Default None. + The sparsity used to compute waveforms. If this is given, `sparse` is ignored. Returns ------- @@ -141,17 +141,17 @@ class SortingAnalyzer: Class to make a pair of Recording-Sorting which will be used used for all post postprocessing, visualization and quality metric computation. - This internaly maintain a list of computed ResultExtention (waveform, pca, unit position, spike poisition, ...). + This internally maintains a list of computed ResultExtention (waveform, pca, unit position, spike position, ...). This can live in memory and/or can be be persistent to disk in 2 internal formats (folder/json/npz or zarr). A SortingAnalyzer can be transfer to another format using `save_as()` This handle unit sparsity that can be propagated to ResultExtention. - This handle spike sampling that can be propagated to ResultExtention : work only on a subset of spikes. + This handle spike sampling that can be propagated to ResultExtention : works at only on a subset of spikes. - This internally save a copy of the Sorting and extract main recording attributes (without traces) so - the SortingAnalyzer object can be reload even if references to the original sorting and/or to the original recording + This internally saves a copy of the Sorting and extracts main recording attributes (without traces) so + the SortingAnalyzer object can be reloaded even if references to the original sorting and/or to the original recording are lost. SortingAnalyzer() should not never be used directly for creating: use instead create_sorting_analyzer(sorting, resording, ...) @@ -429,7 +429,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) else: warnings.warn( - "SortingAnalyzer with zarr : the Recording is not json serializable, the recording link will be lost for futur load" + "SortingAnalyzer with zarr : the Recording is not json serializable, the recording link will be lost for future load" ) # sorting provenance @@ -483,7 +483,7 @@ def load_from_zarr(cls, folder, recording=None): import zarr folder = Path(folder) - assert folder.is_dir(), f"This folder does not exists {folder}" + assert folder.is_dir(), f"This folder does not exist {folder}" zarr_root = zarr.open(folder, mode="r") @@ -587,7 +587,7 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) new_sorting_analyzer.folder = folder else: - raise ValueError("SortingAnalyzer.save: wrong format") + raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") # propagate random_spikes_indices is already done if self.random_spikes_indices is not None: @@ -620,11 +620,11 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": """ Save SortingAnalyzer object into another format. - Uselfull for memory to zarr or memory to binray. + Uselful for memory to zarr or memory to binary. Note that the recording provenance or sorting provenance can be lost. - Mainly propagate the copied sorting and recording property. + Mainly propagates the copied sorting and recording properties. Parameters ---------- @@ -638,7 +638,7 @@ def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ This method is equivalent to `save_as()`but with a subset of units. - Filters units by creating a new waveform extractor object in a new folder. + Filters units by creating a new sorting analyzer object in a new folder. Extensions are also updated to filter the selected unit ids. @@ -653,7 +653,7 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyz Returns ------- we : SortingAnalyzer - The newly create waveform extractor with the selected units + The newly create sorting_analyzer with the selected units """ # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! return self._save_or_select(format=format, folder=folder, unit_ids=unit_ids) @@ -780,14 +780,14 @@ def get_dtype(self): def compute(self, input, save=True, **kwargs): """ - Compute one extension or several extension. - Internally calling compute_one_extension() or compute_several_extensions() depending th input type. + Compute one extension or several extensiosn. + Internally calls compute_one_extension() or compute_several_extensions() depending on the input type. Parameters ---------- input: str or dict - If the input is a string then compute one extension with compute_one_extension(extension_name=input, ...) - If the input is a dict then compute several extension with compute_several_extensions(extensions=input) + If the input is a string then computes one extension with compute_one_extension(extension_name=input, ...) + If the input is a dict then compute several extensions with compute_several_extensions(extensions=input) """ if isinstance(input, str): return self.compute_one_extension(extension_name=input, save=save, **kwargs) @@ -808,10 +808,10 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): save: bool, default True It the extension can be saved then it is saved. If not then the extension will only live in memory as long as the object is deleted. - save=False is convinient to try some parameters without changing an already saved extension. + save=False is convenient to try some parameters without changing an already saved extension. **kwargs: - All other kwargs are transimited to extension.set_params() or job_kwargs + All other kwargs are transmitted to extension.set_params() or job_kwargs Returns ------- @@ -840,14 +840,14 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): # check dependencies if extension_class.need_recording: - assert self.has_recording(), f"Extension {extension_name} need the recording" + assert self.has_recording(), f"Extension {extension_name} requires the recording" for dependency_name in extension_class.depend_on: if "|" in dependency_name: # at least one extension must be done : usefull for "templates|fast_templates" for instance ok = any(self.get_extension(name) is not None for name in dependency_name.split("|")) else: ok = self.get_extension(dependency_name) is not None - assert ok, f"Extension {extension_name} need {dependency_name} to be computed first" + assert ok, f"Extension {extension_name} requires {dependency_name} to be computed first" extension_instance = extension_class(self) extension_instance.set_params(save=save, **params) @@ -864,11 +864,11 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): Parameters ---------- extensions: dict - Key are extension_name and values are params. + Keys are extension_names and values are params. save: bool, default True It the extension can be saved then it is saved. If not then the extension will only live in memory as long as the object is deleted. - save=False is convinient to try some parameters without changing an already saved extension. + save=False is convenient to try some parameters without changing an already saved extension. Returns ------- @@ -971,7 +971,7 @@ def get_extension(self, extension_name: str): Get a ResultExtension. If not loaded then load is automatic. - Return None if the extension is not computed yet (this avoid the use of has_extension() and then get it) + Return None if the extension is not computed yet (this avoids the use of has_extension() and then get it) """ if extension_name in self.extensions: @@ -986,7 +986,7 @@ def get_extension(self, extension_name: str): def load_extension(self, extension_name: str): """ - Load an extensionet from folder or zarr into the `ResultSorting.extensions` dict. + Load an extension from a folder or zarr into the `ResultSorting.extensions` dict. Parameters ---------- @@ -1001,7 +1001,7 @@ def load_extension(self, extension_name: str): """ assert ( self.format != "memory" - ), "SortingAnalyzer.load_extension() do not work for format='memory' use SortingAnalyzer.get_extension()instead" + ), "SortingAnalyzer.load_extension() does not work for format='memory' use SortingAnalyzer.get_extension() instead" extension_class = get_extension_class(extension_name) @@ -1015,7 +1015,7 @@ def load_extension(self, extension_name: str): def load_all_saved_extension(self): """ - Load all saved extension in memory. + Load all saved extensions in memory. """ for extension_name in self.get_saved_extension_names(): self.load_extension(extension_name) @@ -1073,7 +1073,7 @@ def _save_random_spikes_indices(self): def get_selected_indices_in_spike_train(self, unit_id, segment_index): # usefull for Waveforms extractor backwars compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain - assert self.random_spikes_indices is not None, "random spikes selection is not computeds" + assert self.random_spikes_indices is not None, "random spikes selection is not computed" unit_index = self.sorting.id_to_index(unit_id) spikes = self.sorting.to_spike_vector() spike_indices_in_seg = np.flatnonzero( @@ -1133,7 +1133,7 @@ def get_extension_class(extension_name: str): extensions_dict = {ext.extension_name: ext for ext in _possible_extensions} assert ( extension_name in extensions_dict - ), f"Extension '{extension_name}' is not registered, please import related module before" + ), f"Extension '{extension_name}' is not registered, please import related module before use" ext_class = extensions_dict[extension_name] return ext_class @@ -1141,7 +1141,7 @@ def get_extension_class(extension_name: str): class ResultExtension: """ This the base class to extend the SortingAnalyzer. - It can handle persistency to disk any computations related + It can handle persistency to disk for any computations related to: For instance: * waveforms @@ -1149,8 +1149,8 @@ class ResultExtension: * spike amplitudes * quality metrics - Possible extension can be register on the fly at import time with register_result_extension() mechanism. - It also enables any custum computation on top on SortingAnalyzer to be implemented by the user. + Possible extension can be registered on-the-fly at import time with register_result_extension() mechanism. + It also enables any custom computation on top of the SortingAnalyzer to be implemented by the user. An extension needs to inherit from this class and implement some attributes and abstract methods: * extension_name @@ -1169,7 +1169,7 @@ class ResultExtension: The subclass must also hanle an attribute `data` which is a dict contain the results after the `run()`. All ResultExtension will have a function associate for instance (this use the function_factory): - comptute_unit_location(sorting_analyzer, ...) will be equivalent to sorting_analyzer.compute("unit_location", ...) + compute_unit_location(sorting_analyzer, ...) will be equivalent to sorting_analyzer.compute("unit_location", ...) """ @@ -1234,7 +1234,7 @@ def __call__(self, sorting_analyzer, load_if_exists=None, *args, **kwargs): sorting_analyzer = sorting_analyzer.sorting_analyzer if not isinstance(sorting_analyzer, SortingAnalyzer): - raise ValueError(f"compute_{self.extension_name}() need a SortingAnalyzer instance") + raise ValueError(f"compute_{self.extension_name}() needs a SortingAnalyzer instance") if load_if_exists is not None: # backward compatibility with "load_if_exists" @@ -1387,7 +1387,7 @@ def _save_data(self, **kwargs): return if self.sorting_analyzer.is_read_only(): - raise ValueError(f"The SortingAnalyzer is read only save extension {self.extension_name} is not possible") + raise ValueError(f"The SortingAnalyzer is read-only saving extension {self.extension_name} is not possible") if self.format == "binary_folder": import pandas as pd @@ -1452,7 +1452,7 @@ def _save_data(self, **kwargs): def _reset_extension_folder(self): """ - Delete the extension in folder (binary or zarr) and create an empty one. + Delete the extension in a folder (binary or zarr) and create an empty one. """ if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() From 01ec8627bf4d3fb59384a8a54d81e07ba3ce66e9 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 16 Feb 2024 10:06:24 -0500 Subject: [PATCH 102/192] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 7006d2ee84..5eaa165850 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -148,7 +148,7 @@ class SortingAnalyzer: This handle unit sparsity that can be propagated to ResultExtention. - This handle spike sampling that can be propagated to ResultExtention : works at only on a subset of spikes. + This handle spike sampling that can be propagated to ResultExtention : works on only a subset of spikes. This internally saves a copy of the Sorting and extracts main recording attributes (without traces) so the SortingAnalyzer object can be reloaded even if references to the original sorting and/or to the original recording From 2438ac79abef422469ef2fe71754b5cf3b853aa2 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:41:49 -0500 Subject: [PATCH 103/192] WIP: switching some docs over --- doc/modules/core.rst | 211 ++++++++++++++++++--------------- doc/modules/postprocessing.rst | 80 ++++++------- doc/modules/qualitymetrics.rst | 18 +-- 3 files changed, 160 insertions(+), 149 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 17c7bd69e4..73262b4ba7 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -9,7 +9,7 @@ The :py:mod:`spikeinterface.core` module provides the basic classes and tools of Several Base classes are implemented here and inherited throughout the SI code-base. The core classes are: :py:class:`~spikeinterface.core.BaseRecording` (for raw data), :py:class:`~spikeinterface.core.BaseSorting` (for spike-sorted data), and -:py:class:`~spikeinterface.core.WaveformExtractor` (for waveform extraction and postprocessing). +:py:class:`~spikeinterface.core.SortingAnalyzer` (for postprocessing, quality metrics, and waveform extraction). There are additional classes to allow to retrieve events (:py:class:`~spikeinterface.core.BaseEvent`) and to handle unsorted waveform cutouts, or *snippets*, which are recorded by some acquisition systems @@ -163,105 +163,12 @@ Internally, any sorting object can construct 2 internal caches: time, like for extracting amplitudes from a recording. -WaveformExtractor ------------------ - -The :py:class:`~spikeinterface.core.WaveformExtractor` class is the core object to combine a -:py:class:`~spikeinterface.core.BaseRecording` and a :py:class:`~spikeinterface.core.BaseSorting` object. -Waveforms are very important for additional analyses, and the basis of several postprocessing and quality metrics -computations. - -The :py:class:`~spikeinterface.core.WaveformExtractor` allows us to: - -* extract waveforms -* sub-sample spikes for waveform extraction -* compute templates (i.e. average extracellular waveforms) with different modes -* save waveforms in a folder (in numpy / `Zarr `_) for easy retrieval -* save sparse waveforms or *sparsify* dense waveforms -* select units and associated waveforms - -In the default format (:code:`mode='folder'`) waveforms are saved to a folder structure with waveforms as -:code:`.npy` files. -In addition, waveforms can also be extracted in-memory for fast computations (:code:`mode='memory'`). -Note that this mode can quickly fill up your RAM... Use it wisely! -Finally, an existing :py:class:`~spikeinterface.core.WaveformExtractor` can be saved also in :code:`zarr` format. - -.. code-block:: python - - # extract dense waveforms on 500 spikes per unit - we = extract_waveforms(recording=recording, - sorting=sorting, - sparse=False, - folder="waveforms", - max_spikes_per_unit=500 - overwrite=True) - # same, but with parallel processing! (1s chunks processed by 8 jobs) - job_kwargs = dict(n_jobs=8, chunk_duration="1s") - we = extract_waveforms(recording=recording, - sorting=sorting, - sparse=False, - folder="waveforms_parallel", - max_spikes_per_unit=500, - overwrite=True, - **job_kwargs) - # same, but in-memory - we_mem = extract_waveforms(recording=recording, - sorting=sorting, - sparse=False, - folder=None, - mode="memory", - max_spikes_per_unit=500, - **job_kwargs) - - # load pre-computed waveforms - we_loaded = load_waveforms(folder="waveforms") - - # retrieve waveforms and templates for a unit - waveforms0 = we.get_waveforms(unit_id=unit0) - template0 = we.get_template(unit_id=unit0) - - # compute template standard deviations (average is computed by default) - # (this can also be done within the 'extract_waveforms') - we.precompute_templates(modes=("std",)) - - # retrieve all template means and standard deviations - template_means = we.get_all_templates(mode="average") - template_stds = we.get_all_templates(mode="std") - - # save to Zarr - we_zarr = we.save(folder="waveforms_zarr", format="zarr") - - # extract sparse waveforms (see Sparsity section) - # this will use 50 spikes per unit to estimate the sparsity within a 40um radius from that unit - we_sparse = extract_waveforms(recording=recording, - sorting=sorting, - folder="waveforms_sparse", - max_spikes_per_unit=500, - method="radius", - radius_um=40, - num_spikes_for_sparsity=50) +SortingAnalyzer +--------------- +The :py:class:`~spikeinterface.core.SortingAnalyzer` is the class which connects a :code:`Recording` and a :code:`Sorting`. -**IMPORTANT:** to load a waveform extractor object from disk, it needs to be able to reload the associated -:code:`sorting` object (the :code:`recording` is optional, using :code:`with_recording=False`). -In order to make a waveform folder portable (e.g. copied to another location or machine), one can do: - -.. code-block:: python - - # create a "processed" folder - processed_folder = Path("processed") - - # save the sorting object in the "processed" folder - sorting = sorting.save(folder=processed_folder / "sorting") - # extract waveforms using relative paths - we = extract_waveforms(recording=recording, - sorting=sorting, - folder=processed_folder / "waveforms", - use_relative_path=True) - # the "processed" folder is now portable, and the waveform extractor can be reloaded - # from a different location/machine (without loading the recording) - we_loaded = si.load_waveforms(folder=processed_folder / "waveforms", - with_recording=False) +**To be filled in** Event @@ -783,3 +690,111 @@ various formats: # SpikeGLX format local_folder_path = download_dataset(remote_path='/spikeglx/multi_trigger_multi_gate') rec = read_spikeglx(local_folder_path) + + + +LEGACY objects +-------------- + +WaveformExtractor +^^^^^^^^^^^^^^^^^ + +This is now a legacy object that can still be accessed through the :py:class:`MockWaveformExtractor`. It is kept +for backward compatibility. + +The :py:class:`~spikeinterface.core.WaveformExtractor` class is the core object to combine a +:py:class:`~spikeinterface.core.BaseRecording` and a :py:class:`~spikeinterface.core.BaseSorting` object. +Waveforms are very important for additional analyses, and the basis of several postprocessing and quality metrics +computations. + +The :py:class:`~spikeinterface.core.WaveformExtractor` allows us to: + +* extract waveforms +* sub-sample spikes for waveform extraction +* compute templates (i.e. average extracellular waveforms) with different modes +* save waveforms in a folder (in numpy / `Zarr `_) for easy retrieval +* save sparse waveforms or *sparsify* dense waveforms +* select units and associated waveforms + +In the default format (:code:`mode='folder'`) waveforms are saved to a folder structure with waveforms as +:code:`.npy` files. +In addition, waveforms can also be extracted in-memory for fast computations (:code:`mode='memory'`). +Note that this mode can quickly fill up your RAM... Use it wisely! +Finally, an existing :py:class:`~spikeinterface.core.WaveformExtractor` can be saved also in :code:`zarr` format. + +.. code-block:: python + + # extract dense waveforms on 500 spikes per unit + we = extract_waveforms(recording=recording, + sorting=sorting, + sparse=False, + folder="waveforms", + max_spikes_per_unit=500 + overwrite=True) + # same, but with parallel processing! (1s chunks processed by 8 jobs) + job_kwargs = dict(n_jobs=8, chunk_duration="1s") + we = extract_waveforms(recording=recording, + sorting=sorting, + sparse=False, + folder="waveforms_parallel", + max_spikes_per_unit=500, + overwrite=True, + **job_kwargs) + # same, but in-memory + we_mem = extract_waveforms(recording=recording, + sorting=sorting, + sparse=False, + folder=None, + mode="memory", + max_spikes_per_unit=500, + **job_kwargs) + + # load pre-computed waveforms + we_loaded = load_waveforms(folder="waveforms") + + # retrieve waveforms and templates for a unit + waveforms0 = we.get_waveforms(unit_id=unit0) + template0 = we.get_template(unit_id=unit0) + + # compute template standard deviations (average is computed by default) + # (this can also be done within the 'extract_waveforms') + we.precompute_templates(modes=("std",)) + + # retrieve all template means and standard deviations + template_means = we.get_all_templates(mode="average") + template_stds = we.get_all_templates(mode="std") + + # save to Zarr + we_zarr = we.save(folder="waveforms_zarr", format="zarr") + + # extract sparse waveforms (see Sparsity section) + # this will use 50 spikes per unit to estimate the sparsity within a 40um radius from that unit + we_sparse = extract_waveforms(recording=recording, + sorting=sorting, + folder="waveforms_sparse", + max_spikes_per_unit=500, + method="radius", + radius_um=40, + num_spikes_for_sparsity=50) + + +**IMPORTANT:** to load a waveform extractor object from disk, it needs to be able to reload the associated +:code:`sorting` object (the :code:`recording` is optional, using :code:`with_recording=False`). +In order to make a waveform folder portable (e.g. copied to another location or machine), one can do: + +.. code-block:: python + + # create a "processed" folder + processed_folder = Path("processed") + + # save the sorting object in the "processed" folder + sorting = sorting.save(folder=processed_folder / "sorting") + # extract waveforms using relative paths + we = extract_waveforms(recording=recording, + sorting=sorting, + folder=processed_folder / "waveforms", + use_relative_path=True) + # the "processed" folder is now portable, and the waveform extractor can be reloaded + # from a different location/machine (without loading the recording) + we_loaded = si.load_waveforms(folder=processed_folder / "waveforms", + with_recording=False) diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 4e8dd88be5..a3f1959224 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -5,30 +5,30 @@ Postprocessing module After spike sorting, we can use the :py:mod:`~spikeinterface.postprocessing` module to further post-process the spike sorting output. Most of the post-processing functions require a -:py:class:`~spikeinterface.core.WaveformExtractor` as input. +:py:class:`~spikeinterface.core.SortingAnalyzer` as input. .. _waveform_extensions: -WaveformExtractor extensions +ResultExtensions ---------------------------- There are several postprocessing tools available, and all -of them are implemented as a :py:class:`~spikeinterface.core.BaseWaveformExtractorExtension`. All computations on top -of a :code:`WaveformExtractor` will be saved along side the :code:`WaveformExtractor` itself (sub folder, zarr path or sub dict). +of them are implemented as a :py:class:`~spikeinterface.core.ResultExtension`. All computations on top +of a :code:`SortingAnalyzer` will be saved along side the :code:`SortingAnalyzer` itself (sub folder, zarr path or sub dict). This workflow is convenient for retrieval of time-consuming computations (such as pca or spike amplitudes) when reloading a -:code:`WaveformExtractor`. +:code:`SortingAnalyzer`. -:py:class:`~spikeinterface.core.BaseWaveformExtractorExtension` objects are tightly connected to the -parent :code:`WaveformExtractor` object, so that operations done on the :code:`WaveformExtractor`, such as saving, +:py:class:`~spikeinterface.core.ResultExtension` objects are tightly connected to the +parent :code:`SortingAnalyzer` object, so that operations done on the :code:`SortingAnalyzer`, such as saving, loading, or selecting units, will be automatically applied to all extensions. -To check what extensions are available for a :code:`WaveformExtractor` named :code:`we`, you can use: +To check what extensions are available for a :code:`SortingAnalyzer` named :code:`sorting_analyzer`, you can use: .. code-block:: python import spikeinterface as si - available_extension_names = we.get_available_extension_names() + available_extension_names = sorting_analyzer.get_load_extension_names() print(available_extension_names) .. code-block:: bash @@ -40,7 +40,7 @@ To load the extension object you can run: .. code-block:: python - ext = we.load_extension("spike_amplitudes") + ext = sorting_analyzer.get_extension("spike_amplitudes") ext_data = ext.get_data() Here :code:`ext` is the extension object (in this case the :code:`SpikeAmplitudeCalculator`), and :code:`ext_data` will @@ -52,13 +52,9 @@ We can also delete an extension: .. code-block:: python - we.delete_extension("spike_amplitudes") + sorting_analyzer.delete_extension("spike_amplitudes") - -Finally, the waveform extensions can be loaded rather than recalculated by using the :code:`load_if_exists` argument in -any post-processing function. - Available postprocessing extensions ----------------------------------- @@ -66,7 +62,7 @@ noise_levels ^^^^^^^^^^^^ This extension computes the noise level of each channel using the median absolute deviation. -As an extension, this expects the :code:`WaveformExtractor` as input and the computed values are persistent on disk. +As an extension, this expects the :code:`Recording` as input and the computed values are persistent on disk. The :py:func:`~spikeinterface.core.get_noise_levels(recording)` computes the same values, but starting from a recording and without saving the data as an extension. @@ -74,10 +70,9 @@ and without saving the data as an extension. .. code-block:: python - noise = compute_noise_level(waveform_extractor=we) + noise = compute_noise_level(recording=recording) -For more information, see :py:func:`~spikeinterface.postprocessing.compute_noise_levels` @@ -95,9 +90,9 @@ For dense waveforms, sparsity can also be passed as an argument. .. code-block:: python - pc = compute_principal_components(waveform_extractor=we, - n_components=3, - mode="by_channel_local") + sorting_analyzer.compute(input="principal_compoents", + n_components=3, + mode="by_channel_local") For more information, see :py:func:`~spikeinterface.postprocessing.compute_principal_components` @@ -112,7 +107,7 @@ and is not well suited for high-density probes. .. code-block:: python - similarity = compute_template_similarity(waveform_extractor=we, method='cosine_similarity') + sorting_analyzer.compute(input="template_similarity", method='cosine_similarity') For more information, see :py:func:`~spikeinterface.postprocessing.compute_template_similarity` @@ -130,9 +125,9 @@ each spike. .. code-block:: python - amplitudes = computer_spike_amplitudes(waveform_extractor=we, - peak_sign="neg", - outputs="concatenated") + sorting_analyzer.compute(input="spike_amplitudes", + peak_sign="neg", + outputs="concatenated") For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes` @@ -150,15 +145,15 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), .. code-block:: python - spike_locations = compute_spike_locations(waveform_extractor=we, - ms_before=0.5, - ms_after=0.5, - spike_retriever_kwargs=dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg" + sorting_analyzer.compute(input="spike_locations", + ms_before=0.5, + ms_after=0.5, + spike_retriever_kwargs=dict( + channel_from_template=True, + radius_um=50, + peak_sign="neg" ), - method="center_of_mass") + method="center_of_mass") For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_locations` @@ -175,8 +170,7 @@ based on individual waveforms, it calculates at the unit level using templates. .. code-block:: python - unit_locations = compute_unit_locations(waveform_extractor=we, - method="monopolar_triangulation") + sorting_analyzer.compute(input="unit_locations", method="monopolar_triangulation") For more information, see :py:func:`~spikeinterface.postprocessing.compute_unit_locations` @@ -219,10 +213,10 @@ with shape (num_units, num_units, num_bins) with all correlograms for each pair .. code-block:: python - ccgs, bins = compute_correlograms(waveform_or_sorting_extractor=we, - window_ms=50.0, - bin_ms=1.0, - method="auto") + sorting_analyer.compute(input="correlograms", + window_ms=50.0, + bin_ms=1.0, + method="auto") For more information, see :py:func:`~spikeinterface.postprocessing.compute_correlograms` @@ -236,10 +230,10 @@ This extension computes the histograms of inter-spike-intervals. The computed ou .. code-block:: python - isi_histogram, bins = compute_isi_histograms(waveform_or_sorting_extractor=we, - window_ms=50.0, - bin_ms=1.0, - method="auto") + sorting_analyer.compute_isi_histograms(input="isi_histograms" + window_ms=50.0, + bin_ms=1.0, + method="auto") For more information, see :py:func:`~spikeinterface.postprocessing.compute_isi_histograms` diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index 962de2dfd8..42db0e645f 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -48,17 +48,19 @@ This code snippet shows how to compute quality metrics (with or without principa .. code-block:: python - we = si.load_waveforms(folder='waveforms') # start from a waveform extractor + sorting_analyzer = si.load_sorting_analyzer(folder='waveforms') # start from a sorting_analyzer - # without PC - metrics = compute_quality_metrics(waveform_extractor=we, metric_names=['snr']) + # without PC (depends on "waveforms", "templates", and "noise_levels") + sorting_analyzer.compute(input="quality_metrics", metric_names=['snr'], skip_pc_metrics=False) + metrics = sorting_analyzer.get_extension(extension_name="quality_metrics") assert 'snr' in metrics.columns - # with PCs - from spikeinterface.postprocessing import compute_principal_components - pca = compute_principal_components(waveform_extractor=we, n_components=5, mode='by_channel_local') - metrics = compute_quality_metrics(waveform_extractor=we) - assert 'isolation_distance' in metrics.columns + # with PCs (depends on "pca" in addition to the above metrics) + + sorting_analyzer.compute(input={"pca": dict(n_components=5, mode="by_channel_local"), + "quality_metrics": dict(skip_pc_metrics=False)}) + + For more information about quality metrics, check out this excellent `documentation `_ From 1ee81b7f42dea37c1edda0db6878833066431703 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:45:32 -0500 Subject: [PATCH 104/192] fix a couple typos --- doc/modules/postprocessing.rst | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index a3f1959224..1b0162aa36 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -64,10 +64,6 @@ noise_levels This extension computes the noise level of each channel using the median absolute deviation. As an extension, this expects the :code:`Recording` as input and the computed values are persistent on disk. -The :py:func:`~spikeinterface.core.get_noise_levels(recording)` computes the same values, but starting from a recording -and without saving the data as an extension. - - .. code-block:: python noise = compute_noise_level(recording=recording) @@ -90,7 +86,7 @@ For dense waveforms, sparsity can also be passed as an argument. .. code-block:: python - sorting_analyzer.compute(input="principal_compoents", + sorting_analyzer.compute(input="principal_components", n_components=3, mode="by_channel_local") @@ -230,10 +226,10 @@ This extension computes the histograms of inter-spike-intervals. The computed ou .. code-block:: python - sorting_analyer.compute_isi_histograms(input="isi_histograms" - window_ms=50.0, - bin_ms=1.0, - method="auto") + sorting_analyer.compute(input="isi_histograms" + window_ms=50.0, + bin_ms=1.0, + method="auto") For more information, see :py:func:`~spikeinterface.postprocessing.compute_isi_histograms` From 09a489a0dd460e37507fb0510971dfc8f3452330 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sat, 17 Feb 2024 08:29:21 +0100 Subject: [PATCH 105/192] default for correlograms --- src/spikeinterface/postprocessing/correlograms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 12a9c8d42f..bf6f8585d9 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -20,9 +20,9 @@ class ComputeCorrelograms(ResultExtension): ---------- sorting_analyzer: SortingAnalyzer A SortingAnalyzer object - window_ms : float, default: 100.0 + window_ms : float, default: 50.0 The window in ms - bin_ms : float, default: 5 + bin_ms : float, default: 1.0 The bin size in ms method : "auto" | "numpy" | "numba", default: "auto" If "auto" and numba is installed, numba is used, otherwise numpy is used @@ -55,7 +55,7 @@ class ComputeCorrelograms(ResultExtension): def __init__(self, sorting_analyzer): ResultExtension.__init__(self, sorting_analyzer) - def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str = "auto"): + def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) return params From 6f6691141c96bd560f12e2056f099d4c970072a0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sat, 17 Feb 2024 08:47:51 +0100 Subject: [PATCH 106/192] isi default --- src/spikeinterface/postprocessing/isi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 9bb1a3a1ba..c2cf9cfd7d 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -43,7 +43,7 @@ class ComputeISIHistograms(ResultExtension): def __init__(self, sorting_analyzer): ResultExtension.__init__(self, sorting_analyzer) - def _set_params(self, window_ms: float = 100.0, bin_ms: float = 5.0, method: str = "auto"): + def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) return params From decaa2cc0810b536633a8f8d2834c1371c0ac4a9 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Sat, 17 Feb 2024 08:18:14 -0500 Subject: [PATCH 107/192] Update miscmetrics docs to say sorting_analyzer --- doc/modules/qualitymetrics/amplitude_cutoff.rst | 5 +++-- doc/modules/qualitymetrics/amplitude_cv.rst | 8 ++++---- doc/modules/qualitymetrics/amplitude_median.rst | 4 ++-- doc/modules/qualitymetrics/drift.rst | 6 +++--- doc/modules/qualitymetrics/firing_range.rst | 4 ++-- doc/modules/qualitymetrics/firing_rate.rst | 4 ++-- doc/modules/qualitymetrics/isi_violations.rst | 4 ++-- doc/modules/qualitymetrics/presence_ratio.rst | 4 ++-- doc/modules/qualitymetrics/sd_ratio.rst | 3 ++- doc/modules/qualitymetrics/sliding_rp_violations.rst | 4 ++-- doc/modules/qualitymetrics/snr.rst | 4 ++-- doc/modules/qualitymetrics/synchrony.rst | 4 ++-- 12 files changed, 28 insertions(+), 26 deletions(-) diff --git a/doc/modules/qualitymetrics/amplitude_cutoff.rst b/doc/modules/qualitymetrics/amplitude_cutoff.rst index a1e4d85d01..207ca115fd 100644 --- a/doc/modules/qualitymetrics/amplitude_cutoff.rst +++ b/doc/modules/qualitymetrics/amplitude_cutoff.rst @@ -23,9 +23,10 @@ Example code import spikeinterface.qualitymetrics as sqm - # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` + # Combine sorting and recording into a sorting_analyzer + # It is also recommended to run sorting_analyzer.compute(input="spike_amplitudes") # in order to use amplitudes from all spikes - fraction_missing = sqm.compute_amplitude_cutoffs(wvf_extractor, peak_sign="neg") + fraction_missing = sqm.compute_amplitude_cutoffs(sorting_analyzer=sorting_analyzer, peak_sign="neg") # fraction_missing is a dict containing the unit IDs as keys, # and their estimated fraction of missing spikes as values. diff --git a/doc/modules/qualitymetrics/amplitude_cv.rst b/doc/modules/qualitymetrics/amplitude_cv.rst index 81d3b4f12d..2ad51aab2a 100644 --- a/doc/modules/qualitymetrics/amplitude_cv.rst +++ b/doc/modules/qualitymetrics/amplitude_cv.rst @@ -34,10 +34,10 @@ Example code import spikeinterface.qualitymetrics as sqm - # Make recording, sorting and wvf_extractor object for your data. - # It is required to run `compute_spike_amplitudes(wvf_extractor)` or - # `compute_amplitude_scalings(wvf_extractor)` (if missing, values will be NaN) - amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(waveform_extractor=wvf_extractor) + # Combine a sorting and recording into a sorting_analyzer + # It is required to run sorting_analyzer.compute(input="spike_amplitudes") or + # sorting_analyzer.compute(input="amplitude_scalings") (if missing, values will be NaN) + amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(sorting_analyzer=sorting_analyzer) # amplitude_cv_median and amplitude_cv_range are dicts containing the unit ids as keys, # and their amplitude_cv metrics as values. diff --git a/doc/modules/qualitymetrics/amplitude_median.rst b/doc/modules/qualitymetrics/amplitude_median.rst index c77a57b033..1e4eec2e40 100644 --- a/doc/modules/qualitymetrics/amplitude_median.rst +++ b/doc/modules/qualitymetrics/amplitude_median.rst @@ -22,9 +22,9 @@ Example code import spikeinterface.qualitymetrics as sqm - # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` + # It is also recommended to run sorting_analyzer.compute(input="spike_amplitudes") # in order to use amplitude values from all spikes. - amplitude_medians = sqm.compute_amplitude_medians(waveform_extractor=wvf_extractor) + amplitude_medians = sqm.compute_amplitude_medians(sorting_analyzer) # amplitude_medians is a dict containing the unit IDs as keys, # and their estimated amplitude medians as values. diff --git a/doc/modules/qualitymetrics/drift.rst b/doc/modules/qualitymetrics/drift.rst index dad2aafe7c..8f95f74695 100644 --- a/doc/modules/qualitymetrics/drift.rst +++ b/doc/modules/qualitymetrics/drift.rst @@ -42,10 +42,10 @@ Example code import spikeinterface.qualitymetrics as sqm - # Make recording, sorting and wvf_extractor object for your data. - # It is required to run `compute_spike_locations(wvf_extractor) first` + # Combine sorting and recording into sorting_analyzer + # It is required to run sorting_analyzer.compute(input="spike_locations") first # (if missing, values will be NaN) - drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(waveform_extractor=wvf_extractor, peak_sign="neg") + drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(sorting_analyzer=sorting_analyzer peak_sign="neg") # drift_ptps, drift_stds, and drift_mads are each a dict containing the unit IDs as keys, # and their metrics as values. diff --git a/doc/modules/qualitymetrics/firing_range.rst b/doc/modules/qualitymetrics/firing_range.rst index 1cbd903c7a..d059f4eac6 100644 --- a/doc/modules/qualitymetrics/firing_range.rst +++ b/doc/modules/qualitymetrics/firing_range.rst @@ -23,8 +23,8 @@ Example code import spikeinterface.qualitymetrics as sqm - # Make recording, sorting and wvf_extractor object for your data. - firing_range = sqm.compute_firing_ranges(waveform_extractor=wvf_extractor) + # Combine a sorting and recording into a sorting_analyzer + firing_range = sqm.compute_firing_ranges(sorting_analyzer=sorting_analyzer) # firing_range is a dict containing the unit IDs as keys, # and their firing firing_range as values (in Hz). diff --git a/doc/modules/qualitymetrics/firing_rate.rst b/doc/modules/qualitymetrics/firing_rate.rst index ef8cb3d8f4..953901dd38 100644 --- a/doc/modules/qualitymetrics/firing_rate.rst +++ b/doc/modules/qualitymetrics/firing_rate.rst @@ -39,8 +39,8 @@ With SpikeInterface: import spikeinterface.qualitymetrics as sqm - # Make recording, sorting and wvf_extractor object for your data. - firing_rate = sqm.compute_firing_rates(waveform_extractor=wvf_extractor) + # Combine a sorting and recording into a sorting_analyzer + firing_rate = sqm.compute_firing_rates(sorting_analyzer) # firing_rate is a dict containing the unit IDs as keys, # and their firing rates across segments as values (in Hz). diff --git a/doc/modules/qualitymetrics/isi_violations.rst b/doc/modules/qualitymetrics/isi_violations.rst index 725d9b0fd6..e30a2334d5 100644 --- a/doc/modules/qualitymetrics/isi_violations.rst +++ b/doc/modules/qualitymetrics/isi_violations.rst @@ -79,9 +79,9 @@ With SpikeInterface: import spikeinterface.qualitymetrics as sqm - # Make recording, sorting and wvf_extractor object for your data. + # Combine sorting and recording into sorting_analyzer - isi_violations_ratio, isi_violations_count = sqm.compute_isi_violations(wvf_extractor, isi_threshold_ms=1.0) + isi_violations_ratio, isi_violations_count = sqm.compute_isi_violations(sorting_analyzer=sorting_analyzer, isi_threshold_ms=1.0) References ---------- diff --git a/doc/modules/qualitymetrics/presence_ratio.rst b/doc/modules/qualitymetrics/presence_ratio.rst index ad0766d37c..e925c6e325 100644 --- a/doc/modules/qualitymetrics/presence_ratio.rst +++ b/doc/modules/qualitymetrics/presence_ratio.rst @@ -25,9 +25,9 @@ Example code import spikeinterface.qualitymetrics as sqm - # Make recording, sorting and wvf_extractor object for your data. + # Combine sorting and recording into a sorting_analyzer - presence_ratio = sqm.compute_presence_ratios(waveform_extractor=wvf_extractor) + presence_ratio = sqm.compute_presence_ratios(sorting_analyzer=sorting_analyzer) # presence_ratio is a dict containing the unit IDs as keys # and their presence ratio (between 0 and 1) as values. diff --git a/doc/modules/qualitymetrics/sd_ratio.rst b/doc/modules/qualitymetrics/sd_ratio.rst index 0ee3a3fa12..260a2ec38e 100644 --- a/doc/modules/qualitymetrics/sd_ratio.rst +++ b/doc/modules/qualitymetrics/sd_ratio.rst @@ -28,7 +28,8 @@ Example code import spikeinterface.qualitymetrics as sqm - sd_ratio = sqm.compute_sd_ratio(wvf_extractor, censored_period_ms=4.0) + # In this case we need to combine our sorting and recording into a sorting_analyzer + sd_ratio = sqm.compute_sd_ratio(sorting_analyzer=sorting_analyzer censored_period_ms=4.0) References diff --git a/doc/modules/qualitymetrics/sliding_rp_violations.rst b/doc/modules/qualitymetrics/sliding_rp_violations.rst index fd53d7da3b..1913062cd9 100644 --- a/doc/modules/qualitymetrics/sliding_rp_violations.rst +++ b/doc/modules/qualitymetrics/sliding_rp_violations.rst @@ -29,9 +29,9 @@ With SpikeInterface: import spikeinterface.qualitymetrics as sqm - # Make recording, sorting and wvf_extractor object for your data. + # Combine sorting and recording into a sorting_analyzer - contamination = sqm.compute_sliding_rp_violations(waveform_extractor=wvf_extractor, bin_size_ms=0.25) + contamination = sqm.compute_sliding_rp_violations(sorting_analyzer=sorting_analyzer, bin_size_ms=0.25) References ---------- diff --git a/doc/modules/qualitymetrics/snr.rst b/doc/modules/qualitymetrics/snr.rst index 7f27a5078a..e640ec026f 100644 --- a/doc/modules/qualitymetrics/snr.rst +++ b/doc/modules/qualitymetrics/snr.rst @@ -43,8 +43,8 @@ With SpikeInterface: import spikeinterface.qualitymetrics as sqm - # Make recording, sorting and wvf_extractor object for your data. - SNRs = sqm.compute_snrs(waveform_extractor=wvf_extractor) + # Combining sorting and recording into a sorting_analzyer + SNRs = sqm.compute_snrs(sorting_analzyer=sorting_analzyer) # SNRs is a dict containing the unit IDs as keys and their SNRs as values. Links to original implementations diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index d1a3c70a97..41c92dd99e 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -28,8 +28,8 @@ Example code .. code-block:: python import spikeinterface.qualitymetrics as sqm - # Make recording, sorting and wvf_extractor object for your data. - synchrony = sqm.compute_synchrony_metrics(waveform_extractor=wvf_extractor, synchrony_sizes=(2, 4, 8)) + # Combine a sorting and recording into a sorting_analyzer + synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8)) # synchrony is a tuple of dicts with the synchrony metrics for each unit From f053eb08c6795a1efda2290b33a2303d3b68e5e8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 20 Feb 2024 12:15:01 +0100 Subject: [PATCH 108/192] various fix --- src/spikeinterface/comparison/groundtruthstudy.py | 6 ++++-- src/spikeinterface/postprocessing/isi.py | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index fef06f4fe4..112e859dad 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -299,9 +299,11 @@ def create_sorting_analyzer_gt(self, case_keys=None, **kwargs): # the waveforms depend on the dataset key folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] - sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binray_folder", folder=folder) + sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) sorting_analyzer.select_random_spikes(**select_params) - sorting_analyzer.compute("fast_templates", **job_kwargs) + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("noise_levels") def get_waveform_extractor(self, case_key=None, dataset_key=None): if case_key is not None: diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index c2cf9cfd7d..2b39a376e0 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -103,7 +103,7 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float window_size = int(round(fs * window_ms * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size - bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs + bins = np.arange(0, window_size + bin_size, bin_size)# * 1e3 / fs ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64) # TODO: There might be a better way than a double for loop? @@ -113,7 +113,7 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float ISI = np.histogram(np.diff(spike_train), bins=bins)[0] ISIs[i] += ISI - return ISIs, bins + return ISIs, bins * 1e3 / fs def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float = 1.0): @@ -137,7 +137,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size - bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs + bins = np.arange(0, window_size + bin_size, bin_size)# * 1e3 / fs spikes = sorting.to_spike_vector(concatenated=False) ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64) @@ -153,13 +153,13 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float bins, ) - return ISIs, bins + return ISIs, bins * 1e3 / fs if HAVE_NUMBA: @numba.jit( - (numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.float64[::1]), + (numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.int64[::1]), nopython=True, nogil=True, cache=True, From f55a03b4c15af91e372c0e4a3b9115adb55845fa Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 20 Feb 2024 15:11:13 +0100 Subject: [PATCH 109/192] Initial benchmark components refactor --- .../benchmark/benchmark_tools.py | 246 +++++++++++- .../tests/test_benchmark_matching.py | 358 +++++++++--------- 2 files changed, 423 insertions(+), 181 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 2697556290..5fda48073f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -5,9 +5,251 @@ import json import numpy as np +import os -from spikeinterface.core import load_waveforms, NpzSortingExtractor from spikeinterface.core.core_tools import check_json +from spikeinterface import load_extractor, split_job_kwargs, create_sorting_analyzer, load_sorting_analyzer + +import pickle + +_key_separator = "_-°°-_" + +class BenchmarkStudy: + """ + Manage a list of Benchmark + """ + benchmark_class = None + def __init__(self, study_folder): + self.folder = Path(study_folder) + self.datasets = {} + self.cases = {} + self.results = {} + self.scan_folder() + + @classmethod + def create(cls, study_folder, datasets={}, cases={}, levels=None): + # check that cases keys are homogeneous + key0 = list(cases.keys())[0] + if isinstance(key0, str): + assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" + if levels is None: + levels = "level0" + else: + assert isinstance(levels, str) + elif isinstance(key0, tuple): + assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" + num_levels = len(key0) + assert all( + len(key) == num_levels for key in cases.keys() + ), "Keys for cases are not homogeneous, tuple negth differ" + if levels is None: + levels = [f"level{i}" for i in range(num_levels)] + else: + levels = list(levels) + assert len(levels) == num_levels + else: + raise ValueError("Keys for cases must str or tuple") + + study_folder = Path(study_folder) + study_folder.mkdir(exist_ok=False, parents=True) + + (study_folder / "datasets").mkdir() + (study_folder / "datasets" / "recordings").mkdir() + (study_folder / "datasets" / "gt_sortings").mkdir() + (study_folder / "run_logs").mkdir() + (study_folder / "metrics").mkdir() + (study_folder / "results").mkdir() + + for key, (rec, gt_sorting) in datasets.items(): + assert "/" not in key, "'/' cannot be in the key name!" + assert "\\" not in key, "'\\' cannot be in the key name!" + + # recordings are pickled + rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") + + # sortings are pickled + saved as NumpyFolderSorting + gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") + gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") + + info = {} + info["levels"] = levels + (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") + + # cases is dumped to a pickle file, json is not possible because of the tuple key + (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) + + return cls(study_folder) + + def scan_folder(self): + if not (self.folder / "datasets").exists(): + raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") + + with open(self.folder / "info.json", "r") as f: + self.info = json.load(f) + + self.levels = self.info["levels"] + + for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): + key = rec_file.stem + rec = load_extractor(rec_file) + gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) + self.datasets[key] = (rec, gt_sorting) + + with open(self.folder / "cases.pickle", "rb") as f: + self.cases = pickle.load(f) + + self.results = {} + for key in self.cases: + result_folder = self.folder / "results" / self.key_to_str(key) + if result_folder.exists(): + self.results[key] = self.benchmark_class.load_folder(result_folder) + else: + self.results[key] = None + + def __repr__(self): + t = f"{self.__class__.__name__} {self.folder.stem} \n" + t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" + t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" + num_computed = sum([1 for result in self.results.values() if result is not None]) + t += f" computed: {num_computed}\n" + return t + + def key_to_str(self, key): + if isinstance(key, str): + return key + elif isinstance(key, tuple): + return _key_separator.join(key) + else: + raise ValueError("Keys for cases must str or tuple") + + def remove_result(self, key): + result_folder = self.folder / "results" / self.key_to_str(key) + log_file = self.folder / "run_logs" / f"{self.key_to_str(key)}.json" + + if result_folder.exists(): + shutil.rmtree(result_folder) + for f in (log_file, ): + if f.exists(): + f.unlink() + self.results[key] = None + + def run(self, case_keys=None, keep=True, verbose=False): + if case_keys is None: + case_keys = self.cases.keys() + + job_keys = [] + for key in case_keys: + + result_folder = self.folder / "results" / self.key_to_str(key) + + if keep and result_folder.exists(): + continue + elif not keep and result_folder.exists(): + self.remove_result(key) + job_keys.append(key) + + self._run(job_keys) + + # save log + # TODO + + def _run(self, job_keys): + raise NotImplemented + + + def create_sorting_analyzer_gt(self, case_keys=None, **kwargs): + if case_keys is None: + case_keys = self.cases.keys() + + select_params, job_kwargs = split_job_kwargs(kwargs) + + base_folder = self.folder / "sorting_analyzer" + base_folder.mkdir(exist_ok=True) + + dataset_keys = [self.cases[key]["dataset"] for key in case_keys] + dataset_keys = set(dataset_keys) + for dataset_key in dataset_keys: + # the waveforms depend on the dataset key + folder = base_folder / self.key_to_str(dataset_key) + recording, gt_sorting = self.datasets[dataset_key] + sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) + sorting_analyzer.select_random_spikes(**select_params) + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("noise_levels") + + def get_sorting_analyzer(self, case_key=None, dataset_key=None): + if case_key is not None: + dataset_key = self.cases[case_key]["dataset"] + + folder = self.folder / "sorting_analyzer" / self.key_to_str(dataset_key) + sorting_analyzer = load_sorting_analyzer(folder) + return sorting_analyzer + + def get_templates(self, key, operator="average"): + sorting_analyzer = self.get_sorting_analyzer(case_key=key) + templates = sorting_analyzer.get_extenson("templates").get_data(operator=operator) + return templates + + def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): + if case_keys is None: + case_keys = self.cases.keys() + + done = [] + for key in case_keys: + dataset_key = self.cases[key]["dataset"] + if dataset_key in done: + # some case can share the same waveform extractor + continue + done.append(dataset_key) + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + if filename.exists(): + if force: + os.remove(filename) + else: + continue + sorting_analyzer = self.get_sorting_analyzer(key) + qm_ext = sorting_analyzer.compute("quality_metrics", metric_names=metric_names) + metrics = qm_ext.get_data() + metrics.to_csv(filename, sep="\t", index=True) + + def get_metrics(self, key): + import pandas as pd + + dataset_key = self.cases[key]["dataset"] + + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + if not filename.exists(): + return + metrics = pd.read_csv(filename, sep="\t", index_col=0) + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + metrics.index = gt_sorting.unit_ids + return metrics + + def get_units_snr(self, key): + """ """ + return self.get_metrics(key)["snr"] + + +class Benchmark: + """ + """ + @classmethod + def load_folder(cls): + raise NotImplementedError + + def save_to_folder(self, folder): + raise NotImplementedError + + def run(self): + # run method and metrics!! + raise NotImplementedError + + + + + def _simpleaxis(ax): @@ -17,7 +259,7 @@ def _simpleaxis(ax): ax.get_yaxis().tick_left() -class BenchmarkBase: +class BenchmarkBaseOld: _array_names = () _waveform_names = () _sorting_names = () diff --git a/src/spikeinterface/sortingcomponents/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/tests/test_benchmark_matching.py index ad944c921c..6e7f1c0d8a 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_benchmark_matching.py @@ -1,179 +1,179 @@ -import pytest -import numpy as np -import pandas as pd -import shutil -import os -from pathlib import Path - -import spikeinterface.core as sc -import spikeinterface.extractors as se -import spikeinterface.preprocessing as spre -from spikeinterface.sortingcomponents.benchmark import benchmark_matching - - -@pytest.fixture(scope="session") -def benchmark_and_kwargs(tmp_path_factory): - recording, sorting = se.toy_example(duration=1, num_channels=2, num_units=2, num_segments=1, firing_rate=10, seed=0) - recording = spre.common_reference(recording, dtype="float32") - we_path = tmp_path_factory.mktemp("waveforms") - sort_path = tmp_path_factory.mktemp("sortings") / ("sorting.npz") - se.NpzSortingExtractor.write_sorting(sorting, sort_path) - sorting = se.NpzSortingExtractor(sort_path) - we = sc.extract_waveforms(recording, sorting, we_path, overwrite=True) - templates = we.get_all_templates() - noise_levels = sc.get_noise_levels(recording, return_scaled=False) - methods_kwargs = { - "tridesclous": dict(waveform_extractor=we, noise_levels=noise_levels), - "wobble": dict(templates=templates, nbefore=we.nbefore, nafter=we.nafter, parameters={"approx_rank": 2}), - } - methods = list(methods_kwargs.keys()) - benchmark = benchmark_matching.BenchmarkMatching(recording, sorting, we, methods, methods_kwargs) - return benchmark, methods_kwargs - - -@pytest.mark.parametrize( - "parameters, parameter_name", - [ - ([1, 10, 100], "num_spikes"), - ([0, 0.5, 1], "fraction_misclassed"), - ([0, 0.5, 1], "fraction_missing"), - ], -) -def test_run_matching_vary_parameter(benchmark_and_kwargs, parameters, parameter_name): - # Arrange - benchmark, methods_kwargs = benchmark_and_kwargs - num_replicates = 2 - - # Act - with benchmark as bmk: - matching_df = bmk.run_matching_vary_parameter(parameters, parameter_name, num_replicates=num_replicates) - - # Assert - assert matching_df.shape[0] == len(parameters) * num_replicates * len(methods_kwargs) - assert matching_df.shape[1] == 6 - - -@pytest.mark.parametrize( - "parameter_name, num_replicates", - [ - ("invalid_parameter_name", 1), - ("num_spikes", -1), - ("num_spikes", 0.5), - ], -) -def test_run_matching_vary_parameter_invalid_inputs(benchmark_and_kwargs, parameter_name, num_replicates): - parameters = [1, 2] - benchmark, methods_kwargs = benchmark_and_kwargs - with benchmark as bmk: - with pytest.raises(ValueError): - bmk.run_matching_vary_parameter(parameters, parameter_name, num_replicates=num_replicates) - - -@pytest.mark.parametrize( - "fraction_misclassed, min_similarity", - [ - (-1, -1), - (2, -1), - (0, 2), - ], -) -def test_run_matching_misclassed_invalid_inputs(benchmark_and_kwargs, fraction_misclassed, min_similarity): - benchmark, methods_kwargs = benchmark_and_kwargs - with benchmark as bmk: - with pytest.raises(ValueError): - bmk.run_matching_misclassed(fraction_misclassed, min_similarity=min_similarity) - - -@pytest.mark.parametrize( - "fraction_missing, snr_threshold", - [ - (-1, 0), - (2, 0), - (0, -1), - ], -) -def test_run_matching_missing_units_invalid_inputs(benchmark_and_kwargs, fraction_missing, snr_threshold): - benchmark, methods_kwargs = benchmark_and_kwargs - with benchmark as bmk: - with pytest.raises(ValueError): - bmk.run_matching_missing_units(fraction_missing, snr_threshold=snr_threshold) - - -def test_compare_all_sortings(benchmark_and_kwargs): - # Arrange - benchmark, methods_kwargs = benchmark_and_kwargs - parameter_name = "num_spikes" - num_replicates = 2 - num_spikes = [1, 10, 100] - rng = np.random.default_rng(0) - sortings, gt_sortings, parameter_values, parameter_names, iter_nums, methods = [], [], [], [], [], [] - for replicate in range(num_replicates): - for spike_num in num_spikes: - for method in list(methods_kwargs.keys()): - len_spike_train = 100 - spike_time_inds = rng.choice(benchmark.recording.get_num_frames(), len_spike_train, replace=False) - unit_ids = rng.choice(benchmark.gt_sorting.get_unit_ids(), len_spike_train, replace=True) - sort_index = np.argsort(spike_time_inds) - spike_time_inds = spike_time_inds[sort_index] - unit_ids = unit_ids[sort_index] - sorting = sc.NumpySorting.from_times_labels( - spike_time_inds, unit_ids, benchmark.recording.sampling_frequency - ) - spike_time_inds = rng.choice(benchmark.recording.get_num_frames(), len_spike_train, replace=False) - unit_ids = rng.choice(benchmark.gt_sorting.get_unit_ids(), len_spike_train, replace=True) - sort_index = np.argsort(spike_time_inds) - spike_time_inds = spike_time_inds[sort_index] - unit_ids = unit_ids[sort_index] - gt_sorting = sc.NumpySorting.from_times_labels( - spike_time_inds, unit_ids, benchmark.recording.sampling_frequency - ) - sortings.append(sorting) - gt_sortings.append(gt_sorting) - parameter_values.append(spike_num) - parameter_names.append(parameter_name) - iter_nums.append(replicate) - methods.append(method) - matching_df = pd.DataFrame( - { - "sorting": sortings, - "gt_sorting": gt_sortings, - "parameter_value": parameter_values, - "parameter_name": parameter_names, - "iter_num": iter_nums, - "method": methods, - } - ) - comparison_from_df = matching_df.copy() - comparison_from_self = matching_df.copy() - comparison_collision = matching_df.copy() - - # Act - benchmark.compare_all_sortings(comparison_from_df, ground_truth="from_df") - benchmark.compare_all_sortings(comparison_from_self, ground_truth="from_self") - benchmark.compare_all_sortings(comparison_collision, collision=True) - - # Assert - for comparison in [comparison_from_df, comparison_from_self, comparison_collision]: - assert comparison.shape[0] == len(num_spikes) * num_replicates * len(methods_kwargs) - assert comparison.shape[1] == 7 - for comp, sorting in zip(comparison["comparison"], comparison["sorting"]): - comp.sorting2 == sorting - for comp, gt_sorting in zip(comparison_from_df["comparison"], comparison["gt_sorting"]): - comp.sorting1 == gt_sorting - for comp in comparison_from_self["comparison"]: - comp.sorting1 == benchmark.gt_sorting - - -def test_compare_all_sortings_invalid_inputs(benchmark_and_kwargs): - benchmark, methods_kwargs = benchmark_and_kwargs - with pytest.raises(ValueError): - benchmark.compare_all_sortings(pd.DataFrame(), ground_truth="invalid") - - -if __name__ == "__main__": - test_run_matching_vary_parameter(benchmark_and_kwargs) - test_run_matching_vary_parameter_invalid_inputs(benchmark_and_kwargs) - test_run_matching_misclassed_invalid_inputs(benchmark_and_kwargs) - test_run_matching_missing_units_invalid_inputs(benchmark_and_kwargs) - test_compare_all_sortings(benchmark_and_kwargs) - test_compare_all_sortings_invalid_inputs(benchmark_and_kwargs) +# import pytest +# import numpy as np +# import pandas as pd +# import shutil +# import os +# from pathlib import Path + +# import spikeinterface.core as sc +# import spikeinterface.extractors as se +# import spikeinterface.preprocessing as spre +# from spikeinterface.sortingcomponents.benchmark import benchmark_matching + + +# @pytest.fixture(scope="session") +# def benchmark_and_kwargs(tmp_path_factory): +# recording, sorting = se.toy_example(duration=1, num_channels=2, num_units=2, num_segments=1, firing_rate=10, seed=0) +# recording = spre.common_reference(recording, dtype="float32") +# we_path = tmp_path_factory.mktemp("waveforms") +# sort_path = tmp_path_factory.mktemp("sortings") / ("sorting.npz") +# se.NpzSortingExtractor.write_sorting(sorting, sort_path) +# sorting = se.NpzSortingExtractor(sort_path) +# we = sc.extract_waveforms(recording, sorting, we_path, overwrite=True) +# templates = we.get_all_templates() +# noise_levels = sc.get_noise_levels(recording, return_scaled=False) +# methods_kwargs = { +# "tridesclous": dict(waveform_extractor=we, noise_levels=noise_levels), +# "wobble": dict(templates=templates, nbefore=we.nbefore, nafter=we.nafter, parameters={"approx_rank": 2}), +# } +# methods = list(methods_kwargs.keys()) +# benchmark = benchmark_matching.BenchmarkMatching(recording, sorting, we, methods, methods_kwargs) +# return benchmark, methods_kwargs + + +# @pytest.mark.parametrize( +# "parameters, parameter_name", +# [ +# ([1, 10, 100], "num_spikes"), +# ([0, 0.5, 1], "fraction_misclassed"), +# ([0, 0.5, 1], "fraction_missing"), +# ], +# ) +# def test_run_matching_vary_parameter(benchmark_and_kwargs, parameters, parameter_name): +# # Arrange +# benchmark, methods_kwargs = benchmark_and_kwargs +# num_replicates = 2 + +# # Act +# with benchmark as bmk: +# matching_df = bmk.run_matching_vary_parameter(parameters, parameter_name, num_replicates=num_replicates) + +# # Assert +# assert matching_df.shape[0] == len(parameters) * num_replicates * len(methods_kwargs) +# assert matching_df.shape[1] == 6 + + +# @pytest.mark.parametrize( +# "parameter_name, num_replicates", +# [ +# ("invalid_parameter_name", 1), +# ("num_spikes", -1), +# ("num_spikes", 0.5), +# ], +# ) +# def test_run_matching_vary_parameter_invalid_inputs(benchmark_and_kwargs, parameter_name, num_replicates): +# parameters = [1, 2] +# benchmark, methods_kwargs = benchmark_and_kwargs +# with benchmark as bmk: +# with pytest.raises(ValueError): +# bmk.run_matching_vary_parameter(parameters, parameter_name, num_replicates=num_replicates) + + +# @pytest.mark.parametrize( +# "fraction_misclassed, min_similarity", +# [ +# (-1, -1), +# (2, -1), +# (0, 2), +# ], +# ) +# def test_run_matching_misclassed_invalid_inputs(benchmark_and_kwargs, fraction_misclassed, min_similarity): +# benchmark, methods_kwargs = benchmark_and_kwargs +# with benchmark as bmk: +# with pytest.raises(ValueError): +# bmk.run_matching_misclassed(fraction_misclassed, min_similarity=min_similarity) + + +# @pytest.mark.parametrize( +# "fraction_missing, snr_threshold", +# [ +# (-1, 0), +# (2, 0), +# (0, -1), +# ], +# ) +# def test_run_matching_missing_units_invalid_inputs(benchmark_and_kwargs, fraction_missing, snr_threshold): +# benchmark, methods_kwargs = benchmark_and_kwargs +# with benchmark as bmk: +# with pytest.raises(ValueError): +# bmk.run_matching_missing_units(fraction_missing, snr_threshold=snr_threshold) + + +# def test_compare_all_sortings(benchmark_and_kwargs): +# # Arrange +# benchmark, methods_kwargs = benchmark_and_kwargs +# parameter_name = "num_spikes" +# num_replicates = 2 +# num_spikes = [1, 10, 100] +# rng = np.random.default_rng(0) +# sortings, gt_sortings, parameter_values, parameter_names, iter_nums, methods = [], [], [], [], [], [] +# for replicate in range(num_replicates): +# for spike_num in num_spikes: +# for method in list(methods_kwargs.keys()): +# len_spike_train = 100 +# spike_time_inds = rng.choice(benchmark.recording.get_num_frames(), len_spike_train, replace=False) +# unit_ids = rng.choice(benchmark.gt_sorting.get_unit_ids(), len_spike_train, replace=True) +# sort_index = np.argsort(spike_time_inds) +# spike_time_inds = spike_time_inds[sort_index] +# unit_ids = unit_ids[sort_index] +# sorting = sc.NumpySorting.from_times_labels( +# spike_time_inds, unit_ids, benchmark.recording.sampling_frequency +# ) +# spike_time_inds = rng.choice(benchmark.recording.get_num_frames(), len_spike_train, replace=False) +# unit_ids = rng.choice(benchmark.gt_sorting.get_unit_ids(), len_spike_train, replace=True) +# sort_index = np.argsort(spike_time_inds) +# spike_time_inds = spike_time_inds[sort_index] +# unit_ids = unit_ids[sort_index] +# gt_sorting = sc.NumpySorting.from_times_labels( +# spike_time_inds, unit_ids, benchmark.recording.sampling_frequency +# ) +# sortings.append(sorting) +# gt_sortings.append(gt_sorting) +# parameter_values.append(spike_num) +# parameter_names.append(parameter_name) +# iter_nums.append(replicate) +# methods.append(method) +# matching_df = pd.DataFrame( +# { +# "sorting": sortings, +# "gt_sorting": gt_sortings, +# "parameter_value": parameter_values, +# "parameter_name": parameter_names, +# "iter_num": iter_nums, +# "method": methods, +# } +# ) +# comparison_from_df = matching_df.copy() +# comparison_from_self = matching_df.copy() +# comparison_collision = matching_df.copy() + +# # Act +# benchmark.compare_all_sortings(comparison_from_df, ground_truth="from_df") +# benchmark.compare_all_sortings(comparison_from_self, ground_truth="from_self") +# benchmark.compare_all_sortings(comparison_collision, collision=True) + +# # Assert +# for comparison in [comparison_from_df, comparison_from_self, comparison_collision]: +# assert comparison.shape[0] == len(num_spikes) * num_replicates * len(methods_kwargs) +# assert comparison.shape[1] == 7 +# for comp, sorting in zip(comparison["comparison"], comparison["sorting"]): +# comp.sorting2 == sorting +# for comp, gt_sorting in zip(comparison_from_df["comparison"], comparison["gt_sorting"]): +# comp.sorting1 == gt_sorting +# for comp in comparison_from_self["comparison"]: +# comp.sorting1 == benchmark.gt_sorting + + +# def test_compare_all_sortings_invalid_inputs(benchmark_and_kwargs): +# benchmark, methods_kwargs = benchmark_and_kwargs +# with pytest.raises(ValueError): +# benchmark.compare_all_sortings(pd.DataFrame(), ground_truth="invalid") + + +# if __name__ == "__main__": +# test_run_matching_vary_parameter(benchmark_and_kwargs) +# test_run_matching_vary_parameter_invalid_inputs(benchmark_and_kwargs) +# test_run_matching_misclassed_invalid_inputs(benchmark_and_kwargs) +# test_run_matching_missing_units_invalid_inputs(benchmark_and_kwargs) +# test_compare_all_sortings(benchmark_and_kwargs) +# test_compare_all_sortings_invalid_inputs(benchmark_and_kwargs) From 85c2fbd332c156ac1f7cfb9b5b83b690c492e165 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 20 Feb 2024 16:53:06 +0100 Subject: [PATCH 110/192] WIP --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/sorting_tools.py | 3 +- .../benchmark/benchmark_matching.py | 109 ++++++++++++++++-- .../benchmark/benchmark_tools.py | 4 +- 4 files changed, 105 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index d1f67412ec..f5a712f247 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -101,7 +101,7 @@ get_chunk_with_margin, order_channels_by_depth, ) -from .sorting_tools import spike_vector_to_spike_trains +from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_average from .snippets_tools import snippets_from_sorting diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 2b4af70ebf..99c6a6e75a 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -140,7 +140,7 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units): # TODO later : implement other method like "maximum_rate", "by_percent", ... def random_spikes_selection( sorting: BaseSorting, - num_samples: int, + num_samples: int | None = None, method: str = "uniform", max_spikes_per_unit: int = 500, margin_size: int | None = None, @@ -189,6 +189,7 @@ def random_spikes_selection( raise ValueError(f"random_spikes_selection wrong method {method}, currently only 'uniform' can be used.") if margin_size is not None: + assert num_samples is not None margin_size = int(margin_size) keep = np.ones(selected_unit_indices.size, dtype=bool) # left margin diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index caab8f0659..879dbd281f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -1,11 +1,11 @@ from __future__ import annotations -from spikeinterface.core import extract_waveforms from spikeinterface.preprocessing import bandpass_filter, common_reference from spikeinterface.postprocessing import compute_template_similarity from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core import NumpySorting from spikeinterface.qualitymetrics import compute_quality_metrics +from spikeinterface import load_extractor from spikeinterface.comparison import CollisionGTComparison, compare_sorter_to_ground_truth from spikeinterface.widgets import ( plot_agreement_matrix, @@ -15,6 +15,7 @@ import time import os +import pickle from pathlib import Path import string, random import pylab as plt @@ -24,16 +25,106 @@ import shutil import copy from tqdm.auto import tqdm +from .benchmark_tools import BenchmarkStudy, Benchmark +from spikeinterface.core.basesorting import minimum_spike_dtype -def running_in_notebook(): - try: - shell = get_ipython().__class__.__name__ - notebook_shells = {"ZMQInteractiveShell", "TerminalInteractiveShell"} - # if a shell is missing from this set just check get_ipython().__class__.__name__ and add it to the set - return shell in notebook_shells - except NameError: - return False +class MatchingBenchmark(Benchmark): + + def __init__(self, recording, gt_sorting, params): + self.recording = recording + self.gt_sorting = gt_sorting + self.method = params['method'] + self.templates = params["method_kwargs"]['templates'] + self.method_kwargs = params['method_kwargs'] + + def run(self, **job_kwargs): + spikes = find_spikes_from_templates( + self.recording, + method=self.method, + method_kwargs=self.method_kwargs, + **job_kwargs + ) + unit_ids = self.templates.unit_ids + sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) + sorting["sample_index"] = spikes["sample_index"] + sorting["unit_index"] = spikes["cluster_index"] + sorting["segment_index"] = spikes["segment_index"] + sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids) + result = {'sorting' : sorting} + + ## Add metrics + + comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) + result['gt_comparison'] = comp + + ## To do add collisions + #comparison = CollisionGTComparison(gt_sorting, sorting, exhaustive_gt=self.exhaustive_gt, **kwargs) + return result + + def save_to_folder(self, folder, result): + result['sorting'].save(folder = folder / "sorting", format="numpy_folder") + + comparison_file = folder / "gt_comparison.pickle" + with open(comparison_file, mode="wb") as f: + pickle.dump(result['gt_comparison'], f) + + @classmethod + def load_folder(cls, folder): + result = {} + result['sorting'] = load_extractor(folder / "sorting") + with open(folder / "gt_comparison.pickle", "rb") as f: + result['gt_comparison'] = pickle.load(f) + return result + +class MatchingStudy(BenchmarkStudy): + + benchmark_class = MatchingBenchmark + + def _run(self, keys, **job_kwargs): + for key in keys: + + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + benchmark = MatchingBenchmark(recording, gt_sorting, params) + result = benchmark.run() + self.results[key] = result + benchmark.save_to_folder(self.folder / "results" / self.key_to_str(key), result) + + + def plot_agreements(self, keys=None, figsize=(15,15)): + if keys is None: + keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=1, nrows=len(keys), figsize=figsize) + + for count, key in enumerate(keys): + ax = axs[count] + ax.set_title(self.cases[key]['label']) + plot_agreement_matrix(self.results[key]['gt_comparison'], ax=ax) + + def plot_performances_vs_snr(self, keys=None, figsize=(15,15)): + if keys is None: + keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + + for count, k in enumerate(("accuracy", "recall", "precision")): + + ax = axs[count] + for key in keys: + label = self.cases[key]["label"] + + analyzer = self.get_sorting_analyzer(key) + metrics = analyzer.get_extension('quality_metrics').get_data() + x = metrics["snr"].values + + y = self.results[key]['gt_comparison'].get_performance()[k].values + + ax.scatter(x, y, marker=".", label=label) + if count == 2: + ax.legend() class BenchmarkMatching: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 5fda48073f..f47aa1433c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -236,10 +236,10 @@ class Benchmark: """ """ @classmethod - def load_folder(cls): + def load_folder(cls, folder): raise NotImplementedError - def save_to_folder(self, folder): + def save_to_folder(self, folder, result): raise NotImplementedError def run(self): From 2a41cd138778b4551fce66769c926d596f19347e Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 20 Feb 2024 17:48:50 +0100 Subject: [PATCH 111/192] WIP --- .../benchmark/benchmark_matching.py | 835 ++++-------------- 1 file changed, 163 insertions(+), 672 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 879dbd281f..9e131879c3 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -57,14 +57,13 @@ def run(self, **job_kwargs): comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) result['gt_comparison'] = comp - - ## To do add collisions - #comparison = CollisionGTComparison(gt_sorting, sorting, exhaustive_gt=self.exhaustive_gt, **kwargs) + result['templates'] = self.templates + result['gt_collision'] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) return result def save_to_folder(self, folder, result): result['sorting'].save(folder = folder / "sorting", format="numpy_folder") - + result['templates'].to_zarr(folder / "templates") comparison_file = folder / "gt_comparison.pickle" with open(comparison_file, mode="wb") as f: pickle.dump(result['gt_comparison'], f) @@ -73,6 +72,7 @@ def save_to_folder(self, folder, result): def load_folder(cls, folder): result = {} result['sorting'] = load_extractor(folder / "sorting") + result['templates'] = Templates.from_zarr(folder / "templates") with open(folder / "gt_comparison.pickle", "rb") as f: result['gt_comparison'] = pickle.load(f) return result @@ -93,700 +93,191 @@ def _run(self, keys, **job_kwargs): benchmark.save_to_folder(self.folder / "results" / self.key_to_str(key), result) - def plot_agreements(self, keys=None, figsize=(15,15)): - if keys is None: - keys = list(self.cases.keys()) + def plot_agreements(self, case_keys=None, figsize=(15,15)): + if case_keys is None: + case_keys = list(self.cases.keys()) - fig, axs = plt.subplots(ncols=1, nrows=len(keys), figsize=figsize) + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) - for count, key in enumerate(keys): + for count, key in enumerate(case_keys): ax = axs[count] ax.set_title(self.cases[key]['label']) plot_agreement_matrix(self.results[key]['gt_comparison'], ax=ax) - def plot_performances_vs_snr(self, keys=None, figsize=(15,15)): - if keys is None: - keys = list(self.cases.keys()) + def plot_performances_vs_snr(self, case_keys=None, figsize=(15,15)): + if case_keys is None: + case_keys = list(self.cases.keys()) fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) for count, k in enumerate(("accuracy", "recall", "precision")): ax = axs[count] - for key in keys: + for key in case_keys: label = self.cases[key]["label"] analyzer = self.get_sorting_analyzer(key) metrics = analyzer.get_extension('quality_metrics').get_data() x = metrics["snr"].values - y = self.results[key]['gt_comparison'].get_performance()[k].values - ax.scatter(x, y, marker=".", label=label) + if count == 2: ax.legend() + def plot_collisions(self, case_keys=None, figsize=(15,15)): + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) -class BenchmarkMatching: - """Benchmark a set of template matching methods on a given recording and ground-truth sorting.""" - - def __init__( - self, - recording, - gt_sorting, - waveform_extractor, - methods, - methods_kwargs=None, - exhaustive_gt=True, - tmp_folder=None, - template_mode="median", - **job_kwargs, - ): - self.methods = methods - if methods_kwargs is None: - methods_kwargs = {method: {} for method in methods} - self.methods_kwargs = methods_kwargs - self.recording = recording - self.gt_sorting = gt_sorting - self.job_kwargs = job_kwargs - self.exhaustive_gt = exhaustive_gt - self.sampling_rate = self.recording.get_sampling_frequency() - - if tmp_folder is None: - tmp_folder = os.path.join(".", "".join(random.choices(string.ascii_uppercase + string.digits, k=8))) - self.tmp_folder = Path(tmp_folder) - self.sort_folders = [] - - self.we = waveform_extractor - for method in self.methods: - self.methods_kwargs[method]["waveform_extractor"] = self.we - self.templates = self.we.get_all_templates(mode=template_mode) - self.metrics = compute_quality_metrics(self.we, metric_names=["snr"], load_if_exists=True) - self.similarity = compute_template_similarity(self.we) - self.parameter_name2matching_fn = dict( - num_spikes=self.run_matching_num_spikes, - fraction_misclassed=self.run_matching_misclassed, - fraction_missing=self.run_matching_missing_units, - ) - - def __enter__(self): - self.tmp_folder.mkdir(exist_ok=True) - self.sort_folders = [] - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.tmp_folder.exists(): - shutil.rmtree(self.tmp_folder) - for sort_folder in self.sort_folders: - if sort_folder.exists(): - shutil.rmtree(sort_folder) - - def run_matching(self, methods_kwargs, unit_ids): - """Run template matching on the recording with settings in methods_kwargs. - - Parameters - ---------- - methods_kwargs: dict - A dictionary of method_kwargs for each method. - unit_ids: array-like - The unit ids to use for the output sorting. - - Returns - ------- - sortings: dict - A dictionary that maps method --> NumpySorting. - runtimes: dict - A dictionary that maps method --> runtime. - """ - sortings, runtimes = {}, {} - for method in self.methods: - t0 = time.time() - spikes = find_spikes_from_templates( - self.recording, method=method, method_kwargs=methods_kwargs[method], **self.job_kwargs - ) - runtimes[method] = time.time() - t0 - sorting = NumpySorting.from_times_labels( - spikes["sample_index"], unit_ids[spikes["cluster_index"]], self.sampling_rate - ) - sortings[method] = sorting - return sortings, runtimes - - def run_matching_num_spikes(self, spike_num, seed=0, we_kwargs=None, template_mode="median"): - """Run template matching with a given number of spikes per unit. - - Parameters - ---------- - spike_num: int - The maximum number of spikes per unit - seed: int, default: 0 - Random seed - we_kwargs: dict - A dictionary of keyword arguments for the WaveformExtractor - template_mode: "mean" | "median" | "std", default: "median" - The mode to use to extract templates from the WaveformExtractor - - Returns - ------- - - sortings: dict - A dictionary that maps method --> NumpySorting. - gt_sorting: NumpySorting - The ground-truth sorting used for template matching (= self.gt_sorting). - """ - if we_kwargs is None: - we_kwargs = {} - we_kwargs.update( - dict(max_spikes_per_unit=spike_num, seed=seed, overwrite=True, load_if_exists=False, **self.job_kwargs) - ) - - # Generate New Waveform Extractor with New Spike Numbers - we = extract_waveforms(self.recording, self.gt_sorting, self.tmp_folder, **we_kwargs) - methods_kwargs = self.update_methods_kwargs(we, template_mode) - - sortings, _ = self.run_matching(methods_kwargs, we.unit_ids) - shutil.rmtree(self.tmp_folder) - return sortings, self.gt_sorting - - def update_methods_kwargs(self, we, template_mode="median"): - """Update the methods_kwargs dictionary with the new WaveformExtractor. - - Parameters - ---------- - we: WaveformExtractor - The new WaveformExtractor. - template_mode: "mean" | "median" | "std", default: "median" - The mode to use to extract templates from the WaveformExtractor - - Returns - ------- - methods_kwargs: dict - A dictionary of method_kwargs for each method. - """ - templates = we.get_all_templates(we.unit_ids, mode=template_mode) - methods_kwargs = copy.deepcopy(self.methods_kwargs) - for method in self.methods: - method_kwargs = methods_kwargs[method] - if method == "wobble": - method_kwargs.update(dict(templates=templates, nbefore=we.nbefore, nafter=we.nafter)) - else: - method_kwargs["waveform_extractor"] = we - return methods_kwargs - - def run_matching_misclassed( - self, fraction_misclassed, min_similarity=-1, seed=0, we_kwargs=None, template_mode="median" - ): - """Run template matching with a given fraction of misclassified spikes. - - Parameters - ---------- - fraction_misclassed: float - The fraction of misclassified spikes. - min_similarity: float, default: -1 - The minimum cosine similarity between templates to be considered similar - seed: int, default: 0 - Random seed - we_kwargs: dict - A dictionary of keyword arguments for the WaveformExtractor - template_mode: "mean" | "median" | "std", default: "median" - The mode to use to extract templates from the WaveformExtractor - - Returns - ------- - sortings: dict - A dictionary that maps method --> NumpySorting. - gt_sorting: NumpySorting - The ground-truth sorting used for template matching (with misclassified spike trains). - """ - try: - assert 0 <= fraction_misclassed <= 1 - except AssertionError: - raise ValueError("'fraction_misclassed' must be between 0 and 1.") - try: - assert -1 <= min_similarity <= 1 - except AssertionError: - raise ValueError("'min_similarity' must be between -1 and 1.") - if we_kwargs is None: - we_kwargs = {} - we_kwargs.update(dict(seed=seed, overwrite=True, load_if_exists=False, **self.job_kwargs)) - rng = np.random.default_rng(seed) - - # Randomly misclass spike trains - spike_time_indices, labels = [], [] - for unit_index, unit_id in enumerate(self.we.unit_ids): - unit_spike_train = self.gt_sorting.get_unit_spike_train(unit_id=unit_id) - unit_similarity = self.similarity[unit_index, :] - unit_similarity[unit_index] = min_similarity - 1 # skip self - similar_unit_ids = self.we.unit_ids[unit_similarity >= min_similarity] - at_least_one_similar_unit = len(similar_unit_ids) - num_spikes = int(len(unit_spike_train) * fraction_misclassed) - unit_misclass_idx = rng.choice(np.arange(len(unit_spike_train)), size=num_spikes, replace=False) - unit_labels = np.repeat(unit_id, len(unit_spike_train)) - if at_least_one_similar_unit: - unit_labels[unit_misclass_idx] = rng.choice(similar_unit_ids, size=num_spikes) - spike_time_indices.extend(list(unit_spike_train)) - labels.extend(list(unit_labels)) - spike_time_indices = np.array(spike_time_indices) - labels = np.array(labels) - sort_idx = np.argsort(spike_time_indices) - spike_time_indices = spike_time_indices[sort_idx] - labels = labels[sort_idx] - gt_sorting = NumpySorting.from_times_labels( - spike_time_indices, labels, self.sampling_rate, unit_ids=self.we.unit_ids - ) - sort_folder = Path(self.tmp_folder.stem + f"_sorting{len(self.sort_folders)}") - gt_sorting = gt_sorting.save(folder=sort_folder) - self.sort_folders.append(sort_folder) - - # Generate New Waveform Extractor with Misclassed Spike Trains - we = extract_waveforms(self.recording, gt_sorting, self.tmp_folder, **we_kwargs) - methods_kwargs = self.update_methods_kwargs(we, template_mode) - - sortings, _ = self.run_matching(methods_kwargs, we.unit_ids) - shutil.rmtree(self.tmp_folder) - return sortings, gt_sorting - - def run_matching_missing_units( - self, fraction_missing, snr_threshold=0, seed=0, we_kwargs=None, template_mode="median" - ): - """Run template matching with a given fraction of missing units. - - Parameters - ---------- - fraction_missing: float - The fraction of missing units. - snr_threshold: float, default: 0 - The SNR threshold below which units are considered missing - seed: int, default: 0 - Random seed - we_kwargs: dict - A dictionary of keyword arguments for the WaveformExtractor. - template_mode: "mean" | "median" | "std", default: "median" - The mode to use to extract templates from the WaveformExtractor - - Returns - ------- - sortings: dict - A dictionary that maps method --> NumpySorting. - gt_sorting: NumpySorting - The ground-truth sorting used for template matching (with missing units). - """ - try: - assert 0 <= fraction_missing <= 1 - except AssertionError: - raise ValueError("'fraction_missing' must be between 0 and 1.") - try: - assert snr_threshold >= 0 - except AssertionError: - raise ValueError("'snr_threshold' must be greater than or equal to 0.") - if we_kwargs is None: - we_kwargs = {} - we_kwargs.update(dict(seed=seed, overwrite=True, load_if_exists=False, **self.job_kwargs)) - rng = np.random.default_rng(seed) - - # Omit fraction_missing of units with lowest SNR - metrics = self.metrics.sort_values("snr") - missing_units = np.array(metrics.index[metrics.snr < snr_threshold]) - num_missing = int(len(missing_units) * fraction_missing) - missing_units = rng.choice(missing_units, size=num_missing, replace=False) - present_units = np.setdiff1d(self.we.unit_ids, missing_units) - # spike_time_indices, spike_cluster_ids = [], [] - # for unit in present_units: - # spike_train = self.gt_sorting.get_unit_spike_train(unit) - # for time_index in spike_train: - # spike_time_indices.append(time_index) - # spike_cluster_ids.append(unit) - # spike_time_indices = np.array(spike_time_indices) - # spike_cluster_ids = np.array(spike_cluster_ids) - # gt_sorting = NumpySorting.from_times_labels(spike_time_indices, spike_cluster_ids, self.sampling_rate, - # unit_ids=present_units) - gt_sorting = self.gt_sorting.select_units(present_units) - sort_folder = Path(self.tmp_folder.stem + f"_sorting{len(self.sort_folders)}") - gt_sorting = gt_sorting.save(folder=sort_folder) - self.sort_folders.append(sort_folder) - - # Generate New Waveform Extractor with Missing Units - we = extract_waveforms(self.recording, gt_sorting, self.tmp_folder, **we_kwargs) - methods_kwargs = self.update_methods_kwargs(we, template_mode) - - sortings, _ = self.run_matching(methods_kwargs, we.unit_ids) - shutil.rmtree(self.tmp_folder) - return sortings, gt_sorting - - def run_matching_vary_parameter( - self, - parameters, - parameter_name, - num_replicates=1, - we_kwargs=None, - template_mode="median", - progress_bars=[], - **kwargs, - ): - """Run template matching varying the values of a given parameter. - - Parameters - ---------- - parameters: array-like - The values of the parameter to vary. - parameter_name: "num_spikes", "fraction_misclassed", "fraction_missing" - The name of the parameter to vary. - num_replicates: int, default: 1 - The number of replicates to run for each parameter value - we_kwargs: dict - A dictionary of keyword arguments for the WaveformExtractor - template_mode: "mean" | "median" | "std", default: "median" - The mode to use to extract templates from the WaveformExtractor - **kwargs - Keyword arguments for the run_matching method - - Returns - ------- - matching_df : pandas.DataFrame - A dataframe of NumpySortings for each method/parameter_value/iteration combination. - """ - try: - run_matching_fn = self.parameter_name2matching_fn[parameter_name] - except KeyError: - raise ValueError(f"Parameter name must be one of {list(self.parameter_name2matching_fn.keys())}") - try: - progress_bar = self.job_kwargs["progress_bar"] - except KeyError: - progress_bar = False - try: - assert isinstance(num_replicates, int) - assert num_replicates > 0 - except AssertionError: - raise ValueError("num_replicates must be a positive integer") - - sortings, gt_sortings, parameter_values, parameter_names, iter_nums, methods = [], [], [], [], [], [] - if progress_bar: - parameters = tqdm(parameters, desc=f"Vary Parameter ({parameter_name})") - for parameter in parameters: - if progress_bar and num_replicates > 1: - replicates = tqdm(range(1, num_replicates + 1), desc=f"Replicating for Variability") - else: - replicates = range(1, num_replicates + 1) - for i in replicates: - sorting_per_method, gt_sorting = run_matching_fn( - parameter, seed=i, we_kwargs=we_kwargs, template_mode=template_mode, **kwargs - ) - for method in self.methods: - sortings.append(sorting_per_method[method]) - gt_sortings.append(gt_sorting) - parameter_values.append(parameter) - parameter_names.append(parameter_name) - iter_nums.append(i) - methods.append(method) - if running_in_notebook(): - from IPython.display import clear_output - - clear_output(wait=True) - for bar in progress_bars: - display(bar.container) - display(parameters.container) - if num_replicates > 1: - display(replicates.container) - matching_df = pd.DataFrame( - { - "sorting": sortings, - "gt_sorting": gt_sortings, - "parameter_value": parameter_values, - "parameter_name": parameter_names, - "iter_num": iter_nums, - "method": methods, - } - ) - return matching_df - - def compare_sortings(self, gt_sorting, sorting, collision=False, **kwargs): - """Compare a sorting to a ground-truth sorting. - - Parameters - ---------- - gt_sorting: SortingExtractor - The ground-truth sorting extractor. - sorting: SortingExtractor - The sorting extractor to compare to the ground-truth. - collision: bool - If True, use the CollisionGTComparison class. If False, use the compare_sorter_to_ground_truth function. - **kwargs - Keyword arguments for the comparison function. - - Returns - ------- - comparison: GroundTruthComparison - The comparison object. - """ - if collision: - return CollisionGTComparison(gt_sorting, sorting, exhaustive_gt=self.exhaustive_gt, **kwargs) - else: - return compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=self.exhaustive_gt, **kwargs) - - def compare_all_sortings(self, matching_df, collision=False, ground_truth="from_self", **kwargs): - """Compare all sortings in a matching dataframe to their ground-truth sortings. - - Parameters - ---------- - matching_df: pandas.DataFrame - A dataframe of NumpySortings for each method/parameter_value/iteration combination. - collision: bool - If True, use the CollisionGTComparison class. If False, use the compare_sorter_to_ground_truth function. - ground_truth: "from_self" | "from_df", default: "from_self" - If "from_self", use the ground-truth sorting stored in the BenchmarkMatching object. If "from_df", use the - ground-truth sorting stored in the matching_df. - **kwargs - Keyword arguments for the comparison function. - - Notes - ----- - This function adds a new column to the matching_df called "comparison" that contains the GroundTruthComparison - object for each row. - """ - if ground_truth == "from_self": - comparison_fn = lambda row: self.compare_sortings( - self.gt_sorting, row["sorting"], collision=collision, **kwargs - ) - elif ground_truth == "from_df": - comparison_fn = lambda row: self.compare_sortings( - row["gt_sorting"], row["sorting"], collision=collision, **kwargs - ) - else: - raise ValueError("'ground_truth' must be either 'from_self' or 'from_df'") - matching_df["comparison"] = matching_df.apply(comparison_fn, axis=1) - - def plot(self, comp, title=None): - fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(10, 10)) - ax = axs[0, 0] - ax.set_title(title) - plot_agreement_matrix(comp, ax=ax) - ax.set_title(title) - - ax = axs[1, 0] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - for k in ("accuracy", "recall", "precision"): - x = comp.get_performance()[k] - y = self.metrics["snr"] - ax.scatter(x, y, markersize=10, marker=".", label=k) - ax.legend() - - ax = axs[0, 1] - if self.exhaustive_gt: + for count, key in enumerate(case_keys): + templates_array = self.results[key]['templates'].templates_array plot_comparison_collision_by_similarity( - comp, self.templates, ax=ax, show_legend=True, mode="lines", good_only=False + self.results[key]['gt_collision'], templates_array, ax=axs[count], + show_legend=True, mode="lines", good_only=False ) - return fig, axs - - -def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine"): - fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(15, 10)) - - benchmark.we.sorting.get_unit_spike_train(unit_id) - template = benchmark.we.get_template(unit_id) - a = template.reshape(template.size, 1).T - count = 0 - colors = ["r", "b"] - for label in ["TP", "FN"]: - seg_num = 0 # TODO: make compatible with multiple segments - idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) - idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - intersection = np.where(np.isin(idx_2, idx_1))[0] - intersection = np.random.permutation(intersection)[:nb_spikes] - if len(intersection) == 0: - print(f"No {label}s found for unit {unit_id}") - continue - ### Should be able to give a subset of waveforms only... - ax = axs[count, 0] - plot_unit_waveforms( - benchmark.we, - unit_ids=[unit_id], - axes=[ax], - unit_selected_waveforms={unit_id: intersection}, - unit_colors={unit_id: colors[count]}, - ) - ax.set_title(label) - - wfs = benchmark.we.get_waveforms(unit_id) - wfs = wfs[intersection, :, :] - - import sklearn - - nb_spikes = len(wfs) - b = wfs.reshape(nb_spikes, -1) - distances = sklearn.metrics.pairwise_distances(a, b, metric).flatten() - ax = axs[count, 1] - ax.set_title(label) - ax.hist(distances, color=colors[count]) - ax.set_ylabel("# waveforms") - ax.set_xlabel(metric) - - count += 1 - return fig, axs - -def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cosine"): - templates = benchmark.templates - nb_units = len(benchmark.we.unit_ids) - colors = ["r", "b"] - - results = {"TP": {"mean": [], "std": []}, "FN": {"mean": [], "std": []}} - - for i in range(nb_units): - unit_id = benchmark.we.unit_ids[i] - idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - wfs = benchmark.we.get_waveforms(unit_id) - template = benchmark.we.get_template(unit_id) - a = template.reshape(template.size, 1).T - - for label in ["TP", "FN"]: - idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - intersection = np.where(np.isin(idx_2, idx_1))[0] - intersection = np.random.permutation(intersection)[:nb_spikes] - wfs_sliced = wfs[intersection, :, :] - - import sklearn - - all_spikes = len(wfs_sliced) - if all_spikes > 0: - b = wfs_sliced.reshape(all_spikes, -1) - if metric == "cosine": - distances = sklearn.metrics.pairwise.cosine_similarity(a, b).flatten() - else: - distances = sklearn.metrics.pairwise_distances(a, b, metric).flatten() - results[label]["mean"] += [np.nanmean(distances)] - results[label]["std"] += [np.nanstd(distances)] - else: - results[label]["mean"] += [0] - results[label]["std"] += [0] - - fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(15, 5)) - for count, label in enumerate(["TP", "FN"]): - ax = axs[count] - idx = np.argsort(benchmark.metrics.snr) - means = np.array(results[label]["mean"])[idx] - stds = np.array(results[label]["std"])[idx] - ax.errorbar(benchmark.metrics.snr[idx], means, yerr=stds, c=colors[count]) - ax.set_title(label) - ax.set_xlabel("snr") - ax.set_ylabel(metric) - return fig, axs - - -def plot_comparison_matching( - benchmark, - comp_per_method, - performance_names=["accuracy", "recall", "precision"], - colors=["g", "b", "r"], - ylim=(-0.1, 1.1), -): - num_methods = len(benchmark.methods) - fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10)) - for i, method1 in enumerate(benchmark.methods): - for j, method2 in enumerate(benchmark.methods): - if len(axs.shape) > 1: - ax = axs[i, j] - else: - ax = axs[j] - comp1, comp2 = comp_per_method[method1], comp_per_method[method2] - if i <= j: - for performance, color in zip(performance_names, colors): - perf1 = comp1.get_performance()[performance] - perf2 = comp2.get_performance()[performance] - ax.plot(perf2, perf1, ".", label=performance, color=color) - - ax.plot([0, 1], [0, 1], "k--", alpha=0.5) - ax.set_ylim(ylim) - ax.set_xlim(ylim) - ax.spines[["right", "top"]].set_visible(False) - ax.set_aspect("equal") - - if j == i: - ax.set_ylabel(f"{method1}") + # def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine"): + # fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(15, 10)) + + # benchmark.we.sorting.get_unit_spike_train(unit_id) + # template = benchmark.we.get_template(unit_id) + # a = template.reshape(template.size, 1).T + # count = 0 + # colors = ["r", "b"] + # for label in ["TP", "FN"]: + # seg_num = 0 # TODO: make compatible with multiple segments + # idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) + # idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] + # intersection = np.where(np.isin(idx_2, idx_1))[0] + # intersection = np.random.permutation(intersection)[:nb_spikes] + # if len(intersection) == 0: + # print(f"No {label}s found for unit {unit_id}") + # continue + # ### Should be able to give a subset of waveforms only... + # ax = axs[count, 0] + # plot_unit_waveforms( + # benchmark.we, + # unit_ids=[unit_id], + # axes=[ax], + # unit_selected_waveforms={unit_id: intersection}, + # unit_colors={unit_id: colors[count]}, + # ) + # ax.set_title(label) + + # wfs = benchmark.we.get_waveforms(unit_id) + # wfs = wfs[intersection, :, :] + + # import sklearn + + # nb_spikes = len(wfs) + # b = wfs.reshape(nb_spikes, -1) + # distances = sklearn.metrics.pairwise_distances(a, b, metric).flatten() + # ax = axs[count, 1] + # ax.set_title(label) + # ax.hist(distances, color=colors[count]) + # ax.set_ylabel("# waveforms") + # ax.set_xlabel(metric) + + # count += 1 + # return fig, axs + + + # def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cosine"): + # templates = benchmark.templates + # nb_units = len(benchmark.we.unit_ids) + # colors = ["r", "b"] + + # results = {"TP": {"mean": [], "std": []}, "FN": {"mean": [], "std": []}} + + # for i in range(nb_units): + # unit_id = benchmark.we.unit_ids[i] + # idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] + # wfs = benchmark.we.get_waveforms(unit_id) + # template = benchmark.we.get_template(unit_id) + # a = template.reshape(template.size, 1).T + + # for label in ["TP", "FN"]: + # idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] + # intersection = np.where(np.isin(idx_2, idx_1))[0] + # intersection = np.random.permutation(intersection)[:nb_spikes] + # wfs_sliced = wfs[intersection, :, :] + + # import sklearn + + # all_spikes = len(wfs_sliced) + # if all_spikes > 0: + # b = wfs_sliced.reshape(all_spikes, -1) + # if metric == "cosine": + # distances = sklearn.metrics.pairwise.cosine_similarity(a, b).flatten() + # else: + # distances = sklearn.metrics.pairwise_distances(a, b, metric).flatten() + # results[label]["mean"] += [np.nanmean(distances)] + # results[label]["std"] += [np.nanstd(distances)] + # else: + # results[label]["mean"] += [0] + # results[label]["std"] += [0] + + # fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(15, 5)) + # for count, label in enumerate(["TP", "FN"]): + # ax = axs[count] + # idx = np.argsort(benchmark.metrics.snr) + # means = np.array(results[label]["mean"])[idx] + # stds = np.array(results[label]["std"])[idx] + # ax.errorbar(benchmark.metrics.snr[idx], means, yerr=stds, c=colors[count]) + # ax.set_title(label) + # ax.set_xlabel("snr") + # ax.set_ylabel(metric) + # return fig, axs + + def plot_comparison_matching( + benchmark, + comp_per_method, + performance_names=["accuracy", "recall", "precision"], + colors=["g", "b", "r"], + ylim=(-0.1, 1.1), + ): + num_methods = len(benchmark.methods) + fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10)) + for i, method1 in enumerate(benchmark.methods): + for j, method2 in enumerate(benchmark.methods): + if len(axs.shape) > 1: + ax = axs[i, j] else: - ax.set_yticks([]) - if i == j: - ax.set_xlabel(f"{method2}") + ax = axs[j] + comp1, comp2 = comp_per_method[method1], comp_per_method[method2] + if i <= j: + for performance, color in zip(performance_names, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.plot(perf2, perf1, ".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + if j == i: + ax.set_ylabel(f"{method1}") + else: + ax.set_yticks([]) + if i == j: + ax.set_xlabel(f"{method2}") + else: + ax.set_xticks([]) + if i == num_methods - 1 and j == num_methods - 1: + patches = [] + for color, name in zip(colors, performance_names): + patches.append(mpatches.Patch(color=color, label=name)) + ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) else: + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) ax.set_xticks([]) - if i == num_methods - 1 and j == num_methods - 1: - patches = [] - for color, name in zip(colors, performance_names): - patches.append(mpatches.Patch(color=color, label=name)) - ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) - else: - ax.spines["bottom"].set_visible(False) - ax.spines["left"].set_visible(False) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xticks([]) - ax.set_yticks([]) - plt.tight_layout(h_pad=0, w_pad=0) - return fig, axs - - -def compute_rejection_rate(comp, method="by_unit"): - missing_unit_ids = set(comp.unit1_ids) - set(comp.unit2_ids) - performance = comp.get_performance() - rejection_rates = np.zeros(len(missing_unit_ids)) - for i, missing_unit_id in enumerate(missing_unit_ids): - rejection_rates[i] = performance.miss_rate[performance.index == missing_unit_id] - if method == "by_unit": - return rejection_rates - elif method == "pooled_with_average": - return np.mean(rejection_rates) - else: - raise ValueError(f'method must be "by_unit" or "pooled_with_average" but got {method}') - - -def plot_vary_parameter( - matching_df, performance_metric="accuracy", method_colors=None, parameter_transform=lambda x: x -): - parameter_names = matching_df.parameter_name.unique() - methods = matching_df.method.unique() - if method_colors is None: - method_colors = {method: f"C{i}" for i, method in enumerate(methods)} - figs, axs = [], [] - for parameter_name in parameter_names: - df_parameter = matching_df[matching_df.parameter_name == parameter_name] - parameters = df_parameter.parameter_value.unique() - method_means = {method: [] for method in methods} - method_stds = {method: [] for method in methods} - for parameter in parameters: - for method in methods: - method_param_mask = np.logical_and( - df_parameter.method == method, df_parameter.parameter_value == parameter - ) - comps = df_parameter.comparison[method_param_mask] - performance_metrics = [] - for comp in comps: - try: - perf_metric = comp.get_performance(method="pooled_with_average")[performance_metric] - except KeyError: # benchmarking-specific metric - assert performance_metric == "rejection_rate", f"{performance_metric} is not a valid metric" - perf_metric = compute_rejection_rate(comp, method="pooled_with_average") - performance_metrics.append(perf_metric) - # Average / STD over replicates - method_means[method].append(np.mean(performance_metrics)) - method_stds[method].append(np.std(performance_metrics)) - - parameters_transformed = parameter_transform(parameters) - fig, ax = plt.subplots() - for method in methods: - mean, std = method_means[method], method_stds[method] - ax.errorbar( - parameters_transformed, mean, std, color=method_colors[method], marker="o", markersize=5, label=method - ) - if parameter_name == "num_spikes": - xlabel = "Number of Spikes" - elif parameter_name == "fraction_misclassed": - xlabel = "Fraction of Spikes Misclassified" - elif parameter_name == "fraction_missing": - xlabel = "Fraction of Low SNR Units Missing" - ax.set_xticks(parameters_transformed, parameters) - ax.set_xlabel(xlabel) - ax.set_ylabel(f"Average Unit {performance_metric}") - ax.legend() - figs.append(fig) - axs.append(ax) - return figs, axs + ax.set_yticks([]) + plt.tight_layout(h_pad=0, w_pad=0) + return fig, axs \ No newline at end of file From 7e6bc1e7d9510f51967afac076a26d5e6ec54b9f Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 21 Feb 2024 09:29:39 +0100 Subject: [PATCH 112/192] WIP --- .../benchmark/benchmark_matching.py | 187 +++++------------- .../benchmark/benchmark_tools.py | 50 +++-- 2 files changed, 86 insertions(+), 151 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 9e131879c3..962730f4fc 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -3,6 +3,7 @@ from spikeinterface.preprocessing import bandpass_filter, common_reference from spikeinterface.postprocessing import compute_template_similarity from spikeinterface.sortingcomponents.matching import find_spikes_from_templates +from spikeinterface.core.template import Templates from spikeinterface.core import NumpySorting from spikeinterface.qualitymetrics import compute_quality_metrics from spikeinterface import load_extractor @@ -37,7 +38,8 @@ def __init__(self, recording, gt_sorting, params): self.method = params['method'] self.templates = params["method_kwargs"]['templates'] self.method_kwargs = params['method_kwargs'] - + self.result = {} + def run(self, **job_kwargs): spikes = find_spikes_from_templates( self.recording, @@ -51,48 +53,53 @@ def run(self, **job_kwargs): sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids) - result = {'sorting' : sorting} + self.result = {'sorting' : sorting} + self.result['templates'] = self.templates - ## Add metrics - + def compute_result(self, **result_params): + sorting = self.result['sorting'] comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) - result['gt_comparison'] = comp - result['templates'] = self.templates - result['gt_collision'] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) - return result + self.result['gt_comparison'] = comp + self.result['gt_collision'] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) + + def save_run(self, folder): + self.result['sorting'].save(folder = folder / "sorting", format="numpy_folder") + self.result['templates'].to_zarr(folder / "templates") - def save_to_folder(self, folder, result): - result['sorting'].save(folder = folder / "sorting", format="numpy_folder") - result['templates'].to_zarr(folder / "templates") + def save_result(self, folder): comparison_file = folder / "gt_comparison.pickle" with open(comparison_file, mode="wb") as f: - pickle.dump(result['gt_comparison'], f) - + pickle.dump(self.result['gt_comparison'], f) + + collision_file = folder / "gt_collision.pickle" + with open(collision_file, mode="wb") as f: + pickle.dump(self.result['gt_collision'], f) + @classmethod def load_folder(cls, folder): result = {} result['sorting'] = load_extractor(folder / "sorting") result['templates'] = Templates.from_zarr(folder / "templates") - with open(folder / "gt_comparison.pickle", "rb") as f: - result['gt_comparison'] = pickle.load(f) + if (folder / "gt_comparison.pickle").exists(): + with open(folder / "gt_comparison.pickle", "rb") as f: + result['gt_comparison'] = pickle.load(f) + if (folder / "gt_collision.pickle").exists(): + with open(folder / "gt_collision.pickle", "rb") as f: + result['gt_collision'] = pickle.load(f) return result + class MatchingStudy(BenchmarkStudy): benchmark_class = MatchingBenchmark - def _run(self, keys, **job_kwargs): - for key in keys: - - dataset_key = self.cases[key]["dataset"] - recording, gt_sorting = self.datasets[dataset_key] - params = self.cases[key]["params"] - benchmark = MatchingBenchmark(recording, gt_sorting, params) - result = benchmark.run() - self.results[key] = result - benchmark.save_to_folder(self.folder / "results" / self.key_to_str(key), result) - - + def create_benchmark(self,key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + benchmark = MatchingBenchmark(recording, gt_sorting, params) + return benchmark + def plot_agreements(self, case_keys=None, figsize=(15,15)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -102,7 +109,7 @@ def plot_agreements(self, case_keys=None, figsize=(15,15)): for count, key in enumerate(case_keys): ax = axs[count] ax.set_title(self.cases[key]['label']) - plot_agreement_matrix(self.results[key]['gt_comparison'], ax=ax) + plot_agreement_matrix(self.get_result(key)['gt_comparison'], ax=ax) def plot_performances_vs_snr(self, case_keys=None, figsize=(15,15)): if case_keys is None: @@ -119,7 +126,7 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15,15)): analyzer = self.get_sorting_analyzer(key) metrics = analyzer.get_extension('quality_metrics').get_data() x = metrics["snr"].values - y = self.results[key]['gt_comparison'].get_performance()[k].values + y = self.get_result(key)['gt_comparison'].get_performance()[k].values ax.scatter(x, y, marker=".", label=label) if count == 2: @@ -132,121 +139,32 @@ def plot_collisions(self, case_keys=None, figsize=(15,15)): fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) for count, key in enumerate(case_keys): - templates_array = self.results[key]['templates'].templates_array + templates_array = self.get_result(key)['templates'].templates_array plot_comparison_collision_by_similarity( - self.results[key]['gt_collision'], templates_array, ax=axs[count], + self.get_result(key)['gt_collision'], templates_array, ax=axs[count], show_legend=True, mode="lines", good_only=False ) - # def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine"): - # fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(15, 10)) - - # benchmark.we.sorting.get_unit_spike_train(unit_id) - # template = benchmark.we.get_template(unit_id) - # a = template.reshape(template.size, 1).T - # count = 0 - # colors = ["r", "b"] - # for label in ["TP", "FN"]: - # seg_num = 0 # TODO: make compatible with multiple segments - # idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) - # idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - # intersection = np.where(np.isin(idx_2, idx_1))[0] - # intersection = np.random.permutation(intersection)[:nb_spikes] - # if len(intersection) == 0: - # print(f"No {label}s found for unit {unit_id}") - # continue - # ### Should be able to give a subset of waveforms only... - # ax = axs[count, 0] - # plot_unit_waveforms( - # benchmark.we, - # unit_ids=[unit_id], - # axes=[ax], - # unit_selected_waveforms={unit_id: intersection}, - # unit_colors={unit_id: colors[count]}, - # ) - # ax.set_title(label) - - # wfs = benchmark.we.get_waveforms(unit_id) - # wfs = wfs[intersection, :, :] - - # import sklearn - - # nb_spikes = len(wfs) - # b = wfs.reshape(nb_spikes, -1) - # distances = sklearn.metrics.pairwise_distances(a, b, metric).flatten() - # ax = axs[count, 1] - # ax.set_title(label) - # ax.hist(distances, color=colors[count]) - # ax.set_ylabel("# waveforms") - # ax.set_xlabel(metric) - - # count += 1 - # return fig, axs - - - # def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cosine"): - # templates = benchmark.templates - # nb_units = len(benchmark.we.unit_ids) - # colors = ["r", "b"] - - # results = {"TP": {"mean": [], "std": []}, "FN": {"mean": [], "std": []}} - - # for i in range(nb_units): - # unit_id = benchmark.we.unit_ids[i] - # idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - # wfs = benchmark.we.get_waveforms(unit_id) - # template = benchmark.we.get_template(unit_id) - # a = template.reshape(template.size, 1).T - - # for label in ["TP", "FN"]: - # idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - # intersection = np.where(np.isin(idx_2, idx_1))[0] - # intersection = np.random.permutation(intersection)[:nb_spikes] - # wfs_sliced = wfs[intersection, :, :] - - # import sklearn - - # all_spikes = len(wfs_sliced) - # if all_spikes > 0: - # b = wfs_sliced.reshape(all_spikes, -1) - # if metric == "cosine": - # distances = sklearn.metrics.pairwise.cosine_similarity(a, b).flatten() - # else: - # distances = sklearn.metrics.pairwise_distances(a, b, metric).flatten() - # results[label]["mean"] += [np.nanmean(distances)] - # results[label]["std"] += [np.nanstd(distances)] - # else: - # results[label]["mean"] += [0] - # results[label]["std"] += [0] - - # fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(15, 5)) - # for count, label in enumerate(["TP", "FN"]): - # ax = axs[count] - # idx = np.argsort(benchmark.metrics.snr) - # means = np.array(results[label]["mean"])[idx] - # stds = np.array(results[label]["std"])[idx] - # ax.errorbar(benchmark.metrics.snr[idx], means, yerr=stds, c=colors[count]) - # ax.set_title(label) - # ax.set_xlabel("snr") - # ax.set_ylabel(metric) - # return fig, axs - - def plot_comparison_matching( - benchmark, - comp_per_method, + def plot_comparison_matching(self, case_keys=None, performance_names=["accuracy", "recall", "precision"], colors=["g", "b", "r"], ylim=(-0.1, 1.1), + figsize=(15,15) ): - num_methods = len(benchmark.methods) + + if case_keys is None: + case_keys = list(self.cases.keys()) + + num_methods = len(case_keys) fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10)) - for i, method1 in enumerate(benchmark.methods): - for j, method2 in enumerate(benchmark.methods): + for i, key1 in enumerate(case_keys): + for j, key2 in enumerate(case_keys): if len(axs.shape) > 1: ax = axs[i, j] else: ax = axs[j] - comp1, comp2 = comp_per_method[method1], comp_per_method[method2] + comp1 = self.get_result(key1)['gt_comparison'] + comp2 = self.get_result(key2)['gt_comparison'] if i <= j: for performance, color in zip(performance_names, colors): perf1 = comp1.get_performance()[performance] @@ -260,11 +178,11 @@ def plot_comparison_matching( ax.set_aspect("equal") if j == i: - ax.set_ylabel(f"{method1}") + ax.set_ylabel(f"{key1}") else: ax.set_yticks([]) if i == j: - ax.set_xlabel(f"{method2}") + ax.set_xlabel(f"{key2}") else: ax.set_xticks([]) if i == num_methods - 1 and j == num_methods - 1: @@ -279,5 +197,4 @@ def plot_comparison_matching( ax.spines["right"].set_visible(False) ax.set_xticks([]) ax.set_yticks([]) - plt.tight_layout(h_pad=0, w_pad=0) - return fig, axs \ No newline at end of file + plt.tight_layout(h_pad=0, w_pad=0) \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index f47aa1433c..b58b831b17 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -23,7 +23,7 @@ def __init__(self, study_folder): self.folder = Path(study_folder) self.datasets = {} self.cases = {} - self.results = {} + self.benchmarks = {} self.scan_folder() @classmethod @@ -80,6 +80,9 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): return cls(study_folder) + def create_benchmark(self): + raise NotImplementedError + def scan_folder(self): if not (self.folder / "datasets").exists(): raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") @@ -98,19 +101,22 @@ def scan_folder(self): with open(self.folder / "cases.pickle", "rb") as f: self.cases = pickle.load(f) - self.results = {} + self.benchmarks = {} for key in self.cases: result_folder = self.folder / "results" / self.key_to_str(key) if result_folder.exists(): - self.results[key] = self.benchmark_class.load_folder(result_folder) + result = self.benchmark_class.load_folder(result_folder) + benchmark = self.create_benchmark(key) + benchmark.result.update(result) + self.benchmarks[key] = benchmark else: - self.results[key] = None + self.benchmarks[key] = None def __repr__(self): t = f"{self.__class__.__name__} {self.folder.stem} \n" t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" - num_computed = sum([1 for result in self.results.values() if result is not None]) + num_computed = sum([1 for benchmark in self.benchmarks.values() if benchmark is not None]) t += f" computed: {num_computed}\n" return t @@ -122,7 +128,7 @@ def key_to_str(self, key): else: raise ValueError("Keys for cases must str or tuple") - def remove_result(self, key): + def remove_benchmark(self, key): result_folder = self.folder / "results" / self.key_to_str(key) log_file = self.folder / "run_logs" / f"{self.key_to_str(key)}.json" @@ -131,9 +137,9 @@ def remove_result(self, key): for f in (log_file, ): if f.exists(): f.unlink() - self.results[key] = None + self.benchmarks[key] = None - def run(self, case_keys=None, keep=True, verbose=False): + def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): if case_keys is None: case_keys = self.cases.keys() @@ -145,17 +151,25 @@ def run(self, case_keys=None, keep=True, verbose=False): if keep and result_folder.exists(): continue elif not keep and result_folder.exists(): - self.remove_result(key) + self.remove_benchmark(key) job_keys.append(key) - self._run(job_keys) - - # save log - # TODO - - def _run(self, job_keys): - raise NotImplemented + for key in job_keys: + benchmark = self.create_benchmark(key) + benchmark.run() + self.benchmarks[key] = benchmark + benchmark.save_run(self.folder / "results" / self.key_to_str(key)) + + def compute_results(self, case_keys=None, verbose=False, **result_params): + if case_keys is None: + case_keys = self.cases.keys() + job_keys = [] + for key in case_keys: + benchmark = self.benchmarks[key] + assert benchmark is not None + benchmark.compute_result(**result_params) + benchmark.save_result(self.folder / "results" / self.key_to_str(key)) def create_sorting_analyzer_gt(self, case_keys=None, **kwargs): if case_keys is None: @@ -230,6 +244,10 @@ def get_metrics(self, key): def get_units_snr(self, key): """ """ return self.get_metrics(key)["snr"] + + def get_result(self, key): + return self.benchmarks[key].result + class Benchmark: From 35829cb3be7950097febe7cde3fe63dd0ce83e08 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 21 Feb 2024 12:00:30 +0100 Subject: [PATCH 113/192] WIP --- .../benchmark/benchmark_peak_localization.py | 1062 ++++++++++------- 1 file changed, 611 insertions(+), 451 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index 832661bd7e..4f0d8fa52e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -14,490 +14,650 @@ plot_unit_waveforms, ) from spikeinterface.postprocessing import compute_spike_locations -from spikeinterface.postprocessing.unit_localization import compute_center_of_mass, compute_monopolar_triangulation +from spikeinterface.postprocessing.unit_localization import compute_center_of_mass, compute_monopolar_triangulation, compute_grid_convolution from spikeinterface.core import get_noise_levels -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import BenchmarkBase, _simpleaxis import time import string, random import pylab as plt +from spikeinterface.core.template import Templates import os import numpy as np +import pickle +from .benchmark_tools import BenchmarkStudy, Benchmark +from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -class BenchmarkPeakLocalization: - def __init__(self, recording, gt_sorting, gt_positions, job_kwargs={}, tmp_folder=None, verbose=True, title=None): - self.verbose = verbose +class PeakLocalizationBenchmark(Benchmark): + + def __init__(self, recording, gt_sorting, gt_positions, params): self.recording = recording self.gt_sorting = gt_sorting - self.job_kwargs = job_kwargs - self.sampling_rate = self.recording.get_sampling_frequency() - self.title = title - self.waveforms = None - - self.tmp_folder = tmp_folder - if self.tmp_folder is None: - self.tmp_folder = os.path.join(".", "".join(random.choices(string.ascii_uppercase + string.digits, k=8))) - self.gt_positions = gt_positions - - def __del__(self): - import shutil - - shutil.rmtree(self.tmp_folder) - - def run(self, method, method_kwargs={}): - if self.waveforms is None: - self.waveforms = extract_waveforms( - self.recording, - self.gt_sorting, - self.tmp_folder, - ms_before=2.5, - ms_after=2.5, - max_spikes_per_unit=500, - return_scaled=False, - **self.job_kwargs, - ) - - t_start = time.time() - if self.title is None: - self.title = method - - unit_params = method_kwargs.copy() - - for key in ["ms_after", "ms_before"]: - if key in unit_params: - unit_params.pop(key) - - if method == "center_of_mass": - self.template_positions = compute_center_of_mass(self.waveforms, **unit_params) - elif method == "monopolar_triangulation": - self.template_positions = compute_monopolar_triangulation(self.waveforms, **unit_params) - - self.spike_positions = compute_spike_locations( - self.waveforms, method=method, method_kwargs=method_kwargs, **self.job_kwargs, outputs="by_unit" - ) - - self.raw_templates_results = {} - - for unit_ind, unit_id in enumerate(self.waveforms.sorting.unit_ids): - data = self.spike_positions[0][unit_id] - self.raw_templates_results[unit_id] = np.sqrt( + self.method = params['method'] + self.method_kwargs = params['method_kwargs'] + self.result = {} + + def run(self, **job_kwargs): + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer.select_random_spikes() + ext = sorting_analyzer.compute('fast_templates') + templates = ext.get_data(outputs='Templates') + ext = sorting_analyzer.compute("spike_locations", method=self.method, **self.method_kwargs) + spikes_locations = ext.get_data(outputs="by_unit") + self.result = {'spikes_locations' : spikes_locations} + self.result['templates'] = templates + + def compute_result(self, **result_params): + errors = {} + + for unit_ind, unit_id in enumerate(self.gt_sorting.unit_ids): + data = self.result['spikes_locations'][0][unit_id] + errors[unit_id] = np.sqrt( (data["x"] - self.gt_positions[unit_ind, 0]) ** 2 + (data["y"] - self.gt_positions[unit_ind, 1]) ** 2 ) - self.medians_over_templates = np.array( - [np.median(self.raw_templates_results[unit_id]) for unit_id in self.waveforms.sorting.unit_ids] + self.result['medians_over_templates'] = np.array( + [np.median(errors[unit_id]) for unit_id in self.gt_sorting.unit_ids] ) - self.mads_over_templates = np.array( + self.result['mads_over_templates'] = np.array( [ - np.median(np.abs(self.raw_templates_results[unit_id] - np.median(self.raw_templates_results[unit_id]))) - for unit_id in self.waveforms.sorting.unit_ids + np.median(np.abs(errors[unit_id] - np.median(errors[unit_id]))) + for unit_id in self.gt_sorting.unit_ids ] ) + self.result['errors'] = errors + + def save_run(self, folder): + self.result['templates'].to_zarr(folder / "templates") + locations_file = folder / "spikes_locations.pickle" + with open(locations_file, mode="wb") as f: + pickle.dump(self.result['spikes_locations'], f) + + def save_result(self, folder): + errors_file = folder / "errors.pickle" + with open(errors_file, mode="wb") as f: + pickle.dump(self.result['errors'], f) + np.save(folder / "medians_over_templates", self.result['medians_over_templates']) + np.save(folder / "mads_over_templates", self.result['mads_over_templates']) + + @classmethod + def load_folder(cls, folder): + result = {} + result['templates'] = Templates.from_zarr(folder / "templates") + if (folder / "errors.pickle").exists(): + with open(folder / "errors.pickle", "rb") as f: + result['errors'] = pickle.load(f) + if (folder / "spikes_locations.pickle").exists(): + with open(folder / "spikes_locations.pickle", "rb") as f: + result['spikes_locations'] = pickle.load(f) + if (folder / "medians_over_templates.npy").exists(): + result["medians_over_templates"] = np.load(folder / "medians_over_templates.npy") + result["mads_over_templates"] = np.load(folder / "mads_over_templates.npy") + return result + + +class PeakLocalizationStudy(BenchmarkStudy): + + benchmark_class = PeakLocalizationBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + gt_positions = self.cases[key]["gt_positions"] + params = self.cases[key]["params"] + benchmark = PeakLocalizationBenchmark(recording, gt_sorting, gt_positions, params) + return benchmark + + def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): + + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5)) + + + for count, key in enumerate(case_keys): + analyzer = self.get_sorting_analyzer(key) + metrics = analyzer.get_extension('quality_metrics').get_data() + snrs = metrics["snr"].values + result = self.get_result(key) + norms = np.linalg.norm(result['templates'].templates_array, axis=(1, 2)) + distances_to_center = np.linalg.norm(self.benchmarks[key].gt_positions[:, :2], axis=1) + zdx = np.argsort(distances_to_center) + idx = np.argsort(norms) + + from scipy.signal import savgol_filter + wdx = np.argsort(snrs) + + data = result["medians_over_templates"] + + axs[0].plot( + snrs[wdx], savgol_filter(data[wdx], smoothing_factor, 3), lw=2, label=self.cases[key]['label'] + ) + ymin = savgol_filter((data - result["mads_over_templates"])[wdx], smoothing_factor, 3) + ymax = savgol_filter((data + result["mads_over_templates"])[wdx], smoothing_factor, 3) + + axs[0].fill_between(snrs[wdx], ymin, ymax, alpha=0.5) + axs[0].set_xlabel("snr") + axs[0].set_ylabel("error (um)") + # ymin, ymax = ax.get_ylim() + # ax.set_ylim(0, ymax) + # ax.set_yscale('log') + + axs[1].plot( + distances_to_center[zdx], + savgol_filter(data[zdx], smoothing_factor, 3), + lw=2, + label=self.cases[key]['label'], + ) + ymin = savgol_filter((data - result["mads_over_templates"])[zdx], smoothing_factor, 3) + ymax = savgol_filter((data + result["mads_over_templates"])[zdx], smoothing_factor, 3) - def plot_template_errors(self, show_probe=True): - import spikeinterface.full as si - import pylab as plt - - si.plot_probe_map(self.recording) - plt.scatter(self.gt_positions[:, 0], self.gt_positions[:, 1], c=np.arange(len(self.gt_positions)), cmap="jet") - plt.scatter( - self.template_positions[:, 0], - self.template_positions[:, 1], - c=np.arange(len(self.template_positions)), - cmap="jet", - marker="v", - ) - - -def plot_comparison_positions(benchmarks, mode="average"): - norms = np.linalg.norm(benchmarks[0].waveforms.get_all_templates(mode=mode), axis=(1, 2)) - distances_to_center = np.linalg.norm(benchmarks[0].gt_positions[:, :2], axis=1) - zdx = np.argsort(distances_to_center) - idx = np.argsort(norms) - - snrs_tmp = compute_snrs(benchmarks[0].waveforms) - snrs = np.zeros(len(snrs_tmp)) - for k, v in snrs_tmp.items(): - snrs[int(k[1:])] = v - - wdx = np.argsort(snrs) - - plt.rc("font", size=13) - plt.rc("xtick", labelsize=12) - plt.rc("ytick", labelsize=12) - - fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(15, 10)) - ax = axs[0, 0] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - # ax.set_title(title) + axs[1].fill_between(distances_to_center[zdx], ymin, ymax, alpha=0.5) + axs[1].set_xlabel("distance to center (um)") - from scipy.signal import savgol_filter + x_means = [] + x_stds = [] + for count, key in enumerate(case_keys): + result = self.get_result(key)['medians_over_templates'] + x_means += [result.mean()] + x_stds += [result.std()] - smoothing_factor = 31 + y_means = [] + y_stds = [] + for count, key in enumerate(case_keys): + result = self.get_result(key)['mads_over_templates'] + y_means += [result.mean()] + y_stds += [result.std()] - for bench in benchmarks: - errors = np.linalg.norm(bench.template_positions[:, :2] - bench.gt_positions[:, :2], axis=1) - ax.plot(snrs[wdx], savgol_filter(errors[wdx], smoothing_factor, 3), label=bench.title) + colors = [f"C{i}" for i in range(len(x_means))] + axs[2].errorbar(x_means, y_means, xerr=x_stds, yerr=y_stds, fmt=".", c="0.5", alpha=0.5) + axs[2].scatter(x_means, y_means, c=colors, s=200) - ax.legend() - # ax.set_xlabel('norm') - ax.set_ylabel("error (um)") - ymin, ymax = ax.get_ylim() - ax.set_ylim(0, ymax) + axs[2].set_ylabel("error mads (um)") + axs[2].set_xlabel("error medians (um)") + ymin, ymax = axs[2].get_ylim() + axs[2].set_ylim(0, 12) - ax = axs[0, 1] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - for bench in benchmarks: - errors = np.linalg.norm(bench.template_positions[:, :2] - bench.gt_positions[:, :2], axis=1) - ax.plot(distances_to_center[zdx], savgol_filter(errors[zdx], smoothing_factor, 3), label=bench.title) - # ax.set_xlabel('distance to center (um)') - # ax.set_yticks([]) +class UnitLocalizationBenchmark(Benchmark): - ax = axs[0, 2] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - # ax.set_title(title) + def __init__(self, recording, gt_sorting, gt_positions, params): + self.recording = recording + self.gt_sorting = gt_sorting + self.gt_positions = gt_positions + self.method = params['method'] + self.method_kwargs = params['method_kwargs'] + self.result = {} + + def run(self, **job_kwargs): + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer.select_random_spikes() + params = {} + for key in ["ms_before", "ms_after"]: + if key in self.method_kwargs: + params[key] = self.method_kwargs[key] + sorting_analyzer.compute('waveforms', **params, **job_kwargs) + ext = sorting_analyzer.compute('templates') + templates = ext.get_data(outputs='Templates') + + if self.method == "center_of_mass": + unit_locations = compute_center_of_mass(sorting_analyzer, **self.method_kwargs) + elif self.method == "monopolar_triangulation": + unit_locations = compute_monopolar_triangulation(sorting_analyzer, **self.method_kwargs) + elif self.method == "grid_convolution": + unit_locations = compute_grid_convolution(sorting_analyzer, **self.method_kwargs) + + if (unit_locations.shape[1] == 2): + unit_locations = np.hstack((unit_locations, np.zeros((len(unit_locations), 1)))) + + self.result = {'unit_locations' : unit_locations} + self.result['templates'] = templates + + def compute_result(self, **result_params): + errors = np.linalg.norm(self.gt_positions - self.result['unit_locations'], axis=1) + self.result['errors'] = errors + + def save_run(self, folder): + self.result['templates'].to_zarr(folder / "templates") + np.save(folder / "unit_locations", self.result['unit_locations']) + + def save_result(self, folder): + np.save(folder / "errors", self.result['errors']) + + @classmethod + def load_folder(cls, folder): + result = {} + result['templates'] = Templates.from_zarr(folder / "templates") + result["unit_locations"] = np.load(folder / "unit_locations.npy") + if (folder / "errors.npy").exists(): + result["errors"] = np.load(folder / "errors.npy") + return result + + +class UnitLocalizationStudy(BenchmarkStudy): + + benchmark_class = UnitLocalizationBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + gt_positions = self.cases[key]["gt_positions"] + params = self.cases[key]["params"] + benchmark = UnitLocalizationBenchmark(recording, gt_sorting, gt_positions, params) + return benchmark + + def plot_template_errors(self, case_keys=None): + + if case_keys is None: + case_keys = list(self.cases.keys()) + fig, axs = plt.subplots(ncols=1, nrows=1, figsize=(15, 5)) + from spikeinterface.widgets import plot_probe_map + #plot_probe_map(self.benchmarks[case_keys[0]].recording, ax=axs) + axs.scatter(self.gt_positions[:, 0], self.gt_positions[:, 1], c=np.arange(len(self.gt_positions)), cmap="jet") + + for count, key in enumerate(case_keys): + result = self.get_result(key) + axs.scatter( + result['unit_locations'][:, 0], + result['unit_locations'][:, 1], + c=f'C{count}', + marker="v", + label=self.cases[key]['label'] + ) + axs.legend() - for count, bench in enumerate(benchmarks): - errors = np.linalg.norm(bench.template_positions[:, :2] - bench.gt_positions[:, :2], axis=1) - ax.bar([count], np.mean(errors), yerr=np.std(errors)) - # ax.set_xlabel('norms') - # ax.set_yticks([]) - # ax.set_ylim(ymin, ymax) +# def plot_comparison_positions(benchmarks, mode="average"): +# norms = np.linalg.norm(benchmarks[0].waveforms.get_all_templates(mode=mode), axis=(1, 2)) +# distances_to_center = np.linalg.norm(benchmarks[0].gt_positions[:, :2], axis=1) +# zdx = np.argsort(distances_to_center) +# idx = np.argsort(norms) + +# snrs_tmp = compute_snrs(benchmarks[0].waveforms) +# snrs = np.zeros(len(snrs_tmp)) +# for k, v in snrs_tmp.items(): +# snrs[int(k[1:])] = v - ax = axs[1, 0] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) +# wdx = np.argsort(snrs) + +# plt.rc("font", size=13) +# plt.rc("xtick", labelsize=12) +# plt.rc("ytick", labelsize=12) + +# fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(15, 10)) +# ax = axs[0, 0] +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) +# # ax.set_title(title) + +# from scipy.signal import savgol_filter + +# smoothing_factor = 31 + +# for bench in benchmarks: +# errors = np.linalg.norm(bench.template_positions[:, :2] - bench.gt_positions[:, :2], axis=1) +# ax.plot(snrs[wdx], savgol_filter(errors[wdx], smoothing_factor, 3), label=bench.title) + +# ax.legend() +# # ax.set_xlabel('norm') +# ax.set_ylabel("error (um)") +# ymin, ymax = ax.get_ylim() +# ax.set_ylim(0, ymax) + +# ax = axs[0, 1] +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) + +# for bench in benchmarks: +# errors = np.linalg.norm(bench.template_positions[:, :2] - bench.gt_positions[:, :2], axis=1) +# ax.plot(distances_to_center[zdx], savgol_filter(errors[zdx], smoothing_factor, 3), label=bench.title) - for bench in benchmarks: - ax.plot( - snrs[wdx], savgol_filter(bench.medians_over_templates[wdx], smoothing_factor, 3), lw=2, label=bench.title - ) - ymin = savgol_filter((bench.medians_over_templates - bench.mads_over_templates)[wdx], smoothing_factor, 3) - ymax = savgol_filter((bench.medians_over_templates + bench.mads_over_templates)[wdx], smoothing_factor, 3) - - ax.fill_between(snrs[wdx], ymin, ymax, alpha=0.5) - - ax.set_xlabel("snr") - ax.set_ylabel("error (um)") - # ymin, ymax = ax.get_ylim() - # ax.set_ylim(0, ymax) - # ax.set_yscale('log') - - ax = axs[1, 1] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - for bench in benchmarks: - ax.plot( - distances_to_center[zdx], - savgol_filter(bench.medians_over_templates[zdx], smoothing_factor, 3), - lw=2, - label=bench.title, - ) - ymin = savgol_filter((bench.medians_over_templates - bench.mads_over_templates)[zdx], smoothing_factor, 3) - ymax = savgol_filter((bench.medians_over_templates + bench.mads_over_templates)[zdx], smoothing_factor, 3) +# # ax.set_xlabel('distance to center (um)') +# # ax.set_yticks([]) - ax.fill_between(distances_to_center[zdx], ymin, ymax, alpha=0.5) +# ax = axs[0, 2] +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) +# # ax.set_title(title) - ax.set_xlabel("distance to center (um)") +# for count, bench in enumerate(benchmarks): +# errors = np.linalg.norm(bench.template_positions[:, :2] - bench.gt_positions[:, :2], axis=1) +# ax.bar([count], np.mean(errors), yerr=np.std(errors)) + +# # ax.set_xlabel('norms') +# # ax.set_yticks([]) +# # ax.set_ylim(ymin, ymax) - x_means = [] - x_stds = [] - for count, bench in enumerate(benchmarks): - x_means += [np.median(bench.medians_over_templates)] - x_stds += [np.std(bench.medians_over_templates)] - - # ax.set_yticks([]) - # ax.set_ylim(ymin, ymax) - - ax = axs[1, 2] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - y_means = [] - y_stds = [] - for count, bench in enumerate(benchmarks): - y_means += [np.median(bench.mads_over_templates)] - y_stds += [np.std(bench.mads_over_templates)] - - colors = [f"C{i}" for i in range(len(x_means))] - ax.errorbar(x_means, y_means, xerr=x_stds, yerr=y_stds, fmt=".", c="0.5", alpha=0.5) - ax.scatter(x_means, y_means, c=colors, s=200) +# ax = axs[1, 0] +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) + +# for bench in benchmarks: +# ax.plot( +# snrs[wdx], savgol_filter(bench.medians_over_templates[wdx], smoothing_factor, 3), lw=2, label=bench.title +# ) +# ymin = savgol_filter((bench.medians_over_templates - bench.mads_over_templates)[wdx], smoothing_factor, 3) +# ymax = savgol_filter((bench.medians_over_templates + bench.mads_over_templates)[wdx], smoothing_factor, 3) - ax.set_ylabel("error mads (um)") - ax.set_xlabel("error medians (um)") - # ax.set_yticks([] - ymin, ymax = ax.get_ylim() - ax.set_ylim(0, 12) - - -def plot_comparison_inferences(benchmarks, bin_size=np.arange(0.1, 20, 1)): - import numpy as np - import sklearn - import scipy.stats - import spikeinterface.full as si - - plt.rc("font", size=11) - plt.rc("xtick", labelsize=12) - plt.rc("ytick", labelsize=12) - - from scipy.signal import savgol_filter - - smoothing_factor = 5 - - fig = plt.figure(figsize=(10, 12)) - gs = fig.add_gridspec(8, 10) - - ax3 = fig.add_subplot(gs[0:2, 6:10]) - ax4 = fig.add_subplot(gs[2:4, 6:10]) - - ax3.spines["top"].set_visible(False) - ax3.spines["right"].set_visible(False) - ax3.set_ylabel("correlation coefficient") - ax3.set_xticks([]) - - ax4.spines["top"].set_visible(False) - ax4.spines["right"].set_visible(False) - ax4.set_ylabel("chi squared") - ax4.set_xlabel("bin size (um)") - - def chiSquared(p, q): - return 0.5 * np.sum((p - q) ** 2 / (p + q + 1e-6)) - - for count, benchmark in enumerate(benchmarks): - spikes = benchmark.spike_positions[0] - units = benchmark.waveforms.sorting.unit_ids - all_x = np.concatenate([spikes[unit_id]["x"] for unit_id in units]) - all_y = np.concatenate([spikes[unit_id]["y"] for unit_id in units]) - - gt_positions = benchmark.gt_positions[:, :2] - real_x = np.concatenate([gt_positions[c, 0] * np.ones(len(spikes[i]["x"])) for c, i in enumerate(units)]) - real_y = np.concatenate([gt_positions[c, 1] * np.ones(len(spikes[i]["x"])) for c, i in enumerate(units)]) - - r_y = np.zeros(len(bin_size)) - c_y = np.zeros(len(bin_size)) - for i, b in enumerate(bin_size): - all_bins = np.arange(all_y.min(), all_y.max(), b) - x1, y2 = np.histogram(all_y, bins=all_bins) - x2, y2 = np.histogram(real_y, bins=all_bins) - - r_y[i] = np.corrcoef(x1, x2)[0, 1] - c_y[i] = chiSquared(x1, x2) - - r_x = np.zeros(len(bin_size)) - c_x = np.zeros(len(bin_size)) - for i, b in enumerate(bin_size): - all_bins = np.arange(all_x.min(), all_x.max(), b) - x1, y2 = np.histogram(all_x, bins=all_bins) - x2, y2 = np.histogram(real_x, bins=all_bins) - - r_x[i] = np.corrcoef(x1, x2)[0, 1] - c_x[i] = chiSquared(x1, x2) - - ax3.plot(bin_size, savgol_filter((r_y + r_x) / 2, smoothing_factor, 3), c=f"C{count}", label=benchmark.title) - ax4.plot(bin_size, savgol_filter((c_y + c_x) / 2, smoothing_factor, 3), c=f"C{count}", label=benchmark.title) - - r_control_y = np.zeros(len(bin_size)) - c_control_y = np.zeros(len(bin_size)) - for i, b in enumerate(bin_size): - all_bins = np.arange(all_y.min(), all_y.max(), b) - random_y = all_y.min() + (all_y.max() - all_y.min()) * np.random.rand(len(all_y)) - x1, y2 = np.histogram(random_y, bins=all_bins) - x2, y2 = np.histogram(real_y, bins=all_bins) - - r_control_y[i] = np.corrcoef(x1, x2)[0, 1] - c_control_y[i] = chiSquared(x1, x2) - - r_control_x = np.zeros(len(bin_size)) - c_control_x = np.zeros(len(bin_size)) - for i, b in enumerate(bin_size): - all_bins = np.arange(all_x.min(), all_x.max(), b) - random_x = all_x.min() + (all_x.max() - all_x.min()) * np.random.rand(len(all_y)) - x1, y2 = np.histogram(random_x, bins=all_bins) - x2, y2 = np.histogram(real_x, bins=all_bins) - - r_control_x[i] = np.corrcoef(x1, x2)[0, 1] - c_control_x[i] = chiSquared(x1, x2) - - ax3.plot(bin_size, savgol_filter((r_control_y + r_control_x) / 2, smoothing_factor, 3), "0.5", label="Control") - ax4.plot(bin_size, savgol_filter((c_control_y + c_control_x) / 2, smoothing_factor, 3), "0.5", label="Control") - - ax4.legend() - - ax0 = fig.add_subplot(gs[0:3, 0:3]) - - si.plot_probe_map(benchmarks[0].recording, ax=ax0) - ax0.scatter(all_x, all_y, alpha=0.5) - ax0.scatter(gt_positions[:, 0], gt_positions[:, 1], c="k") - ax0.set_xticks([]) - ymin, ymax = ax0.get_ylim() - xmin, xmax = ax0.get_xlim() - ax0.spines["top"].set_visible(False) - ax0.spines["right"].set_visible(False) - # ax0.spines['left'].set_visible(False) - ax0.spines["bottom"].set_visible(False) - ax0.set_xlabel("") - - ax1 = fig.add_subplot(gs[0:3, 3]) - ax1.hist(all_y, bins=100, orientation="horizontal", alpha=0.5) - ax1.hist(real_y, bins=100, orientation="horizontal", color="k", alpha=0.5) - ax1.spines["top"].set_visible(False) - ax1.spines["right"].set_visible(False) - ax1.set_yticks([]) - ax1.set_ylim(ymin, ymax) - ax1.set_xlabel("# spikes") - - ax2 = fig.add_subplot(gs[3, 0:3]) - ax2.hist(all_x, bins=100, alpha=0.5) - ax2.hist(real_x, bins=100, color="k", alpha=0.5) - ax2.spines["top"].set_visible(False) - ax2.spines["right"].set_visible(False) - ax2.set_xlim(xmin, xmax) - ax2.set_xlabel(r"x ($\mu$m)") - ax2.set_ylabel("# spikes") - - -def plot_comparison_precision(benchmarks): - import pylab as plt - - fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(15, 10), squeeze=False) - - for bench in benchmarks: - # gt_positions = bench.gt_positions - # template_positions = bench.template_positions - # dx = np.abs(gt_positions[:, 0] - template_positions[:, 0]) - # dy = np.abs(gt_positions[:, 1] - template_positions[:, 1]) - # dz = np.abs(gt_positions[:, 2] - template_positions[:, 2]) - # ax = axes[0, 0] - # ax.errorbar(np.arange(3), [dx.mean(), dy.mean(), dz.mean()], yerr=[dx.std(), dy.std(), dz.std()], label=bench.title) - - spikes = bench.spike_positions[0] - units = bench.waveforms.sorting.unit_ids - all_x = np.concatenate([spikes[unit_id]["x"] for unit_id in units]) - all_y = np.concatenate([spikes[unit_id]["y"] for unit_id in units]) - all_z = np.concatenate([spikes[unit_id]["z"] for unit_id in units]) - - gt_positions = bench.gt_positions - real_x = np.concatenate([gt_positions[c, 0] * np.ones(len(spikes[i]["x"])) for c, i in enumerate(units)]) - real_y = np.concatenate([gt_positions[c, 1] * np.ones(len(spikes[i]["y"])) for c, i in enumerate(units)]) - real_z = np.concatenate([gt_positions[c, 2] * np.ones(len(spikes[i]["z"])) for c, i in enumerate(units)]) - - dx = np.abs(all_x - real_x) - dy = np.abs(all_y - real_y) - dz = np.abs(all_z - real_z) - ax = axes[0, 0] - ax.errorbar( - np.arange(3), [dx.mean(), dy.mean(), dz.mean()], yerr=[dx.std(), dy.std(), dz.std()], label=bench.title - ) - ax.legend() - ax.set_ylabel("error (um)") - ax.set_xticks(np.arange(3), ["x", "y", "z"]) - _simpleaxis(ax) - - x_means = [] - x_stds = [] - for count, bench in enumerate(benchmarks): - x_means += [np.mean(bench.means_over_templates)] - x_stds += [np.std(bench.means_over_templates)] - - # ax.set_yticks([]) - # ax.set_ylim(ymin, ymax) - - ax = axes[0, 1] - _simpleaxis(ax) - - y_means = [] - y_stds = [] - for count, bench in enumerate(benchmarks): - y_means += [np.mean(bench.stds_over_templates)] - y_stds += [np.std(bench.stds_over_templates)] - - colors = [f"C{i}" for i in range(len(x_means))] - ax.errorbar(x_means, y_means, xerr=x_stds, yerr=y_stds, fmt=".", c="0.5", alpha=0.5) - ax.scatter(x_means, y_means, c=colors, s=200) - - ax.set_ylabel("error variances (um)") - ax.set_xlabel("error means (um)") - # ax.set_yticks([] - ymin, ymax = ax.get_ylim() - # ax.set_ylim(0, 25) - ax.legend() - - -def plot_figure_1(benchmark, mode="average", cell_ind="auto"): - if cell_ind == "auto": - norms = np.linalg.norm(benchmark.gt_positions[:, :2], axis=1) - cell_ind = np.argsort(norms)[0] - - import pylab as plt - - fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(15, 10)) - plot_probe_map(benchmark.recording, ax=axs[0, 0]) - axs[0, 0].scatter(benchmark.gt_positions[:, 0], benchmark.gt_positions[:, 1], c="k") - axs[0, 0].scatter(benchmark.gt_positions[cell_ind, 0], benchmark.gt_positions[cell_ind, 1], c="r") - plt.rc("font", size=13) - plt.rc("xtick", labelsize=12) - plt.rc("ytick", labelsize=12) - - import spikeinterface.full as si - - sorting = benchmark.waveforms.sorting - unit_id = sorting.unit_ids[cell_ind] - - spikes_seg0 = sorting.to_spike_vector(concatenated=False)[0] - mask = spikes_seg0["unit_index"] == cell_ind - times = spikes_seg0[mask] / sorting.get_sampling_frequency() - - print(benchmark.recording) - # si.plot_traces(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1]) - # axs[0, 1].set_ylabel('Neurons') - - # si.plot_spikes_on_traces(benchmark.waveforms, unit_ids=[unit_id], time_range=(times[0]-0.01, times[0] + 0.1), unit_colors={unit_id : 'r'}, ax=axs[0, 1], - # channel_ids=benchmark.recording.channel_ids[120:180], ) - - waveforms = extract_waveforms( - benchmark.recording, - benchmark.gt_sorting, - None, - mode="memory", - ms_before=2.5, - ms_after=2.5, - max_spikes_per_unit=100, - return_scaled=False, - **benchmark.job_kwargs, - sparse=True, - method="radius", - radius_um=100, - ) - - unit_id = waveforms.sorting.unit_ids[cell_ind] - - si.plot_unit_templates(waveforms, unit_ids=[unit_id], ax=axs[1, 0], same_axis=True, unit_colors={unit_id: "r"}) - ymin, ymax = axs[1, 0].get_ylim() - xmin, xmax = axs[1, 0].get_xlim() - axs[1, 0].set_title("Averaged template") - si.plot_unit_waveforms(waveforms, unit_ids=[unit_id], ax=axs[1, 1], same_axis=True, unit_colors={unit_id: "r"}) - axs[1, 1].set_xlim(xmin, xmax) - axs[1, 1].set_ylim(ymin, ymax) - axs[1, 1].set_title("Single spikes") - - for i in [0, 1]: - for j in [0, 1]: - axs[i, j].spines["top"].set_visible(False) - axs[i, j].spines["right"].set_visible(False) - - for i in [1]: - for j in [0, 1]: - axs[i, j].spines["left"].set_visible(False) - axs[i, j].spines["bottom"].set_visible(False) - axs[i, j].set_xticks([]) - axs[i, j].set_yticks([]) - axs[i, j].set_title("") +# ax.fill_between(snrs[wdx], ymin, ymax, alpha=0.5) + +# ax.set_xlabel("snr") +# ax.set_ylabel("error (um)") +# # ymin, ymax = ax.get_ylim() +# # ax.set_ylim(0, ymax) +# # ax.set_yscale('log') + +# ax = axs[1, 1] +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) + +# for bench in benchmarks: +# ax.plot( +# distances_to_center[zdx], +# savgol_filter(bench.medians_over_templates[zdx], smoothing_factor, 3), +# lw=2, +# label=bench.title, +# ) +# ymin = savgol_filter((bench.medians_over_templates - bench.mads_over_templates)[zdx], smoothing_factor, 3) +# ymax = savgol_filter((bench.medians_over_templates + bench.mads_over_templates)[zdx], smoothing_factor, 3) + +# ax.fill_between(distances_to_center[zdx], ymin, ymax, alpha=0.5) + +# ax.set_xlabel("distance to center (um)") + +# x_means = [] +# x_stds = [] +# for count, bench in enumerate(benchmarks): +# x_means += [np.median(bench.medians_over_templates)] +# x_stds += [np.std(bench.medians_over_templates)] + +# # ax.set_yticks([]) +# # ax.set_ylim(ymin, ymax) + +# ax = axs[1, 2] +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) + +# y_means = [] +# y_stds = [] +# for count, bench in enumerate(benchmarks): +# y_means += [np.median(bench.mads_over_templates)] +# y_stds += [np.std(bench.mads_over_templates)] + +# colors = [f"C{i}" for i in range(len(x_means))] +# ax.errorbar(x_means, y_means, xerr=x_stds, yerr=y_stds, fmt=".", c="0.5", alpha=0.5) +# ax.scatter(x_means, y_means, c=colors, s=200) + +# ax.set_ylabel("error mads (um)") +# ax.set_xlabel("error medians (um)") +# # ax.set_yticks([] +# ymin, ymax = ax.get_ylim() +# ax.set_ylim(0, 12) + + +# def plot_comparison_inferences(benchmarks, bin_size=np.arange(0.1, 20, 1)): +# import numpy as np +# import sklearn +# import scipy.stats +# import spikeinterface.full as si + +# plt.rc("font", size=11) +# plt.rc("xtick", labelsize=12) +# plt.rc("ytick", labelsize=12) + +# from scipy.signal import savgol_filter + +# smoothing_factor = 5 + +# fig = plt.figure(figsize=(10, 12)) +# gs = fig.add_gridspec(8, 10) + +# ax3 = fig.add_subplot(gs[0:2, 6:10]) +# ax4 = fig.add_subplot(gs[2:4, 6:10]) + +# ax3.spines["top"].set_visible(False) +# ax3.spines["right"].set_visible(False) +# ax3.set_ylabel("correlation coefficient") +# ax3.set_xticks([]) + +# ax4.spines["top"].set_visible(False) +# ax4.spines["right"].set_visible(False) +# ax4.set_ylabel("chi squared") +# ax4.set_xlabel("bin size (um)") + +# def chiSquared(p, q): +# return 0.5 * np.sum((p - q) ** 2 / (p + q + 1e-6)) + +# for count, benchmark in enumerate(benchmarks): +# spikes = benchmark.spike_positions[0] +# units = benchmark.waveforms.sorting.unit_ids +# all_x = np.concatenate([spikes[unit_id]["x"] for unit_id in units]) +# all_y = np.concatenate([spikes[unit_id]["y"] for unit_id in units]) + +# gt_positions = benchmark.gt_positions[:, :2] +# real_x = np.concatenate([gt_positions[c, 0] * np.ones(len(spikes[i]["x"])) for c, i in enumerate(units)]) +# real_y = np.concatenate([gt_positions[c, 1] * np.ones(len(spikes[i]["x"])) for c, i in enumerate(units)]) + +# r_y = np.zeros(len(bin_size)) +# c_y = np.zeros(len(bin_size)) +# for i, b in enumerate(bin_size): +# all_bins = np.arange(all_y.min(), all_y.max(), b) +# x1, y2 = np.histogram(all_y, bins=all_bins) +# x2, y2 = np.histogram(real_y, bins=all_bins) + +# r_y[i] = np.corrcoef(x1, x2)[0, 1] +# c_y[i] = chiSquared(x1, x2) + +# r_x = np.zeros(len(bin_size)) +# c_x = np.zeros(len(bin_size)) +# for i, b in enumerate(bin_size): +# all_bins = np.arange(all_x.min(), all_x.max(), b) +# x1, y2 = np.histogram(all_x, bins=all_bins) +# x2, y2 = np.histogram(real_x, bins=all_bins) + +# r_x[i] = np.corrcoef(x1, x2)[0, 1] +# c_x[i] = chiSquared(x1, x2) + +# ax3.plot(bin_size, savgol_filter((r_y + r_x) / 2, smoothing_factor, 3), c=f"C{count}", label=benchmark.title) +# ax4.plot(bin_size, savgol_filter((c_y + c_x) / 2, smoothing_factor, 3), c=f"C{count}", label=benchmark.title) + +# r_control_y = np.zeros(len(bin_size)) +# c_control_y = np.zeros(len(bin_size)) +# for i, b in enumerate(bin_size): +# all_bins = np.arange(all_y.min(), all_y.max(), b) +# random_y = all_y.min() + (all_y.max() - all_y.min()) * np.random.rand(len(all_y)) +# x1, y2 = np.histogram(random_y, bins=all_bins) +# x2, y2 = np.histogram(real_y, bins=all_bins) + +# r_control_y[i] = np.corrcoef(x1, x2)[0, 1] +# c_control_y[i] = chiSquared(x1, x2) + +# r_control_x = np.zeros(len(bin_size)) +# c_control_x = np.zeros(len(bin_size)) +# for i, b in enumerate(bin_size): +# all_bins = np.arange(all_x.min(), all_x.max(), b) +# random_x = all_x.min() + (all_x.max() - all_x.min()) * np.random.rand(len(all_y)) +# x1, y2 = np.histogram(random_x, bins=all_bins) +# x2, y2 = np.histogram(real_x, bins=all_bins) + +# r_control_x[i] = np.corrcoef(x1, x2)[0, 1] +# c_control_x[i] = chiSquared(x1, x2) + +# ax3.plot(bin_size, savgol_filter((r_control_y + r_control_x) / 2, smoothing_factor, 3), "0.5", label="Control") +# ax4.plot(bin_size, savgol_filter((c_control_y + c_control_x) / 2, smoothing_factor, 3), "0.5", label="Control") + +# ax4.legend() + +# ax0 = fig.add_subplot(gs[0:3, 0:3]) + +# si.plot_probe_map(benchmarks[0].recording, ax=ax0) +# ax0.scatter(all_x, all_y, alpha=0.5) +# ax0.scatter(gt_positions[:, 0], gt_positions[:, 1], c="k") +# ax0.set_xticks([]) +# ymin, ymax = ax0.get_ylim() +# xmin, xmax = ax0.get_xlim() +# ax0.spines["top"].set_visible(False) +# ax0.spines["right"].set_visible(False) +# # ax0.spines['left'].set_visible(False) +# ax0.spines["bottom"].set_visible(False) +# ax0.set_xlabel("") + +# ax1 = fig.add_subplot(gs[0:3, 3]) +# ax1.hist(all_y, bins=100, orientation="horizontal", alpha=0.5) +# ax1.hist(real_y, bins=100, orientation="horizontal", color="k", alpha=0.5) +# ax1.spines["top"].set_visible(False) +# ax1.spines["right"].set_visible(False) +# ax1.set_yticks([]) +# ax1.set_ylim(ymin, ymax) +# ax1.set_xlabel("# spikes") + +# ax2 = fig.add_subplot(gs[3, 0:3]) +# ax2.hist(all_x, bins=100, alpha=0.5) +# ax2.hist(real_x, bins=100, color="k", alpha=0.5) +# ax2.spines["top"].set_visible(False) +# ax2.spines["right"].set_visible(False) +# ax2.set_xlim(xmin, xmax) +# ax2.set_xlabel(r"x ($\mu$m)") +# ax2.set_ylabel("# spikes") + + +# def plot_comparison_precision(benchmarks): +# import pylab as plt + +# fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(15, 10), squeeze=False) + +# for bench in benchmarks: +# # gt_positions = bench.gt_positions +# # template_positions = bench.template_positions +# # dx = np.abs(gt_positions[:, 0] - template_positions[:, 0]) +# # dy = np.abs(gt_positions[:, 1] - template_positions[:, 1]) +# # dz = np.abs(gt_positions[:, 2] - template_positions[:, 2]) +# # ax = axes[0, 0] +# # ax.errorbar(np.arange(3), [dx.mean(), dy.mean(), dz.mean()], yerr=[dx.std(), dy.std(), dz.std()], label=bench.title) + +# spikes = bench.spike_positions[0] +# units = bench.waveforms.sorting.unit_ids +# all_x = np.concatenate([spikes[unit_id]["x"] for unit_id in units]) +# all_y = np.concatenate([spikes[unit_id]["y"] for unit_id in units]) +# all_z = np.concatenate([spikes[unit_id]["z"] for unit_id in units]) + +# gt_positions = bench.gt_positions +# real_x = np.concatenate([gt_positions[c, 0] * np.ones(len(spikes[i]["x"])) for c, i in enumerate(units)]) +# real_y = np.concatenate([gt_positions[c, 1] * np.ones(len(spikes[i]["y"])) for c, i in enumerate(units)]) +# real_z = np.concatenate([gt_positions[c, 2] * np.ones(len(spikes[i]["z"])) for c, i in enumerate(units)]) + +# dx = np.abs(all_x - real_x) +# dy = np.abs(all_y - real_y) +# dz = np.abs(all_z - real_z) +# ax = axes[0, 0] +# ax.errorbar( +# np.arange(3), [dx.mean(), dy.mean(), dz.mean()], yerr=[dx.std(), dy.std(), dz.std()], label=bench.title +# ) +# ax.legend() +# ax.set_ylabel("error (um)") +# ax.set_xticks(np.arange(3), ["x", "y", "z"]) +# _simpleaxis(ax) + +# x_means = [] +# x_stds = [] +# for count, bench in enumerate(benchmarks): +# x_means += [np.mean(bench.means_over_templates)] +# x_stds += [np.std(bench.means_over_templates)] + +# # ax.set_yticks([]) +# # ax.set_ylim(ymin, ymax) + +# ax = axes[0, 1] +# _simpleaxis(ax) + +# y_means = [] +# y_stds = [] +# for count, bench in enumerate(benchmarks): +# y_means += [np.mean(bench.stds_over_templates)] +# y_stds += [np.std(bench.stds_over_templates)] + +# colors = [f"C{i}" for i in range(len(x_means))] +# ax.errorbar(x_means, y_means, xerr=x_stds, yerr=y_stds, fmt=".", c="0.5", alpha=0.5) +# ax.scatter(x_means, y_means, c=colors, s=200) + +# ax.set_ylabel("error variances (um)") +# ax.set_xlabel("error means (um)") +# # ax.set_yticks([] +# ymin, ymax = ax.get_ylim() +# # ax.set_ylim(0, 25) +# ax.legend() + + +# def plot_figure_1(benchmark, mode="average", cell_ind="auto"): +# if cell_ind == "auto": +# norms = np.linalg.norm(benchmark.gt_positions[:, :2], axis=1) +# cell_ind = np.argsort(norms)[0] + +# import pylab as plt + +# fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(15, 10)) +# plot_probe_map(benchmark.recording, ax=axs[0, 0]) +# axs[0, 0].scatter(benchmark.gt_positions[:, 0], benchmark.gt_positions[:, 1], c="k") +# axs[0, 0].scatter(benchmark.gt_positions[cell_ind, 0], benchmark.gt_positions[cell_ind, 1], c="r") +# plt.rc("font", size=13) +# plt.rc("xtick", labelsize=12) +# plt.rc("ytick", labelsize=12) + +# import spikeinterface.full as si + +# sorting = benchmark.waveforms.sorting +# unit_id = sorting.unit_ids[cell_ind] + +# spikes_seg0 = sorting.to_spike_vector(concatenated=False)[0] +# mask = spikes_seg0["unit_index"] == cell_ind +# times = spikes_seg0[mask] / sorting.get_sampling_frequency() + +# print(benchmark.recording) +# # si.plot_traces(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1]) +# # axs[0, 1].set_ylabel('Neurons') + +# # si.plot_spikes_on_traces(benchmark.waveforms, unit_ids=[unit_id], time_range=(times[0]-0.01, times[0] + 0.1), unit_colors={unit_id : 'r'}, ax=axs[0, 1], +# # channel_ids=benchmark.recording.channel_ids[120:180], ) + +# waveforms = extract_waveforms( +# benchmark.recording, +# benchmark.gt_sorting, +# None, +# mode="memory", +# ms_before=2.5, +# ms_after=2.5, +# max_spikes_per_unit=100, +# return_scaled=False, +# **benchmark.job_kwargs, +# sparse=True, +# method="radius", +# radius_um=100, +# ) + +# unit_id = waveforms.sorting.unit_ids[cell_ind] + +# si.plot_unit_templates(waveforms, unit_ids=[unit_id], ax=axs[1, 0], same_axis=True, unit_colors={unit_id: "r"}) +# ymin, ymax = axs[1, 0].get_ylim() +# xmin, xmax = axs[1, 0].get_xlim() +# axs[1, 0].set_title("Averaged template") +# si.plot_unit_waveforms(waveforms, unit_ids=[unit_id], ax=axs[1, 1], same_axis=True, unit_colors={unit_id: "r"}) +# axs[1, 1].set_xlim(xmin, xmax) +# axs[1, 1].set_ylim(ymin, ymax) +# axs[1, 1].set_title("Single spikes") + +# for i in [0, 1]: +# for j in [0, 1]: +# axs[i, j].spines["top"].set_visible(False) +# axs[i, j].spines["right"].set_visible(False) + +# for i in [1]: +# for j in [0, 1]: +# axs[i, j].spines["left"].set_visible(False) +# axs[i, j].spines["bottom"].set_visible(False) +# axs[i, j].set_xticks([]) +# axs[i, j].set_yticks([]) +# axs[i, j].set_title("") From 874e899ef8c5c8491ea04ae0fe2ff733e4ab09e9 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 21 Feb 2024 12:51:44 +0100 Subject: [PATCH 114/192] WIP --- .../benchmark/benchmark_peak_localization.py | 194 ++++++------------ 1 file changed, 61 insertions(+), 133 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index 4f0d8fa52e..cb2af1b369 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -35,16 +35,22 @@ def __init__(self, recording, gt_sorting, gt_positions, params): self.recording = recording self.gt_sorting = gt_sorting self.gt_positions = gt_positions - self.method = params['method'] - self.method_kwargs = params['method_kwargs'] + self.params = params self.result = {} + self.templates_params = {} + for key in ["ms_before", "ms_after"]: + if key in self.params: + self.templates_params[key] = self.params[key] + else: + self.templates_params[key] = 2 + self.params[key] = 2 def run(self, **job_kwargs): sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) sorting_analyzer.select_random_spikes() - ext = sorting_analyzer.compute('fast_templates') + ext = sorting_analyzer.compute('fast_templates', **self.templates_params) templates = ext.get_data(outputs='Templates') - ext = sorting_analyzer.compute("spike_locations", method=self.method, **self.method_kwargs) + ext = sorting_analyzer.compute("spike_locations", **self.params) spikes_locations = ext.get_data(outputs="by_unit") self.result = {'spikes_locations' : spikes_locations} self.result['templates'] = templates @@ -124,7 +130,11 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): snrs = metrics["snr"].values result = self.get_result(key) norms = np.linalg.norm(result['templates'].templates_array, axis=(1, 2)) - distances_to_center = np.linalg.norm(self.benchmarks[key].gt_positions[:, :2], axis=1) + + coordinates = self.benchmarks[key].gt_positions[:, :2].copy() + coordinates[:, 0] -= coordinates[:, 0].mean() + coordinates[:, 1] -= coordinates[:, 1].mean() + distances_to_center = np.linalg.norm(coordinates, axis=1) zdx = np.argsort(distances_to_center) idx = np.argsort(norms) @@ -142,9 +152,6 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): axs[0].fill_between(snrs[wdx], ymin, ymax, alpha=0.5) axs[0].set_xlabel("snr") axs[0].set_ylabel("error (um)") - # ymin, ymax = ax.get_ylim() - # ax.set_ylim(0, ymax) - # ax.set_yscale('log') axs[1].plot( distances_to_center[zdx], @@ -180,6 +187,7 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): axs[2].set_xlabel("error medians (um)") ymin, ymax = axs[2].get_ylim() axs[2].set_ylim(0, 12) + axs[1].legend() @@ -192,15 +200,18 @@ def __init__(self, recording, gt_sorting, gt_positions, params): self.method = params['method'] self.method_kwargs = params['method_kwargs'] self.result = {} + self.waveforms_params = {} + for key in ["ms_before", "ms_after"]: + if key in self.method_kwargs: + self.waveforms_params[key] = self.method_kwargs.pop(key) + else: + self.waveforms_params[key] = 2 def run(self, **job_kwargs): - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory') sorting_analyzer.select_random_spikes() - params = {} - for key in ["ms_before", "ms_after"]: - if key in self.method_kwargs: - params[key] = self.method_kwargs[key] - sorting_analyzer.compute('waveforms', **params, **job_kwargs) + + sorting_analyzer.compute('waveforms', **self.waveforms_params, **job_kwargs) ext = sorting_analyzer.compute('templates') templates = ext.get_data(outputs='Templates') @@ -218,7 +229,7 @@ def run(self, **job_kwargs): self.result['templates'] = templates def compute_result(self, **result_params): - errors = np.linalg.norm(self.gt_positions - self.result['unit_locations'], axis=1) + errors = np.linalg.norm(self.gt_positions[:, :2] - self.result['unit_locations'][:, :2], axis=1) self.result['errors'] = errors def save_run(self, folder): @@ -251,7 +262,7 @@ def create_benchmark(self, key): return benchmark def plot_template_errors(self, case_keys=None): - + if case_keys is None: case_keys = list(self.cases.keys()) fig, axs = plt.subplots(ncols=1, nrows=1, figsize=(15, 5)) @@ -271,132 +282,49 @@ def plot_template_errors(self, case_keys=None): axs.legend() -# def plot_comparison_positions(benchmarks, mode="average"): -# norms = np.linalg.norm(benchmarks[0].waveforms.get_all_templates(mode=mode), axis=(1, 2)) -# distances_to_center = np.linalg.norm(benchmarks[0].gt_positions[:, :2], axis=1) -# zdx = np.argsort(distances_to_center) -# idx = np.argsort(norms) - -# snrs_tmp = compute_snrs(benchmarks[0].waveforms) -# snrs = np.zeros(len(snrs_tmp)) -# for k, v in snrs_tmp.items(): -# snrs[int(k[1:])] = v - -# wdx = np.argsort(snrs) - -# plt.rc("font", size=13) -# plt.rc("xtick", labelsize=12) -# plt.rc("ytick", labelsize=12) - -# fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(15, 10)) -# ax = axs[0, 0] -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) -# # ax.set_title(title) - -# from scipy.signal import savgol_filter - -# smoothing_factor = 31 - -# for bench in benchmarks: -# errors = np.linalg.norm(bench.template_positions[:, :2] - bench.gt_positions[:, :2], axis=1) -# ax.plot(snrs[wdx], savgol_filter(errors[wdx], smoothing_factor, 3), label=bench.title) - -# ax.legend() -# # ax.set_xlabel('norm') -# ax.set_ylabel("error (um)") -# ymin, ymax = ax.get_ylim() -# ax.set_ylim(0, ymax) - -# ax = axs[0, 1] -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) - -# for bench in benchmarks: -# errors = np.linalg.norm(bench.template_positions[:, :2] - bench.gt_positions[:, :2], axis=1) -# ax.plot(distances_to_center[zdx], savgol_filter(errors[zdx], smoothing_factor, 3), label=bench.title) - -# # ax.set_xlabel('distance to center (um)') -# # ax.set_yticks([]) - -# ax = axs[0, 2] -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) -# # ax.set_title(title) - -# for count, bench in enumerate(benchmarks): -# errors = np.linalg.norm(bench.template_positions[:, :2] - bench.gt_positions[:, :2], axis=1) -# ax.bar([count], np.mean(errors), yerr=np.std(errors)) - -# # ax.set_xlabel('norms') -# # ax.set_yticks([]) -# # ax.set_ylim(ymin, ymax) - -# ax = axs[1, 0] -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) - -# for bench in benchmarks: -# ax.plot( -# snrs[wdx], savgol_filter(bench.medians_over_templates[wdx], smoothing_factor, 3), lw=2, label=bench.title -# ) -# ymin = savgol_filter((bench.medians_over_templates - bench.mads_over_templates)[wdx], smoothing_factor, 3) -# ymax = savgol_filter((bench.medians_over_templates + bench.mads_over_templates)[wdx], smoothing_factor, 3) - -# ax.fill_between(snrs[wdx], ymin, ymax, alpha=0.5) - -# ax.set_xlabel("snr") -# ax.set_ylabel("error (um)") -# # ymin, ymax = ax.get_ylim() -# # ax.set_ylim(0, ymax) -# # ax.set_yscale('log') - -# ax = axs[1, 1] -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) + def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): -# for bench in benchmarks: -# ax.plot( -# distances_to_center[zdx], -# savgol_filter(bench.medians_over_templates[zdx], smoothing_factor, 3), -# lw=2, -# label=bench.title, -# ) -# ymin = savgol_filter((bench.medians_over_templates - bench.mads_over_templates)[zdx], smoothing_factor, 3) -# ymax = savgol_filter((bench.medians_over_templates + bench.mads_over_templates)[zdx], smoothing_factor, 3) + if case_keys is None: + case_keys = list(self.cases.keys()) -# ax.fill_between(distances_to_center[zdx], ymin, ymax, alpha=0.5) + fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5)) -# ax.set_xlabel("distance to center (um)") + for count, key in enumerate(case_keys): + analyzer = self.get_sorting_analyzer(key) + metrics = analyzer.get_extension('quality_metrics').get_data() + snrs = metrics["snr"].values + result = self.get_result(key) + norms = np.linalg.norm(result['templates'].templates_array, axis=(1, 2)) -# x_means = [] -# x_stds = [] -# for count, bench in enumerate(benchmarks): -# x_means += [np.median(bench.medians_over_templates)] -# x_stds += [np.std(bench.medians_over_templates)] + coordinates = self.benchmarks[key].gt_positions[:, :2].copy() + coordinates[:, 0] -= coordinates[:, 0].mean() + coordinates[:, 1] -= coordinates[:, 1].mean() + distances_to_center = np.linalg.norm(coordinates, axis=1) + zdx = np.argsort(distances_to_center) + idx = np.argsort(norms) -# # ax.set_yticks([]) -# # ax.set_ylim(ymin, ymax) + from scipy.signal import savgol_filter + wdx = np.argsort(snrs) -# ax = axs[1, 2] -# ax.spines["top"].set_visible(False) -# ax.spines["right"].set_visible(False) + data = result["errors"] -# y_means = [] -# y_stds = [] -# for count, bench in enumerate(benchmarks): -# y_means += [np.median(bench.mads_over_templates)] -# y_stds += [np.std(bench.mads_over_templates)] + axs[0].plot( + snrs[wdx], savgol_filter(data[wdx], smoothing_factor, 3), lw=2, label=self.cases[key]['label'] + ) + + axs[0].set_xlabel("snr") + axs[0].set_ylabel("error (um)") -# colors = [f"C{i}" for i in range(len(x_means))] -# ax.errorbar(x_means, y_means, xerr=x_stds, yerr=y_stds, fmt=".", c="0.5", alpha=0.5) -# ax.scatter(x_means, y_means, c=colors, s=200) + axs[1].plot( + distances_to_center[zdx], + savgol_filter(data[zdx], smoothing_factor, 3), + lw=2, + label=self.cases[key]['label'], + ) -# ax.set_ylabel("error mads (um)") -# ax.set_xlabel("error medians (um)") -# # ax.set_yticks([] -# ymin, ymax = ax.get_ylim() -# ax.set_ylim(0, 12) + axs[1].legend() + axs[1].set_xlabel("distance to center (um)") + axs[2].bar([count], np.mean(data), yerr=np.std(data)) # def plot_comparison_inferences(benchmarks, bin_size=np.arange(0.1, 20, 1)): From c70f76fa354292c7aa69e3e95c32ae9812bba609 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 21 Feb 2024 13:48:33 +0100 Subject: [PATCH 115/192] refactor benchmark estimation --- .../benchmark/benchmark_motion_estimation.py | 1373 +++++++++++------ .../benchmark/benchmark_tools.py | 293 ++-- 2 files changed, 1008 insertions(+), 658 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index b2cf95881f..6c715c19d9 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -5,6 +5,7 @@ import time from pathlib import Path +import pickle from spikeinterface.core import get_noise_levels from spikeinterface.extractors import read_mearec @@ -15,7 +16,7 @@ from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import BenchmarkBase, _simpleaxis +from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis from spikeinterface.widgets import plot_probe_map @@ -27,590 +28,934 @@ import MEArec as mr -class BenchmarkMotionEstimationMearec(BenchmarkBase): - _array_names = ( - "noise_levels", - "gt_unit_positions", - "peaks", - "selected_peaks", - "motion", - "temporal_bins", - "spatial_bins", - "peak_locations", - "gt_motion", - ) - - def __init__( - self, - mearec_filename, - title="", - detect_kwargs={}, - select_kwargs=None, - localize_kwargs={}, - estimate_motion_kwargs={}, - folder=None, - do_preprocessing=True, - job_kwargs={"chunk_duration": "1s", "n_jobs": -1, "progress_bar": True, "verbose": True}, - overwrite=False, - parent_benchmark=None, - ): - BenchmarkBase.__init__( - self, folder=folder, title=title, overwrite=overwrite, job_kwargs=job_kwargs, parent_benchmark=None - ) - - self._args.extend([str(mearec_filename)]) - - self.mearec_filename = mearec_filename - self.raw_recording, self.gt_sorting = read_mearec(self.mearec_filename) - self.do_preprocessing = do_preprocessing - - self._recording = None - self.detect_kwargs = detect_kwargs.copy() - self.select_kwargs = select_kwargs.copy() if select_kwargs is not None else None - self.localize_kwargs = localize_kwargs.copy() - self.estimate_motion_kwargs = estimate_motion_kwargs.copy() - - self._kwargs.update( - dict( - detect_kwargs=self.detect_kwargs, - select_kwargs=self.select_kwargs, - localize_kwargs=self.localize_kwargs, - estimate_motion_kwargs=self.estimate_motion_kwargs, - ) - ) - - @property - def recording(self): - if self._recording is None: - if self.do_preprocessing: - self._recording = bandpass_filter(self.raw_recording) - self._recording = common_reference(self._recording) - self._recording = zscore(self._recording) - else: - self._recording = self.raw_recording - return self._recording +class MotionEstimationBenchmark(Benchmark): + def __init__(self, recording, gt_sorting, params, + unit_locations, unit_displacements, displacement_sampling_frequency, + direction="y"): + Benchmark.__init__(self) + self.recording = recording + self.gt_sorting = gt_sorting + self.params = params + self.unit_locations = unit_locations + self.unit_displacements = unit_displacements + self.displacement_sampling_frequency = displacement_sampling_frequency + self.direction = direction + self.direction_dim = ["x", "y"].index(direction) - def run(self): - if self.folder is not None: - if self.folder.exists() and not self.overwrite: - raise ValueError(f"The folder {self.folder} is not empty") + def run(self, **job_kwargs): + p = self.params - self.noise_levels = get_noise_levels(self.recording, return_scaled=False) + noise_levels = get_noise_levels(self.recording, return_scaled=False) t0 = time.perf_counter() - self.peaks = detect_peaks( - self.recording, noise_levels=self.noise_levels, **self.detect_kwargs, **self.job_kwargs + peaks = detect_peaks( + self.recording, noise_levels=noise_levels, **p["detect_kwargs"], **job_kwargs ) t1 = time.perf_counter() - if self.select_kwargs is not None: - self.selected_peaks = select_peaks(self.peaks, **self.select_kwargs, **self.job_kwargs) + if p["select_kwargs"] is not None: + selected_peaks = select_peaks(self.peaks, **p["select_kwargs"], **job_kwargs) else: - self.selected_peaks = self.peaks + selected_peaks = peaks + t2 = time.perf_counter() - self.peak_locations = localize_peaks( - self.recording, self.selected_peaks, **self.localize_kwargs, **self.job_kwargs + peak_locations = localize_peaks( + self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs ) t3 = time.perf_counter() - self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( - self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs + motion, temporal_bins, spatial_bins = estimate_motion( + self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"] ) - t4 = time.perf_counter() - self.run_times = dict( + step_run_times = dict( detect_peaks=t1 - t0, select_peaks=t2 - t1, localize_peaks=t3 - t2, estimate_motion=t4 - t3, ) - self.compute_gt_motion() + self.result["step_run_times"] = step_run_times + self.result["raw_motion"] = motion + self.result["temporal_bins"] = temporal_bins + self.result["spatial_bins"] = spatial_bins - # align globally gt_motion and motion to avoid offsets - self.motion += np.median(self.gt_motion - self.motion) - - ## save folder - if self.folder is not None: - self.save_to_folder() - - def run_estimate_motion(self): - # usefull to re run only the motion estimate with peak localization - t3 = time.perf_counter() - self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( - self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs - ) - t4 = time.perf_counter() - self.compute_gt_motion() + # self.compute_gt_motion() # align globally gt_motion and motion to avoid offsets - self.motion += np.median(self.gt_motion - self.motion) - self.run_times["estimate_motion"] = t4 - t3 + # self.motion += np.median(self.gt_motion - self.motion) - ## save folder - if self.folder is not None: - self.save_to_folder() - def compute_gt_motion(self): - self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins) - template_locations = np.array(mr.load_recordings(self.mearec_filename).template_locations) - assert len(template_locations.shape) == 3 - mid = template_locations.shape[1] // 2 - unit_mid_positions = template_locations[:, mid, 2] - unit_motions = self.gt_unit_positions - unit_mid_positions - # unit_positions = np.mean(self.gt_unit_positions, axis=0) + # self.result = {'sorting' : sorting} + # self.result['templates'] = self.templates - if self.spatial_bins is None: - self.gt_motion = np.mean(unit_motions, axis=1)[:, None] - channel_positions = self.recording.get_channel_locations() - probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() - center = (probe_y_min + probe_y_max) // 2 - self.spatial_bins = np.array([center]) - else: - # time, units - self.gt_motion = np.zeros_like(self.motion) - for t in range(self.gt_unit_positions.shape[0]): - f = scipy.interpolate.interp1d(unit_mid_positions, unit_motions[t, :], fill_value="extrapolate") - self.gt_motion[t, :] = f(self.spatial_bins) - - def plot_true_drift(self, scaling_probe=1.5, figsize=(15, 10), axes=None): - if axes is None: - fig = plt.figure(figsize=figsize) - gs = fig.add_gridspec(1, 8, wspace=0) + def compute_result(self, **result_params): + raw_motion = self.result["raw_motion"] + temporal_bins = self.result["temporal_bins"] + spatial_bins = self.result["spatial_bins"] - if axes is None: - ax = fig.add_subplot(gs[:2]) - else: - ax = axes[0] - plot_probe_map(self.recording, ax=ax) - _simpleaxis(ax) - - mr_recording = mr.load_recordings(self.mearec_filename) - - for loc in mr_recording.template_locations[::2]: - if len(mr_recording.template_locations.shape) == 3: - ax.plot([loc[0, 1], loc[-1, 1]], [loc[0, 2], loc[-1, 2]], alpha=0.7, lw=2) - else: - ax.scatter([loc[1]], [loc[2]], alpha=0.7, s=100) - - # ymin, ymax = ax.get_ylim() - ax.set_ylabel("depth (um)") - ax.set_xlabel(None) - # ax.set_yticks(np.arange(-600,600,100), np.arange(-600,600,100)) + # interpolation units to gt_motion + num_units = self.unit_locations.shape[0] - # ax.set_ylim(scaling_probe*probe_y_min, scaling_probe*probe_y_max) - if axes is None: - ax = fig.add_subplot(gs[2:7]) - else: - ax = axes[1] + print(self.unit_locations.shape) + # self.unit_displacements = unit_displacements + # self.displacement_sampling_frequency = displacement_sampling_frequency - for i in range(self.gt_unit_positions.shape[1]): - ax.plot(self.temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") - for i in range(self.gt_motion.shape[1]): - depth = self.spatial_bins[i] - ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=4) + # self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins) - # ax.set_ylim(ymin, ymax) - ax.set_xlabel("time (s)") - _simpleaxis(ax) - ax.set_yticks([]) - ax.spines["left"].set_visible(False) + # template_locations = np.array(mr.load_recordings(self.mearec_filename).template_locations) + # assert len(template_locations.shape) == 3 + # mid = template_locations.shape[1] // 2 + # unit_mid_positions = template_locations[:, mid, 2] - channel_positions = self.recording.get_channel_locations() - probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() - ax.set_ylim(scaling_probe * probe_y_min, scaling_probe * probe_y_max) + # unit_motions = self.gt_unit_positions - unit_mid_positions + # # unit_positions = np.mean(self.gt_unit_positions, axis=0) - ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) - ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) + # time interpolatation of unit displacements + times = np.arange(self.unit_displacements.shape[0]) / self.displacement_sampling_frequency + f = scipy.interpolate.interp1d(times, self.unit_displacements, axis=0) + unit_displacements = f(temporal_bins) - if axes is None: - ax = fig.add_subplot(gs[7]) + # spatial interpolataion of units discplacement + if spatial_bins.shape[0] == 1: + # rigid + gt_motion = np.mean(unit_displacements, axis=1)[:, None] else: - ax = axes[2] - # plot_probe_map(self.recording, ax=ax) - _simpleaxis(ax) + gt_motion = np.zeros_like(raw_motion) + for t in range(temporal_bins.shape[0]): + f = scipy.interpolate.interp1d(self.unit_locations[:, self.direction_dim], unit_displacements[t, :], fill_value="extrapolate") + gt_motion[t, :] = f(spatial_bins) + # print("gt_motion", gt_motion.shape, raw_motion.shape) - ax.hist(self.gt_unit_positions[30, :], 50, orientation="horizontal", color="0.5") - ax.set_yticks([]) - ax.set_xlabel("# neurons") - - def plot_peaks_probe(self, alpha=0.05, figsize=(15, 10)): - fig, axs = plt.subplots(ncols=2, sharey=True, figsize=figsize) - ax = axs[0] - plot_probe_map(self.recording, ax=ax) - ax.scatter(self.peak_locations["x"], self.peak_locations["y"], color="k", s=1, alpha=alpha) - ax.set_xlabel("x") - ax.set_ylabel("y") - if "z" in self.peak_locations.dtype.fields: - ax = axs[1] - ax.scatter(self.peak_locations["z"], self.peak_locations["y"], color="k", s=1, alpha=alpha) - ax.set_xlabel("z") - ax.set_xlim(0, 100) - - def plot_peaks(self, scaling_probe=1.5, show_drift=True, show_histogram=True, alpha=0.05, figsize=(15, 10)): - fig = plt.figure(figsize=figsize) - if show_histogram: - gs = fig.add_gridspec(1, 4) - else: - gs = fig.add_gridspec(1, 3) - # Create the Axes. + # align globally gt_motion and motion to avoid offsets + motion = raw_motion.copy() + motion += np.median(gt_motion - motion) + self.result["gt_motion"] = gt_motion + self.result["motion"] = motion + + def save_run(self, folder): + for k in ("raw_motion", "temporal_bins", "spatial_bins"): + np.save(folder / f"{k}.npy", self.result[k]) + + for k in ('step_run_times', ): + with open(folder / f"{k}.pickle", mode="wb") as f: + pickle.dump(self.result[k], f) + + def save_result(self, folder): + for k in ("gt_motion", "motion"): + np.save(folder / f"{k}.npy", self.result[k]) + + @classmethod + def load_folder(cls, folder): + result = {} + # run + for k in ("raw_motion", "temporal_bins", "spatial_bins"): + result[k] = np.load(folder / f"{k}.npy") + + for k in ('step_run_times', ): + with open(folder / f"{k}.pickle", "rb") as f: + result[k] = pickle.load(f) + + # result + for k in ("gt_motion", "motion"): + file = folder / f"{k}.npy" + if file.exists(): + result[k] = np.load(file) + + return result + + + +class MotionEstimationStudy(BenchmarkStudy): + + benchmark_class = MotionEstimationBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + init_kwargs = self.cases[key]["init_kwargs"] + benchmark = MotionEstimationBenchmark(recording, gt_sorting, params, **init_kwargs) + return benchmark + + def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): + if case_keys is None: + case_keys = list(self.cases.keys()) + + for key in case_keys: + + bench = self.benchmarks[key] - ax0 = fig.add_subplot(gs[0]) - plot_probe_map(self.recording, ax=ax0) - _simpleaxis(ax0) + fig = plt.figure(figsize=figsize) + gs = fig.add_gridspec(1, 8, wspace=0) - # ymin, ymax = ax.get_ylim() - ax0.set_ylabel("depth (um)") - ax0.set_xlabel(None) + # probe and units + ax = ax0 = fig.add_subplot(gs[:2]) + plot_probe_map(bench.recording, ax=ax) + _simpleaxis(ax) + unit_locations = bench.unit_locations + ax.scatter(unit_locations[:, 0], unit_locations[:, 1], alpha=0.7, s=100) + ax.set_ylabel("depth (um)") + ax.set_xlabel(None) - ax = ax1 = fig.add_subplot(gs[1:3]) - x = self.selected_peaks["sample_index"] / self.recording.get_sampling_frequency() - y = self.peak_locations["y"] - ax.scatter(x, y, s=1, color="k", alpha=alpha) + ax.set_aspect('auto') - ax.set_title(self.title) - # xmin, xmax = ax.get_xlim() - # ax.plot([xmin, xmax], [probe_y_min, probe_y_min], 'k--', alpha=0.5) - # ax.plot([xmin, xmax], [probe_y_max, probe_y_max], 'k--', alpha=0.5) + # dirft + ax = ax1 = fig.add_subplot(gs[2:7]) + ax1.sharey(ax0) + temporal_bins = bench.result["temporal_bins"] + spatial_bins = bench.result["spatial_bins"] + gt_motion = bench.result["gt_motion"] - _simpleaxis(ax) - # ax.set_yticks([]) - # ax.set_ylim(scaling_probe*probe_y_min, scaling_probe*probe_y_max) - ax.spines["left"].set_visible(False) - ax.set_xlabel("time (s)") - channel_positions = self.recording.get_channel_locations() - probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() - ax.set_ylim(scaling_probe * probe_y_min, scaling_probe * probe_y_max) - - ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) - ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) - - if show_drift: - if self.spatial_bins is None: - center = (probe_y_min + probe_y_max) // 2 - ax.plot(self.temporal_bins, self.gt_motion[:, 0] + center, color="green", lw=1.5) - ax.plot(self.temporal_bins, self.motion[:, 0] + center, color="orange", lw=1.5) - else: - for i in range(self.gt_motion.shape[1]): - depth = self.spatial_bins[i] - ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=1.5) - ax.plot(self.temporal_bins, self.motion[:, i] + depth, color="orange", lw=1.5) - - if show_histogram: - ax2 = fig.add_subplot(gs[3]) - ax2.hist(self.peak_locations["y"], bins=1000, orientation="horizontal") - - ax2.axhline(probe_y_min, color="k", ls="--", alpha=0.5) - ax2.axhline(probe_y_max, color="k", ls="--", alpha=0.5) - - ax2.set_xlabel("density") - _simpleaxis(ax2) - # ax.set_ylabel('') + # for i in range(self.gt_unit_positions.shape[1]): + # ax.plot(temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + + for i in range(gt_motion.shape[1]): + depth = spatial_bins[i] + ax.plot(temporal_bins, gt_motion[:, i] + depth, color="green", lw=4) + ax.set_xlabel("time (s)") + _simpleaxis(ax) ax.set_yticks([]) - ax2.sharey(ax0) + ax.spines["left"].set_visible(False) - ax1.sharey(ax0) + channel_positions = bench.recording.get_channel_locations() + probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() + # ax.set_ylim(scaling_probe * probe_y_min, scaling_probe * probe_y_max) - def plot_motion_corrected_peaks(self, scaling_probe=1.5, alpha=0.05, figsize=(15, 10), show_probe=True, axes=None): - if axes is None: - fig = plt.figure(figsize=figsize) - if show_probe: - gs = fig.add_gridspec(1, 5) - else: - gs = fig.add_gridspec(1, 4) - # Create the Axes. - - if show_probe: - if axes is None: - ax0 = ax = fig.add_subplot(gs[0]) - else: - ax0 = ax = axes[0] - plot_probe_map(self.recording, ax=ax) - _simpleaxis(ax) + ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) + ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) - ymin, ymax = ax.get_ylim() - ax.set_ylabel("depth (um)") - ax.set_xlabel(None) - channel_positions = self.recording.get_channel_locations() - probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() + ax = ax2= fig.add_subplot(gs[7]) + ax2.sharey(ax0) + _simpleaxis(ax) + ax.hist(unit_locations[:, bench.direction_dim], bins=50, orientation="horizontal", color="0.5") + ax.set_yticks([]) + ax.set_xlabel("# neurons") - peak_locations_corrected = correct_motion_on_peaks( - self.selected_peaks, - self.peak_locations, - self.recording.sampling_frequency, - self.motion, - self.temporal_bins, - self.spatial_bins, - direction="y", - ) - if axes is None: - if show_probe: - ax1 = ax = fig.add_subplot(gs[1:3]) - else: - ax1 = ax = fig.add_subplot(gs[0:2]) - else: - if show_probe: - ax1 = ax = axes[1] - else: - ax1 = ax = axes[0] + label = self.cases[key]["label"] + ax1.set_title(label) - _simpleaxis(ax) + # ax0.set_ylim() - x = self.selected_peaks["sample_index"] / self.recording.get_sampling_frequency() - y = self.peak_locations["y"] - ax.scatter(x, y, s=1, color="k", alpha=alpha) - ax.set_title(self.title) + def plot_errors(self, case_keys=None, figsize=None, lim=None): - ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) - ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) + if case_keys is None: + case_keys = list(self.cases.keys()) - ax.set_xlabel("time (s)") + for key in case_keys: - if axes is None: - if show_probe: - ax2 = ax = fig.add_subplot(gs[3:5]) - else: - ax2 = ax = fig.add_subplot(gs[2:4]) - else: - if show_probe: - ax2 = ax = axes[2] - else: - ax2 = ax = axes[1] + bench = self.benchmarks[key] + label = self.cases[key]["label"] - _simpleaxis(ax) - y = peak_locations_corrected["y"] - ax.scatter(x, y, s=1, color="k", alpha=alpha) + gt_motion = bench.result["gt_motion"] + motion = bench.result["motion"] + temporal_bins = bench.result["temporal_bins"] + spatial_bins = bench.result["spatial_bins"] - ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) - ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) - ax.set_xlabel("time (s)") + fig = plt.figure(figsize=figsize) - if show_probe: - ax0.set_ylim(scaling_probe * probe_y_min, scaling_probe * probe_y_max) - ax1.sharey(ax0) - ax2.sharey(ax0) - else: - ax1.set_ylim(scaling_probe * probe_y_min, scaling_probe * probe_y_max) - ax2.sharey(ax1) - - def estimation_vs_depth(self, show_only=8, figsize=(15, 10)): - fig, axs = plt.subplots(ncols=2, figsize=figsize, sharey=True) - - n = self.motion.shape[1] - step = int(np.ceil(max(1, n / show_only))) - colors = plt.cm.get_cmap("jet", n) - for i in range(0, n, step): - ax = axs[0] - ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) - ax.plot( - self.temporal_bins, - self.motion[:, i], - lw=1.5, - ls="-", - color=colors(i), - label=f"{self.spatial_bins[i]:0.1f}", - ) + gs = fig.add_gridspec(2, 2) - ax = axs[1] - ax.plot(self.temporal_bins, self.motion[:, i] - self.gt_motion[:, i], lw=1.5, ls="-", color=colors(i)) + errors = gt_motion - motion - ax = axs[0] - ax.set_title(self.title) - ax.legend() - ax.set_ylabel("drift estimated and GT(um)") - ax.set_xlabel("time (s)") - _simpleaxis(ax) + channel_positions = bench.recording.get_channel_locations() + probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() - ax = axs[1] - ax.set_ylabel("error (um)") - ax.set_xlabel("time (s)") - _simpleaxis(ax) + ax = fig.add_subplot(gs[0, :]) + im = ax.imshow( + np.abs(errors).T, + aspect="auto", + interpolation="nearest", + origin="lower", + extent=(temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1]), + ) + plt.colorbar(im, ax=ax, label="error") + ax.set_ylabel("depth (um)") + ax.set_xlabel("time (s)") + ax.set_title(label) + if lim is not None: + im.set_clim(0, lim) + + ax = fig.add_subplot(gs[1, 0]) + mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) + ax.plot(temporal_bins, mean_error) + ax.set_xlabel("time (s)") + ax.set_ylabel("error") + _simpleaxis(ax) + if lim is not None: + ax.set_ylim(0, lim) + + ax = fig.add_subplot(gs[1, 1]) + depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) + ax.plot(spatial_bins, depth_error) + ax.axvline(probe_y_min, color="k", ls="--", alpha=0.5) + ax.axvline(probe_y_max, color="k", ls="--", alpha=0.5) + ax.set_xlabel("depth (um)") + ax.set_ylabel("error") + _simpleaxis(ax) + if lim is not None: + ax.set_ylim(0, lim) - def view_errors(self, figsize=(15, 10), lim=None): - fig = plt.figure(figsize=figsize) - gs = fig.add_gridspec(2, 2) + def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, figsize=(15, 5)): - errors = self.gt_motion - self.motion + if case_keys is None: + case_keys = list(self.cases.keys()) - channel_positions = self.recording.get_channel_locations() - probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() + fig, axes = plt.subplots(1, 3, figsize=figsize) + + for count, key in enumerate(case_keys): - ax = fig.add_subplot(gs[0, :]) - im = ax.imshow( - np.abs(errors).T, - aspect="auto", - interpolation="nearest", - origin="lower", - extent=(self.temporal_bins[0], self.temporal_bins[-1], self.spatial_bins[0], self.spatial_bins[-1]), - ) - plt.colorbar(im, ax=ax, label="error") - ax.set_ylabel("depth (um)") - ax.set_xlabel("time (s)") - ax.set_title(self.title) - if lim is not None: - im.set_clim(0, lim) - - ax = fig.add_subplot(gs[1, 0]) - mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) - ax.plot(self.temporal_bins, mean_error) - ax.set_xlabel("time (s)") - ax.set_ylabel("error") - _simpleaxis(ax) - if lim is not None: - ax.set_ylim(0, lim) - - ax = fig.add_subplot(gs[1, 1]) - depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - ax.plot(self.spatial_bins, depth_error) - ax.axvline(probe_y_min, color="k", ls="--", alpha=0.5) - ax.axvline(probe_y_max, color="k", ls="--", alpha=0.5) - ax.set_xlabel("depth (um)") - ax.set_ylabel("error") - _simpleaxis(ax) - if lim is not None: - ax.set_ylim(0, lim) - - return fig - - -def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colors=None): - if axes is None: - fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - - for count, benchmark in enumerate(benchmarks): - c = colors[count] if colors is not None else None - errors = benchmark.gt_motion - benchmark.motion - mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) - depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - - axes[0].plot(benchmark.temporal_bins, mean_error, lw=1, label=benchmark.title, color=c) - parts = axes[1].violinplot(mean_error, [count], showmeans=True) - if c is not None: - for pc in parts["bodies"]: - pc.set_facecolor(c) - pc.set_edgecolor(c) - for k in parts: - if k != "bodies": - # for line in parts[k]: - parts[k].set_color(c) - axes[2].plot(benchmark.spatial_bins, depth_error, label=benchmark.title, color=c) - - ax0 = ax = axes[0] - ax.set_xlabel("Time [s]") - ax.set_ylabel("Error [μm]") - if show_legend: - ax.legend() - _simpleaxis(ax) - - ax1 = axes[1] - # ax.set_ylabel('error') - ax1.set_yticks([]) - ax1.set_xticks([]) - _simpleaxis(ax1) - - ax2 = axes[2] - ax2.set_yticks([]) - ax2.set_xlabel("Depth [μm]") - # ax.set_ylabel('error') - channel_positions = benchmark.recording.get_channel_locations() - probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() - ax2.axvline(probe_y_min, color="k", ls="--", alpha=0.5) - ax2.axvline(probe_y_max, color="k", ls="--", alpha=0.5) - - _simpleaxis(ax2) - - # ax1.sharey(ax0) - # ax2.sharey(ax0) - - -def plot_error_map_several_benchmarks(benchmarks, axes=None, lim=15, figsize=(10, 10)): - if axes is None: - fig, axes = plt.subplots(nrows=len(benchmarks), sharex=True, sharey=True, figsize=figsize) - else: - fig = axes[0].figure - - for count, benchmark in enumerate(benchmarks): - errors = benchmark.gt_motion - benchmark.motion - - channel_positions = benchmark.recording.get_channel_locations() - probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() + bench = self.benchmarks[key] + label = self.cases[key]["label"] - ax = axes[count] - im = ax.imshow( - np.abs(errors).T, - aspect="auto", - interpolation="nearest", - origin="lower", - extent=( - benchmark.temporal_bins[0], - benchmark.temporal_bins[-1], - benchmark.spatial_bins[0], - benchmark.spatial_bins[-1], - ), - ) - fig.colorbar(im, ax=ax, label="error") - ax.set_ylabel("depth (um)") + gt_motion = bench.result["gt_motion"] + motion = bench.result["motion"] + temporal_bins = bench.result["temporal_bins"] + spatial_bins = bench.result["spatial_bins"] - ax.set_title(benchmark.title) - if lim is not None: - im.set_clim(0, lim) - axes[-1].set_xlabel("time (s)") - return fig + c = colors[count] if colors is not None else None + errors = gt_motion - motion + mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) + depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) -def plot_motions_several_benchmarks(benchmarks): - fig, ax = plt.subplots(figsize=(15, 5)) + axes[0].plot(temporal_bins, mean_error, lw=1, label=label, color=c) + parts = axes[1].violinplot(mean_error, [count], showmeans=True) + if c is not None: + for pc in parts["bodies"]: + pc.set_facecolor(c) + pc.set_edgecolor(c) + for k in parts: + if k != "bodies": + # for line in parts[k]: + parts[k].set_color(c) + axes[2].plot(spatial_bins, depth_error, label=label, color=c) - ax.plot(list(benchmarks)[0].temporal_bins, list(benchmarks)[0].gt_motion[:, 0], lw=2, c="k", label="real motion") - for count, benchmark in enumerate(benchmarks): - ax.plot(benchmark.temporal_bins, benchmark.motion.mean(1), lw=1, c=f"C{count}", label=benchmark.title) - ax.fill_between( - benchmark.temporal_bins, - benchmark.motion.mean(1) - benchmark.motion.std(1), - benchmark.motion.mean(1) + benchmark.motion.std(1), - color=f"C{count}", - alpha=0.25, - ) + ax0 = ax = axes[0] + ax.set_xlabel("Time [s]") + ax.set_ylabel("Error [μm]") + if show_legend: + ax.legend() + _simpleaxis(ax) - # ax.legend() - ax.set_ylabel("depth (um)") - ax.set_xlabel("time (s)") - _simpleaxis(ax) - - -def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): - if ax is None: - fig, ax = plt.subplots(figsize=(5, 5)) - - for count, benchmark in enumerate(benchmarks): - color = colors[count] if colors is not None else None - - if detailed: - bottom = 0 - i = 0 - patterns = ["/", "\\", "|", "*"] - for key, value in benchmark.run_times.items(): - if count == 0: - label = key.replace("_", " ") - else: - label = None - ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i]) - bottom += value - i += 1 - else: - total_run_time = np.sum([value for key, value in benchmark.run_times.items()]) - ax.bar([count], [total_run_time], color=color, edgecolor="black") - - # ax.legend() - ax.set_ylabel("speed (s)") - _simpleaxis(ax) - ax.set_xticks([]) - # ax.set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) + ax1 = axes[1] + # ax.set_ylabel('error') + ax1.set_yticks([]) + ax1.set_xticks([]) + _simpleaxis(ax1) + + ax2 = axes[2] + ax2.set_yticks([]) + ax2.set_xlabel("Depth [μm]") + # ax.set_ylabel('error') + channel_positions = bench.recording.get_channel_locations() + probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() + ax2.axvline(probe_y_min, color="k", ls="--", alpha=0.5) + ax2.axvline(probe_y_max, color="k", ls="--", alpha=0.5) + + _simpleaxis(ax2) + + # ax1.sharey(ax0) + # ax2.sharey(ax0) + + +# TODO : plot_peaks +# TODO : plot_motion_corrected_peaks +# TODO : plot_error_map_several_benchmarks + + +# class BenchmarkMotionEstimationMearec(BenchmarkBase): +# _array_names = ( +# "noise_levels", +# "gt_unit_positions", +# "peaks", +# "selected_peaks", +# "motion", +# "temporal_bins", +# "spatial_bins", +# "peak_locations", +# "gt_motion", +# ) + +# def __init__( +# self, +# mearec_filename, +# title="", +# detect_kwargs={}, +# select_kwargs=None, +# localize_kwargs={}, +# estimate_motion_kwargs={}, +# folder=None, +# do_preprocessing=True, +# job_kwargs={"chunk_duration": "1s", "n_jobs": -1, "progress_bar": True, "verbose": True}, +# overwrite=False, +# parent_benchmark=None, +# ): +# BenchmarkBase.__init__( +# self, folder=folder, title=title, overwrite=overwrite, job_kwargs=job_kwargs, parent_benchmark=None +# ) + +# self._args.extend([str(mearec_filename)]) + +# self.mearec_filename = mearec_filename +# self.raw_recording, self.gt_sorting = read_mearec(self.mearec_filename) +# self.do_preprocessing = do_preprocessing + +# self._recording = None +# self.detect_kwargs = detect_kwargs.copy() +# self.select_kwargs = select_kwargs.copy() if select_kwargs is not None else None +# self.localize_kwargs = localize_kwargs.copy() +# self.estimate_motion_kwargs = estimate_motion_kwargs.copy() + +# self._kwargs.update( +# dict( +# detect_kwargs=self.detect_kwargs, +# select_kwargs=self.select_kwargs, +# localize_kwargs=self.localize_kwargs, +# estimate_motion_kwargs=self.estimate_motion_kwargs, +# ) +# ) + +# @property +# def recording(self): +# if self._recording is None: +# if self.do_preprocessing: +# self._recording = bandpass_filter(self.raw_recording) +# self._recording = common_reference(self._recording) +# self._recording = zscore(self._recording) +# else: +# self._recording = self.raw_recording +# return self._recording + +# def run(self): +# if self.folder is not None: +# if self.folder.exists() and not self.overwrite: +# raise ValueError(f"The folder {self.folder} is not empty") + +# self.noise_levels = get_noise_levels(self.recording, return_scaled=False) + +# t0 = time.perf_counter() +# self.peaks = detect_peaks( +# self.recording, noise_levels=self.noise_levels, **self.detect_kwargs, **self.job_kwargs +# ) +# t1 = time.perf_counter() +# if self.select_kwargs is not None: +# self.selected_peaks = select_peaks(self.peaks, **self.select_kwargs, **self.job_kwargs) +# else: +# self.selected_peaks = self.peaks +# t2 = time.perf_counter() +# self.peak_locations = localize_peaks( +# self.recording, self.selected_peaks, **self.localize_kwargs, **self.job_kwargs +# ) +# t3 = time.perf_counter() +# self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( +# self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs +# ) + +# t4 = time.perf_counter() + +# self.run_times = dict( +# detect_peaks=t1 - t0, +# select_peaks=t2 - t1, +# localize_peaks=t3 - t2, +# estimate_motion=t4 - t3, +# ) + +# self.compute_gt_motion() + +# # align globally gt_motion and motion to avoid offsets +# self.motion += np.median(self.gt_motion - self.motion) + +# ## save folder +# if self.folder is not None: +# self.save_to_folder() + +# def run_estimate_motion(self): +# # usefull to re run only the motion estimate with peak localization +# t3 = time.perf_counter() +# self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( +# self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs +# ) +# t4 = time.perf_counter() + +# self.compute_gt_motion() + +# # align globally gt_motion and motion to avoid offsets +# self.motion += np.median(self.gt_motion - self.motion) +# self.run_times["estimate_motion"] = t4 - t3 + +# ## save folder +# if self.folder is not None: +# self.save_to_folder() + +# def compute_gt_motion(self): +# self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins) + +# template_locations = np.array(mr.load_recordings(self.mearec_filename).template_locations) +# assert len(template_locations.shape) == 3 +# mid = template_locations.shape[1] // 2 +# unit_mid_positions = template_locations[:, mid, 2] + +# unit_motions = self.gt_unit_positions - unit_mid_positions +# # unit_positions = np.mean(self.gt_unit_positions, axis=0) + +# if self.spatial_bins is None: +# self.gt_motion = np.mean(unit_motions, axis=1)[:, None] +# channel_positions = self.recording.get_channel_locations() +# probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() +# center = (probe_y_min + probe_y_max) // 2 +# self.spatial_bins = np.array([center]) +# else: +# # time, units +# self.gt_motion = np.zeros_like(self.motion) +# for t in range(self.gt_unit_positions.shape[0]): +# f = scipy.interpolate.interp1d(unit_mid_positions, unit_motions[t, :], fill_value="extrapolate") +# self.gt_motion[t, :] = f(self.spatial_bins) + +# def plot_true_drift(self, scaling_probe=1.5, figsize=(15, 10), axes=None): +# if axes is None: +# fig = plt.figure(figsize=figsize) +# gs = fig.add_gridspec(1, 8, wspace=0) + +# if axes is None: +# ax = fig.add_subplot(gs[:2]) +# else: +# ax = axes[0] +# plot_probe_map(self.recording, ax=ax) +# _simpleaxis(ax) + +# mr_recording = mr.load_recordings(self.mearec_filename) + +# for loc in mr_recording.template_locations[::2]: +# if len(mr_recording.template_locations.shape) == 3: +# ax.plot([loc[0, 1], loc[-1, 1]], [loc[0, 2], loc[-1, 2]], alpha=0.7, lw=2) +# else: +# ax.scatter([loc[1]], [loc[2]], alpha=0.7, s=100) + +# # ymin, ymax = ax.get_ylim() +# ax.set_ylabel("depth (um)") +# ax.set_xlabel(None) +# # ax.set_yticks(np.arange(-600,600,100), np.arange(-600,600,100)) + +# # ax.set_ylim(scaling_probe*probe_y_min, scaling_probe*probe_y_max) +# if axes is None: +# ax = fig.add_subplot(gs[2:7]) +# else: +# ax = axes[1] + +# for i in range(self.gt_unit_positions.shape[1]): +# ax.plot(self.temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + +# for i in range(self.gt_motion.shape[1]): +# depth = self.spatial_bins[i] +# ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=4) + +# # ax.set_ylim(ymin, ymax) +# ax.set_xlabel("time (s)") +# _simpleaxis(ax) +# ax.set_yticks([]) +# ax.spines["left"].set_visible(False) + +# channel_positions = self.recording.get_channel_locations() +# probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() +# ax.set_ylim(scaling_probe * probe_y_min, scaling_probe * probe_y_max) + +# ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) +# ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) + +# if axes is None: +# ax = fig.add_subplot(gs[7]) +# else: +# ax = axes[2] +# # plot_probe_map(self.recording, ax=ax) +# _simpleaxis(ax) + +# ax.hist(self.gt_unit_positions[30, :], 50, orientation="horizontal", color="0.5") +# ax.set_yticks([]) +# ax.set_xlabel("# neurons") + +# def plot_peaks_probe(self, alpha=0.05, figsize=(15, 10)): +# fig, axs = plt.subplots(ncols=2, sharey=True, figsize=figsize) +# ax = axs[0] +# plot_probe_map(self.recording, ax=ax) +# ax.scatter(self.peak_locations["x"], self.peak_locations["y"], color="k", s=1, alpha=alpha) +# ax.set_xlabel("x") +# ax.set_ylabel("y") +# if "z" in self.peak_locations.dtype.fields: +# ax = axs[1] +# ax.scatter(self.peak_locations["z"], self.peak_locations["y"], color="k", s=1, alpha=alpha) +# ax.set_xlabel("z") +# ax.set_xlim(0, 100) + +# def plot_peaks(self, scaling_probe=1.5, show_drift=True, show_histogram=True, alpha=0.05, figsize=(15, 10)): +# fig = plt.figure(figsize=figsize) +# if show_histogram: +# gs = fig.add_gridspec(1, 4) +# else: +# gs = fig.add_gridspec(1, 3) +# # Create the Axes. + +# ax0 = fig.add_subplot(gs[0]) +# plot_probe_map(self.recording, ax=ax0) +# _simpleaxis(ax0) + +# # ymin, ymax = ax.get_ylim() +# ax0.set_ylabel("depth (um)") +# ax0.set_xlabel(None) + +# ax = ax1 = fig.add_subplot(gs[1:3]) +# x = self.selected_peaks["sample_index"] / self.recording.get_sampling_frequency() +# y = self.peak_locations["y"] +# ax.scatter(x, y, s=1, color="k", alpha=alpha) + +# ax.set_title(self.title) +# # xmin, xmax = ax.get_xlim() +# # ax.plot([xmin, xmax], [probe_y_min, probe_y_min], 'k--', alpha=0.5) +# # ax.plot([xmin, xmax], [probe_y_max, probe_y_max], 'k--', alpha=0.5) + +# _simpleaxis(ax) +# # ax.set_yticks([]) +# # ax.set_ylim(scaling_probe*probe_y_min, scaling_probe*probe_y_max) +# ax.spines["left"].set_visible(False) +# ax.set_xlabel("time (s)") + +# channel_positions = self.recording.get_channel_locations() +# probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() +# ax.set_ylim(scaling_probe * probe_y_min, scaling_probe * probe_y_max) + +# ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) +# ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) + +# if show_drift: +# if self.spatial_bins is None: +# center = (probe_y_min + probe_y_max) // 2 +# ax.plot(self.temporal_bins, self.gt_motion[:, 0] + center, color="green", lw=1.5) +# ax.plot(self.temporal_bins, self.motion[:, 0] + center, color="orange", lw=1.5) +# else: +# for i in range(self.gt_motion.shape[1]): +# depth = self.spatial_bins[i] +# ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=1.5) +# ax.plot(self.temporal_bins, self.motion[:, i] + depth, color="orange", lw=1.5) + +# if show_histogram: +# ax2 = fig.add_subplot(gs[3]) +# ax2.hist(self.peak_locations["y"], bins=1000, orientation="horizontal") + +# ax2.axhline(probe_y_min, color="k", ls="--", alpha=0.5) +# ax2.axhline(probe_y_max, color="k", ls="--", alpha=0.5) + +# ax2.set_xlabel("density") +# _simpleaxis(ax2) +# # ax.set_ylabel('') +# ax.set_yticks([]) +# ax2.sharey(ax0) + +# ax1.sharey(ax0) + +# def plot_motion_corrected_peaks(self, scaling_probe=1.5, alpha=0.05, figsize=(15, 10), show_probe=True, axes=None): +# if axes is None: +# fig = plt.figure(figsize=figsize) +# if show_probe: +# gs = fig.add_gridspec(1, 5) +# else: +# gs = fig.add_gridspec(1, 4) +# # Create the Axes. + +# if show_probe: +# if axes is None: +# ax0 = ax = fig.add_subplot(gs[0]) +# else: +# ax0 = ax = axes[0] +# plot_probe_map(self.recording, ax=ax) +# _simpleaxis(ax) + +# ymin, ymax = ax.get_ylim() +# ax.set_ylabel("depth (um)") +# ax.set_xlabel(None) + +# channel_positions = self.recording.get_channel_locations() +# probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() + +# peak_locations_corrected = correct_motion_on_peaks( +# self.selected_peaks, +# self.peak_locations, +# self.recording.sampling_frequency, +# self.motion, +# self.temporal_bins, +# self.spatial_bins, +# direction="y", +# ) +# if axes is None: +# if show_probe: +# ax1 = ax = fig.add_subplot(gs[1:3]) +# else: +# ax1 = ax = fig.add_subplot(gs[0:2]) +# else: +# if show_probe: +# ax1 = ax = axes[1] +# else: +# ax1 = ax = axes[0] + +# _simpleaxis(ax) + +# x = self.selected_peaks["sample_index"] / self.recording.get_sampling_frequency() +# y = self.peak_locations["y"] +# ax.scatter(x, y, s=1, color="k", alpha=alpha) +# ax.set_title(self.title) + +# ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) +# ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) + +# ax.set_xlabel("time (s)") + +# if axes is None: +# if show_probe: +# ax2 = ax = fig.add_subplot(gs[3:5]) +# else: +# ax2 = ax = fig.add_subplot(gs[2:4]) +# else: +# if show_probe: +# ax2 = ax = axes[2] +# else: +# ax2 = ax = axes[1] + +# _simpleaxis(ax) +# y = peak_locations_corrected["y"] +# ax.scatter(x, y, s=1, color="k", alpha=alpha) + +# ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) +# ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) + +# ax.set_xlabel("time (s)") + +# if show_probe: +# ax0.set_ylim(scaling_probe * probe_y_min, scaling_probe * probe_y_max) +# ax1.sharey(ax0) +# ax2.sharey(ax0) +# else: +# ax1.set_ylim(scaling_probe * probe_y_min, scaling_probe * probe_y_max) +# ax2.sharey(ax1) + +# def estimation_vs_depth(self, show_only=8, figsize=(15, 10)): +# fig, axs = plt.subplots(ncols=2, figsize=figsize, sharey=True) + +# n = self.motion.shape[1] +# step = int(np.ceil(max(1, n / show_only))) +# colors = plt.cm.get_cmap("jet", n) +# for i in range(0, n, step): +# ax = axs[0] +# ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) +# ax.plot( +# self.temporal_bins, +# self.motion[:, i], +# lw=1.5, +# ls="-", +# color=colors(i), +# label=f"{self.spatial_bins[i]:0.1f}", +# ) + +# ax = axs[1] +# ax.plot(self.temporal_bins, self.motion[:, i] - self.gt_motion[:, i], lw=1.5, ls="-", color=colors(i)) + +# ax = axs[0] +# ax.set_title(self.title) +# ax.legend() +# ax.set_ylabel("drift estimated and GT(um)") +# ax.set_xlabel("time (s)") +# _simpleaxis(ax) + +# ax = axs[1] +# ax.set_ylabel("error (um)") +# ax.set_xlabel("time (s)") +# _simpleaxis(ax) + +# def view_errors(self, figsize=(15, 10), lim=None): +# fig = plt.figure(figsize=figsize) +# gs = fig.add_gridspec(2, 2) + +# errors = self.gt_motion - self.motion + +# channel_positions = self.recording.get_channel_locations() +# probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() + +# ax = fig.add_subplot(gs[0, :]) +# im = ax.imshow( +# np.abs(errors).T, +# aspect="auto", +# interpolation="nearest", +# origin="lower", +# extent=(self.temporal_bins[0], self.temporal_bins[-1], self.spatial_bins[0], self.spatial_bins[-1]), +# ) +# plt.colorbar(im, ax=ax, label="error") +# ax.set_ylabel("depth (um)") +# ax.set_xlabel("time (s)") +# ax.set_title(self.title) +# if lim is not None: +# im.set_clim(0, lim) + +# ax = fig.add_subplot(gs[1, 0]) +# mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) +# ax.plot(self.temporal_bins, mean_error) +# ax.set_xlabel("time (s)") +# ax.set_ylabel("error") +# _simpleaxis(ax) +# if lim is not None: +# ax.set_ylim(0, lim) + +# ax = fig.add_subplot(gs[1, 1]) +# depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) +# ax.plot(self.spatial_bins, depth_error) +# ax.axvline(probe_y_min, color="k", ls="--", alpha=0.5) +# ax.axvline(probe_y_max, color="k", ls="--", alpha=0.5) +# ax.set_xlabel("depth (um)") +# ax.set_ylabel("error") +# _simpleaxis(ax) +# if lim is not None: +# ax.set_ylim(0, lim) + +# return fig + + +# def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colors=None): +# if axes is None: +# fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + +# for count, benchmark in enumerate(benchmarks): +# c = colors[count] if colors is not None else None +# errors = benchmark.gt_motion - benchmark.motion +# mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) +# depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) + +# axes[0].plot(benchmark.temporal_bins, mean_error, lw=1, label=benchmark.title, color=c) +# parts = axes[1].violinplot(mean_error, [count], showmeans=True) +# if c is not None: +# for pc in parts["bodies"]: +# pc.set_facecolor(c) +# pc.set_edgecolor(c) +# for k in parts: +# if k != "bodies": +# # for line in parts[k]: +# parts[k].set_color(c) +# axes[2].plot(benchmark.spatial_bins, depth_error, label=benchmark.title, color=c) + +# ax0 = ax = axes[0] +# ax.set_xlabel("Time [s]") +# ax.set_ylabel("Error [μm]") +# if show_legend: +# ax.legend() +# _simpleaxis(ax) + +# ax1 = axes[1] +# # ax.set_ylabel('error') +# ax1.set_yticks([]) +# ax1.set_xticks([]) +# _simpleaxis(ax1) + +# ax2 = axes[2] +# ax2.set_yticks([]) +# ax2.set_xlabel("Depth [μm]") +# # ax.set_ylabel('error') +# channel_positions = benchmark.recording.get_channel_locations() +# probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() +# ax2.axvline(probe_y_min, color="k", ls="--", alpha=0.5) +# ax2.axvline(probe_y_max, color="k", ls="--", alpha=0.5) + +# _simpleaxis(ax2) + +# # ax1.sharey(ax0) +# # ax2.sharey(ax0) + + +# def plot_error_map_several_benchmarks(benchmarks, axes=None, lim=15, figsize=(10, 10)): +# if axes is None: +# fig, axes = plt.subplots(nrows=len(benchmarks), sharex=True, sharey=True, figsize=figsize) +# else: +# fig = axes[0].figure + +# for count, benchmark in enumerate(benchmarks): +# errors = benchmark.gt_motion - benchmark.motion + +# channel_positions = benchmark.recording.get_channel_locations() +# probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() + +# ax = axes[count] +# im = ax.imshow( +# np.abs(errors).T, +# aspect="auto", +# interpolation="nearest", +# origin="lower", +# extent=( +# benchmark.temporal_bins[0], +# benchmark.temporal_bins[-1], +# benchmark.spatial_bins[0], +# benchmark.spatial_bins[-1], +# ), +# ) +# fig.colorbar(im, ax=ax, label="error") +# ax.set_ylabel("depth (um)") + +# ax.set_title(benchmark.title) +# if lim is not None: +# im.set_clim(0, lim) + +# axes[-1].set_xlabel("time (s)") + +# return fig + + +# def plot_motions_several_benchmarks(benchmarks): +# fig, ax = plt.subplots(figsize=(15, 5)) + +# ax.plot(list(benchmarks)[0].temporal_bins, list(benchmarks)[0].gt_motion[:, 0], lw=2, c="k", label="real motion") +# for count, benchmark in enumerate(benchmarks): +# ax.plot(benchmark.temporal_bins, benchmark.motion.mean(1), lw=1, c=f"C{count}", label=benchmark.title) +# ax.fill_between( +# benchmark.temporal_bins, +# benchmark.motion.mean(1) - benchmark.motion.std(1), +# benchmark.motion.mean(1) + benchmark.motion.std(1), +# color=f"C{count}", +# alpha=0.25, +# ) + +# # ax.legend() +# ax.set_ylabel("depth (um)") +# ax.set_xlabel("time (s)") +# _simpleaxis(ax) + + +# def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): +# if ax is None: +# fig, ax = plt.subplots(figsize=(5, 5)) + +# for count, benchmark in enumerate(benchmarks): +# color = colors[count] if colors is not None else None + +# if detailed: +# bottom = 0 +# i = 0 +# patterns = ["/", "\\", "|", "*"] +# for key, value in benchmark.run_times.items(): +# if count == 0: +# label = key.replace("_", " ") +# else: +# label = None +# ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i]) +# bottom += value +# i += 1 +# else: +# total_run_time = np.sum([value for key, value in benchmark.run_times.items()]) +# ax.bar([count], [total_run_time], color=color, edgecolor="black") + +# # ax.legend() +# ax.set_ylabel("speed (s)") +# _simpleaxis(ax) +# ax.set_xticks([]) +# # ax.set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index b58b831b17..43a4bf4586 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -141,7 +141,7 @@ def remove_benchmark(self, key): def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): if case_keys is None: - case_keys = self.cases.keys() + case_keys = list(self.cases.keys()) job_keys = [] for key in case_keys: @@ -158,7 +158,9 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): benchmark = self.create_benchmark(key) benchmark.run() self.benchmarks[key] = benchmark - benchmark.save_run(self.folder / "results" / self.key_to_str(key)) + bench_folder = self.folder / "results" / self.key_to_str(key) + bench_folder.mkdir(exist_ok=True) + benchmark.save_run(bench_folder) def compute_results(self, case_keys=None, verbose=False, **result_params): if case_keys is None: @@ -253,6 +255,9 @@ def get_result(self, key): class Benchmark: """ """ + def __init__(self): + self.result = {} + @classmethod def load_folder(cls, folder): raise NotImplementedError @@ -277,145 +282,145 @@ def _simpleaxis(ax): ax.get_yaxis().tick_left() -class BenchmarkBaseOld: - _array_names = () - _waveform_names = () - _sorting_names = () - - _array_names_from_parent = () - _waveform_names_from_parent = () - _sorting_names_from_parent = () - - def __init__( - self, - folder=None, - title="", - overwrite=None, - job_kwargs={"chunk_duration": "1s", "n_jobs": -1, "progress_bar": True, "verbose": True}, - parent_benchmark=None, - ): - self.folder = Path(folder) - self.title = title - self.overwrite = overwrite - self.job_kwargs = job_kwargs - self.run_times = None - - self._args = [] - self._kwargs = dict(title=title, overwrite=overwrite, job_kwargs=job_kwargs) - - self.waveforms = {} - self.sortings = {} - - self.parent_benchmark = parent_benchmark - - if self.parent_benchmark is not None: - for name in self._array_names_from_parent: - setattr(self, name, getattr(parent_benchmark, name)) - - for name in self._waveform_names_from_parent: - self.waveforms[name] = parent_benchmark.waveforms[name] - - for key in parent_benchmark.sortings.keys(): - if isinstance(key, str) and key in self._sorting_names_from_parent: - self.sortings[key] = parent_benchmark.sortings[key] - elif isinstance(key, tuple) and key[0] in self._sorting_names_from_parent: - self.sortings[key] = parent_benchmark.sortings[key] - - def save_to_folder(self): - if self.folder.exists(): - import glob, os - - pattern = "*.*" - files = self.folder.glob(pattern) - for file in files: - if file.is_file(): - os.remove(file) - else: - self.folder.mkdir(parents=True) - - if self.parent_benchmark is None: - parent_folder = None - else: - parent_folder = str(self.parent_benchmark.folder) - - info = { - "args": self._args, - "kwargs": self._kwargs, - "parent_folder": parent_folder, - } - info = check_json(info) - (self.folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") - - for name in self._array_names: - if self.parent_benchmark is not None and name in self._array_names_from_parent: - continue - value = getattr(self, name) - if value is not None: - np.save(self.folder / f"{name}.npy", value) - - if self.run_times is not None: - run_times_filename = self.folder / "run_times.json" - run_times_filename.write_text(json.dumps(self.run_times, indent=4), encoding="utf8") - - for key, sorting in self.sortings.items(): - (self.folder / "sortings").mkdir(exist_ok=True) - if isinstance(key, str): - npz_file = self.folder / "sortings" / (str(key) + ".npz") - elif isinstance(key, tuple): - npz_file = self.folder / "sortings" / ("_###_".join(key) + ".npz") - NpzSortingExtractor.write_sorting(sorting, npz_file) - - @classmethod - def load_from_folder(cls, folder, parent_benchmark=None): - folder = Path(folder) - assert folder.exists() - - with open(folder / "info.json", "r") as f: - info = json.load(f) - args = info["args"] - kwargs = info["kwargs"] - - if info["parent_folder"] is None: - parent_benchmark = None - else: - if parent_benchmark is None: - parent_benchmark = cls.load_from_folder(info["parent_folder"]) - - import os - - kwargs["folder"] = folder - - bench = cls(*args, **kwargs, parent_benchmark=parent_benchmark) - - for name in cls._array_names: - filename = folder / f"{name}.npy" - if filename.exists(): - arr = np.load(filename) - else: - arr = None - setattr(bench, name, arr) - - if (folder / "run_times.json").exists(): - with open(folder / "run_times.json", "r") as f: - bench.run_times = json.load(f) - else: - bench.run_times = None - - for key in bench._waveform_names: - if parent_benchmark is not None and key in bench._waveform_names_from_parent: - continue - waveforms_folder = folder / "waveforms" / key - if waveforms_folder.exists(): - bench.waveforms[key] = load_waveforms(waveforms_folder, with_recording=True) - - sorting_folder = folder / "sortings" - if sorting_folder.exists(): - for npz_file in sorting_folder.glob("*.npz"): - name = npz_file.stem - if "_###_" in name: - key = tuple(name.split("_###_")) - else: - key = name - bench.sortings[key] = NpzSortingExtractor(npz_file) - - return bench +# class BenchmarkBaseOld: +# _array_names = () +# _waveform_names = () +# _sorting_names = () + +# _array_names_from_parent = () +# _waveform_names_from_parent = () +# _sorting_names_from_parent = () + +# def __init__( +# self, +# folder=None, +# title="", +# overwrite=None, +# job_kwargs={"chunk_duration": "1s", "n_jobs": -1, "progress_bar": True, "verbose": True}, +# parent_benchmark=None, +# ): +# self.folder = Path(folder) +# self.title = title +# self.overwrite = overwrite +# self.job_kwargs = job_kwargs +# self.run_times = None + +# self._args = [] +# self._kwargs = dict(title=title, overwrite=overwrite, job_kwargs=job_kwargs) + +# self.waveforms = {} +# self.sortings = {} + +# self.parent_benchmark = parent_benchmark + +# if self.parent_benchmark is not None: +# for name in self._array_names_from_parent: +# setattr(self, name, getattr(parent_benchmark, name)) + +# for name in self._waveform_names_from_parent: +# self.waveforms[name] = parent_benchmark.waveforms[name] + +# for key in parent_benchmark.sortings.keys(): +# if isinstance(key, str) and key in self._sorting_names_from_parent: +# self.sortings[key] = parent_benchmark.sortings[key] +# elif isinstance(key, tuple) and key[0] in self._sorting_names_from_parent: +# self.sortings[key] = parent_benchmark.sortings[key] + +# def save_to_folder(self): +# if self.folder.exists(): +# import glob, os + +# pattern = "*.*" +# files = self.folder.glob(pattern) +# for file in files: +# if file.is_file(): +# os.remove(file) +# else: +# self.folder.mkdir(parents=True) + +# if self.parent_benchmark is None: +# parent_folder = None +# else: +# parent_folder = str(self.parent_benchmark.folder) + +# info = { +# "args": self._args, +# "kwargs": self._kwargs, +# "parent_folder": parent_folder, +# } +# info = check_json(info) +# (self.folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") + +# for name in self._array_names: +# if self.parent_benchmark is not None and name in self._array_names_from_parent: +# continue +# value = getattr(self, name) +# if value is not None: +# np.save(self.folder / f"{name}.npy", value) + +# if self.run_times is not None: +# run_times_filename = self.folder / "run_times.json" +# run_times_filename.write_text(json.dumps(self.run_times, indent=4), encoding="utf8") + +# for key, sorting in self.sortings.items(): +# (self.folder / "sortings").mkdir(exist_ok=True) +# if isinstance(key, str): +# npz_file = self.folder / "sortings" / (str(key) + ".npz") +# elif isinstance(key, tuple): +# npz_file = self.folder / "sortings" / ("_###_".join(key) + ".npz") +# NpzSortingExtractor.write_sorting(sorting, npz_file) + +# @classmethod +# def load_from_folder(cls, folder, parent_benchmark=None): +# folder = Path(folder) +# assert folder.exists() + +# with open(folder / "info.json", "r") as f: +# info = json.load(f) +# args = info["args"] +# kwargs = info["kwargs"] + +# if info["parent_folder"] is None: +# parent_benchmark = None +# else: +# if parent_benchmark is None: +# parent_benchmark = cls.load_from_folder(info["parent_folder"]) + +# import os + +# kwargs["folder"] = folder + +# bench = cls(*args, **kwargs, parent_benchmark=parent_benchmark) + +# for name in cls._array_names: +# filename = folder / f"{name}.npy" +# if filename.exists(): +# arr = np.load(filename) +# else: +# arr = None +# setattr(bench, name, arr) + +# if (folder / "run_times.json").exists(): +# with open(folder / "run_times.json", "r") as f: +# bench.run_times = json.load(f) +# else: +# bench.run_times = None + +# for key in bench._waveform_names: +# if parent_benchmark is not None and key in bench._waveform_names_from_parent: +# continue +# waveforms_folder = folder / "waveforms" / key +# if waveforms_folder.exists(): +# bench.waveforms[key] = load_waveforms(waveforms_folder, with_recording=True) + +# sorting_folder = folder / "sortings" +# if sorting_folder.exists(): +# for npz_file in sorting_folder.glob("*.npz"): +# name = npz_file.stem +# if "_###_" in name: +# key = tuple(name.split("_###_")) +# else: +# key = name +# bench.sortings[key] = NpzSortingExtractor(npz_file) + +# return bench From 578c954354955f482f9f84b421f2f063422dfe7a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 21 Feb 2024 14:17:35 +0100 Subject: [PATCH 116/192] WIP --- .../benchmark/benchmark_clustering.py | 60 +++++++++++++++++++ .../benchmark/benchmark_matching.py | 6 +- .../benchmark/benchmark_tools.py | 2 +- 3 files changed, 65 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index f1bfa35959..220dd3258f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -27,6 +27,66 @@ import numpy as np +from .benchmark_tools import BenchmarkStudy, Benchmark +from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.sortinganalyzer import create_sorting_analyzer + + +class ClusteringBenchmark(Benchmark): + + def __init__(self, recording, gt_sorting, gt_positions, params): + self.recording = recording + self.gt_sorting = gt_sorting + self.gt_positions = gt_positions + self.params = params + self.result = {} + self.templates_params = {} + for key in ["ms_before", "ms_after"]: + if key in self.params: + self.templates_params[key] = self.params[key] + else: + self.templates_params[key] = 2 + self.params[key] = 2 + + def run(self, **job_kwargs): + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer.select_random_spikes() + ext = sorting_analyzer.compute('fast_templates', **self.templates_params) + templates = ext.get_data(outputs='Templates') + ext = sorting_analyzer.compute("spike_locations", **self.params) + spikes_locations = ext.get_data(outputs="by_unit") + self.result = {'spikes_locations' : spikes_locations} + self.result['templates'] = templates + + def compute_result(self, **result_params): + + + def save_run(self, folder): + + + def save_result(self, folder): + + + @classmethod + def load_folder(cls, folder): + + return result + + +class ClusteringStudy(BenchmarkStudy): + + benchmark_class = ClusteringBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + gt_positions = self.cases[key]["gt_positions"] + params = self.cases[key]["params"] + benchmark = ClusteringBenchmark(recording, gt_sorting, gt_positions, params) + return benchmark + + class BenchmarkClustering: def __init__(self, recording, gt_sorting, method, exhaustive_gt=True, tmp_folder=None, job_kwargs={}, verbose=True): self.method = method diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 962730f4fc..296faeaa51 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -177,12 +177,14 @@ def plot_comparison_matching(self, case_keys=None, ax.spines[["right", "top"]].set_visible(False) ax.set_aspect("equal") + label1 = self.cases[key1]['label'] + label2 = self.cases[key2]['label'] if j == i: - ax.set_ylabel(f"{key1}") + ax.set_ylabel(f"{label1}") else: ax.set_yticks([]) if i == j: - ax.set_xlabel(f"{key2}") + ax.set_xlabel(f"{label2}") else: ax.set_xticks([]) if i == num_methods - 1 and j == num_methods - 1: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index b58b831b17..d8f108211e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -124,7 +124,7 @@ def key_to_str(self, key): if isinstance(key, str): return key elif isinstance(key, tuple): - return _key_separator.join(key) + return _key_separator.join([str(k) for k in key]) else: raise ValueError("Keys for cases must str or tuple") From 2449e3c4a745044d8d6ebe8216878fad0e7bba07 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 21 Feb 2024 14:40:12 +0100 Subject: [PATCH 117/192] benchmark : save/load more generic --- .../benchmark/benchmark_motion_estimation.py | 99 +++++-------------- .../benchmark_motion_interpolation.py | 42 +++++++- .../benchmark/benchmark_tools.py | 41 +++++++- 3 files changed, 102 insertions(+), 80 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 6c715c19d9..ee9fa885c4 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -1,31 +1,33 @@ from __future__ import annotations import json -import numpy as np import time from pathlib import Path - import pickle +import numpy as np +import scipy.interpolate + from spikeinterface.core import get_noise_levels -from spikeinterface.extractors import read_mearec from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks from spikeinterface.sortingcomponents.motion_estimation import estimate_motion -from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks -from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference - from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis +import matplotlib.pyplot as plt from spikeinterface.widgets import plot_probe_map -import scipy.interpolate +# import MEArec as mr + +# TODO : plot_peaks +# TODO : plot_motion_corrected_peaks +# TODO : plot_error_map_several_benchmarks +# TODO : plot_speed_several_benchmarks +# TODO : read from mearec -import matplotlib.pyplot as plt -import MEArec as mr class MotionEstimationBenchmark(Benchmark): @@ -79,41 +81,11 @@ def run(self, **job_kwargs): self.result["temporal_bins"] = temporal_bins self.result["spatial_bins"] = spatial_bins - - # self.compute_gt_motion() - - # align globally gt_motion and motion to avoid offsets - # self.motion += np.median(self.gt_motion - self.motion) - - - - - # self.result = {'sorting' : sorting} - # self.result['templates'] = self.templates - def compute_result(self, **result_params): raw_motion = self.result["raw_motion"] temporal_bins = self.result["temporal_bins"] spatial_bins = self.result["spatial_bins"] - # interpolation units to gt_motion - num_units = self.unit_locations.shape[0] - - print(self.unit_locations.shape) - # self.unit_displacements = unit_displacements - # self.displacement_sampling_frequency = displacement_sampling_frequency - - - # self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins) - - # template_locations = np.array(mr.load_recordings(self.mearec_filename).template_locations) - # assert len(template_locations.shape) == 3 - # mid = template_locations.shape[1] // 2 - # unit_mid_positions = template_locations[:, mid, 2] - - # unit_motions = self.gt_unit_positions - unit_mid_positions - # # unit_positions = np.mean(self.gt_unit_positions, axis=0) - # time interpolatation of unit displacements times = np.arange(self.unit_displacements.shape[0]) / self.displacement_sampling_frequency f = scipy.interpolate.interp1d(times, self.unit_displacements, axis=0) @@ -124,11 +96,11 @@ def compute_result(self, **result_params): # rigid gt_motion = np.mean(unit_displacements, axis=1)[:, None] else: + # non rigid gt_motion = np.zeros_like(raw_motion) for t in range(temporal_bins.shape[0]): f = scipy.interpolate.interp1d(self.unit_locations[:, self.direction_dim], unit_displacements[t, :], fill_value="extrapolate") gt_motion[t, :] = f(spatial_bins) - # print("gt_motion", gt_motion.shape, raw_motion.shape) # align globally gt_motion and motion to avoid offsets motion = raw_motion.copy() @@ -136,36 +108,19 @@ def compute_result(self, **result_params): self.result["gt_motion"] = gt_motion self.result["motion"] = motion - def save_run(self, folder): - for k in ("raw_motion", "temporal_bins", "spatial_bins"): - np.save(folder / f"{k}.npy", self.result[k]) - - for k in ('step_run_times', ): - with open(folder / f"{k}.pickle", mode="wb") as f: - pickle.dump(self.result[k], f) - - def save_result(self, folder): - for k in ("gt_motion", "motion"): - np.save(folder / f"{k}.npy", self.result[k]) - - @classmethod - def load_folder(cls, folder): - result = {} - # run - for k in ("raw_motion", "temporal_bins", "spatial_bins"): - result[k] = np.load(folder / f"{k}.npy") - - for k in ('step_run_times', ): - with open(folder / f"{k}.pickle", "rb") as f: - result[k] = pickle.load(f) - - # result - for k in ("gt_motion", "motion"): - file = folder / f"{k}.npy" - if file.exists(): - result[k] = np.load(file) - return result + _run_key_saved = [ + ("raw_motion", "npy"), + ("temporal_bins", "npy"), + ("spatial_bins", "npy"), + ("step_run_times", "pickle"), + ] + _result_key_saved = [ + ("gt_motion", "npy",), + ("motion", "npy",) + ] + + @@ -182,6 +137,8 @@ def create_benchmark(self, key): return benchmark def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): + + if case_keys is None: case_keys = list(self.cases.keys()) @@ -367,9 +324,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fi # ax2.sharey(ax0) -# TODO : plot_peaks -# TODO : plot_motion_corrected_peaks -# TODO : plot_error_map_several_benchmarks + # class BenchmarkMotionEstimationMearec(BenchmarkBase): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index 61ad457217..ad19ff08aa 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -6,10 +6,7 @@ from pathlib import Path import shutil -from spikeinterface.core import extract_waveforms, precompute_sparsity, WaveformExtractor - -from spikeinterface.extractors import read_mearec from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten from spikeinterface.sorters import run_sorter, read_sorter_folder @@ -24,11 +21,50 @@ import sklearn +from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis + + import matplotlib.pyplot as plt import MEArec as mr + + +class MotionInterpolationBenchmark(Benchmark): + def __init__(self, recording, gt_sorting, params, + unit_locations, unit_displacements, displacement_sampling_frequency, + direction="y"): + Benchmark.__init__(self) + self.recording = recording + self.gt_sorting = gt_sorting + self.params = params + self.unit_locations = unit_locations + self.unit_displacements = unit_displacements + self.displacement_sampling_frequency = displacement_sampling_frequency + self.direction = direction + self.direction_dim = ["x", "y"].index(direction) + + def run(self, **job_kwargs): + p = self.params + + +class MotionInterpolationStudy(BenchmarkStudy): + + benchmark_class = MotionInterpolationBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + init_kwargs = self.cases[key]["init_kwargs"] + benchmark = MotionInterpolationBenchmark(recording, gt_sorting, params, **init_kwargs) + return benchmark + + + + + class BenchmarkMotionInterpolationMearec(BenchmarkBase): _array_names = ("gt_motion", "estimated_motion", "temporal_bins", "spatial_bins") _waveform_names = ("static", "drifting", "corrected_gt", "corrected_estimated") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 43a4bf4586..6a95bb4b0b 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -258,17 +258,48 @@ class Benchmark: def __init__(self): self.result = {} + _run_key_saved = [] + _result_key_saved = [] + + def _save_keys(self, saved_keys, folder): + for k, format in saved_keys: + if format == "npy": + np.save(folder / f"{k}.npy", self.result[k]) + elif format =="pickle": + with open(folder / f"{k}.pickle", mode="wb") as f: + pickle.dump(self.result[k], f) + else: + raise ValueError(f"Save error {k} {format}") + + def save_run(self, folder): + self._save_keys(self._run_key_saved) + + def save_result(self, folder): + self._save_keys(self._result_key_saved) + @classmethod def load_folder(cls, folder): - raise NotImplementedError - - def save_to_folder(self, folder, result): - raise NotImplementedError + result = {} + for k, format in cls._run_key_saved + cls._result_key_saved: + if format == "npy": + file = folder / f"{k}.npy" + if file.exists(): + result[k] = np.load(file) + elif format =="pickle": + file = folder / f"{k}.pickle" + if file.exists(): + with open(file, mode="rb") as f: + result[k] = pickle.load(f) + + return result def run(self): - # run method and metrics!! + # run method raise NotImplementedError + def compute_result(self): + # run becnhmark result + raise NotImplementedError From 28b4cc4e5b16d2da15e8c62ccd401cf10091432a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 21 Feb 2024 15:16:04 +0100 Subject: [PATCH 118/192] WIP --- .../benchmark/benchmark_clustering.py | 1134 ++++++++--------- .../benchmark/benchmark_matching.py | 33 +- .../benchmark/benchmark_peak_localization.py | 68 +- .../benchmark/benchmark_tools.py | 16 +- .../clustering/clustering_tools.py | 5 +- 5 files changed, 597 insertions(+), 659 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 220dd3258f..8b1347f332 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -17,7 +17,7 @@ ) from spikeinterface.postprocessing import compute_principal_components from spikeinterface.comparison.comparisontools import make_matching_events -from spikeinterface.postprocessing import get_template_extremum_channel +#from spikeinterface.postprocessing import get_template_extremum_channel from spikeinterface.core import get_noise_levels import time @@ -35,42 +35,35 @@ class ClusteringBenchmark(Benchmark): - def __init__(self, recording, gt_sorting, gt_positions, params): + def __init__(self, recording, gt_sorting, params, peaks): self.recording = recording self.gt_sorting = gt_sorting - self.gt_positions = gt_positions + self.peaks = peaks self.params = params + self.method = params['method'] + self.method_kwargs = params['method_kwargs'] self.result = {} - self.templates_params = {} - for key in ["ms_before", "ms_after"]: - if key in self.params: - self.templates_params[key] = self.params[key] - else: - self.templates_params[key] = 2 - self.params[key] = 2 - - def run(self, **job_kwargs): - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) - sorting_analyzer.select_random_spikes() - ext = sorting_analyzer.compute('fast_templates', **self.templates_params) - templates = ext.get_data(outputs='Templates') - ext = sorting_analyzer.compute("spike_locations", **self.params) - spikes_locations = ext.get_data(outputs="by_unit") - self.result = {'spikes_locations' : spikes_locations} - self.result['templates'] = templates + + def run(self, **job_kwargs): + labels, peak_labels = find_cluster_from_peaks( + self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs + ) + self.result['peak_labels'] = peak_labels def compute_result(self, **result_params): - + pass def save_run(self, folder): - + np.save(folder / "peak_labels", self.result['peak_labels']) def save_result(self, folder): - + pass @classmethod def load_folder(cls, folder): - + result = {} + if (folder / "peak_labels.npy").exists(): + result["peak_labels"] = np.load(folder / "peak_labels.npy") return result @@ -81,564 +74,543 @@ class ClusteringStudy(BenchmarkStudy): def create_benchmark(self, key): dataset_key = self.cases[key]["dataset"] recording, gt_sorting = self.datasets[dataset_key] - gt_positions = self.cases[key]["gt_positions"] params = self.cases[key]["params"] - benchmark = ClusteringBenchmark(recording, gt_sorting, gt_positions, params) + init_kwargs = self.cases[key]["init_kwargs"] + benchmark = ClusteringBenchmark(recording, gt_sorting, params, **init_kwargs) return benchmark -class BenchmarkClustering: - def __init__(self, recording, gt_sorting, method, exhaustive_gt=True, tmp_folder=None, job_kwargs={}, verbose=True): - self.method = method - - assert method in clustering_methods, f"Clustering method should be in {clustering_methods.keys()}" - - self.verbose = verbose - self.recording = recording - self.gt_sorting = gt_sorting - self.job_kwargs = job_kwargs - self.exhaustive_gt = exhaustive_gt - self.recording_f = recording - self.sampling_rate = self.recording_f.get_sampling_frequency() - self.job_kwargs = job_kwargs - - self.tmp_folder = tmp_folder - if self.tmp_folder is None: - self.tmp_folder = os.path.join(".", "".join(random.choices(string.ascii_uppercase + string.digits, k=8))) - - self._peaks = None - self._selected_peaks = None - self._positions = None - self._gt_positions = None - self.gt_peaks = None - - self.waveforms = {} - self.pcas = {} - self.templates = {} - - def __del__(self): - import shutil - - shutil.rmtree(self.tmp_folder) - - def set_peaks(self, peaks): - self._peaks = peaks - - def set_positions(self, positions): - self._positions = positions - - def set_gt_positions(self, gt_positions): - self._gt_positions = gt_positions - - @property - def peaks(self): - if self._peaks is None: - self.detect_peaks() - return self._peaks - - @property - def selected_peaks(self): - if self._selected_peaks is None: - self.select_peaks() - return self._selected_peaks - - @property - def positions(self): - if self._positions is None: - self.localize_peaks() - return self._positions - - @property - def gt_positions(self): - if self._gt_positions is None: - self.localize_gt_peaks() - return self._gt_positions - - def detect_peaks(self, method_kwargs={"method": "locally_exclusive"}): - from spikeinterface.sortingcomponents.peak_detection import detect_peaks - - if self.verbose: - method = method_kwargs["method"] - print(f"Detecting peaks with method {method}") - self._peaks = detect_peaks(self.recording_f, **method_kwargs, **self.job_kwargs) - - def select_peaks(self, method_kwargs={"method": "uniform", "n_peaks": 100}): - from spikeinterface.sortingcomponents.peak_selection import select_peaks - - if self.verbose: - method = method_kwargs["method"] - print(f"Selecting peaks with method {method}") - self._selected_peaks = select_peaks(self.peaks, **method_kwargs, **self.job_kwargs) - if self.verbose: - ratio = len(self._selected_peaks) / len(self.peaks) - print(f"The ratio of peaks kept for clustering is {ratio}%") - - def localize_peaks(self, method_kwargs={"method": "center_of_mass"}): - from spikeinterface.sortingcomponents.peak_localization import localize_peaks - - if self.verbose: - method = method_kwargs["method"] - print(f"Localizing peaks with method {method}") - self._positions = localize_peaks(self.recording_f, self.selected_peaks, **method_kwargs, **self.job_kwargs) - - def localize_gt_peaks(self, method_kwargs={"method": "center_of_mass"}): - from spikeinterface.sortingcomponents.peak_localization import localize_peaks - - if self.verbose: - method = method_kwargs["method"] - print(f"Localizing gt peaks with method {method}") - self._gt_positions = localize_peaks(self.recording_f, self.gt_peaks, **method_kwargs, **self.job_kwargs) - - def run(self, peaks=None, positions=None, method=None, method_kwargs={}, delta=0.2): - t_start = time.time() - if method is not None: - self.method = method - if peaks is not None: - self._peaks = peaks - self._selected_peaks = peaks - - nb_peaks = len(self.selected_peaks) - if self.verbose: - print(f"Launching the {self.method} clustering algorithm with {nb_peaks} peaks") - - if positions is not None: - self._positions = positions - - labels, peak_labels = find_cluster_from_peaks( - self.recording_f, self.selected_peaks, method=self.method, method_kwargs=method_kwargs, **self.job_kwargs - ) - nb_clusters = len(labels) - if self.verbose: - print(f"{nb_clusters} clusters have been found") - self.noise = peak_labels == -1 - self.run_time = time.time() - t_start - self.selected_peaks_labels = peak_labels - self.labels = labels - - self.clustering = NumpySorting.from_times_labels( - self.selected_peaks["sample_index"][~self.noise], - self.selected_peaks_labels[~self.noise], - self.sampling_rate, - ) - if self.verbose: - print("Performing the comparison with (sliced) ground truth") - - spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0] - spikes2 = self.clustering.to_spike_vector(concatenated=False)[0] - - matches = make_matching_events( - spikes1["sample_index"], spikes2["sample_index"], int(delta * self.sampling_rate / 1000) - ) - - self.matches = matches - idx = matches["index1"] - self.sliced_gt_sorting = NumpySorting(spikes1[idx], self.sampling_rate, self.gt_sorting.unit_ids) - - self.comp = GroundTruthComparison(self.sliced_gt_sorting, self.clustering, exhaustive_gt=self.exhaustive_gt) - - for label, sorting in zip( - ["gt", "clustering", "full_gt"], [self.sliced_gt_sorting, self.clustering, self.gt_sorting] - ): - tmp_folder = os.path.join(self.tmp_folder, label) - if os.path.exists(tmp_folder): - import shutil - - shutil.rmtree(tmp_folder) - - if not (label == "full_gt" and label in self.waveforms): - if self.verbose: - print(f"Extracting waveforms for {label}") - - self.waveforms[label] = extract_waveforms( - self.recording_f, - sorting, - tmp_folder, - load_if_exists=True, - ms_before=2.5, - ms_after=3.5, - max_spikes_per_unit=500, - return_scaled=False, - **self.job_kwargs, - ) - - # self.pcas[label] = compute_principal_components(self.waveforms[label], load_if_exists=True, - # n_components=5, mode='by_channel_local', - # whiten=True, dtype='float32') - - self.templates[label] = self.waveforms[label].get_all_templates(mode="median") - - if self.gt_peaks is None: - if self.verbose: - print("Computing gt peaks") - gt_peaks_ = self.gt_sorting.to_spike_vector() - self.gt_peaks = np.zeros( - gt_peaks_.size, dtype=[("sample_index", " 0.1) - - ax.plot( - metrics["snr"][unit_ids1][inds_1[:nb_potentials]], - nb_peaks[inds_1[:nb_potentials]], - markersize=10, - marker=".", - ls="", - c="k", - label="Cluster potentially found", - ) - ax.plot( - metrics["snr"][unit_ids1][inds_1[nb_potentials:]], - nb_peaks[inds_1[nb_potentials:]], - markersize=10, - marker=".", - ls="", - c="r", - label="Cluster clearly missed", - ) - - if annotations: - for l, x, y in zip( - unit_ids1[: len(inds_2)], - metrics["snr"][unit_ids1][inds_1[: len(inds_2)]], - nb_peaks[inds_1[: len(inds_2)]], - ): - ax.annotate(l, (x, y)) - - for l, x, y in zip( - unit_ids1[len(inds_2) :], - metrics["snr"][unit_ids1][inds_1[len(inds_2) :]], - nb_peaks[inds_1[len(inds_2) :]], - ): - ax.annotate(l, (x, y), c="r") - - if detect_threshold is not None: - ymin, ymax = ax.get_ylim() - ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") - - ax.legend() - ax.set_xlabel("template snr") - ax.set_ylabel("nb spikes") - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - ax = axs[0, 2] - im = ax.imshow(distances, aspect="auto") - ax.set_title(metric) - fig.colorbar(im, ax=ax) - - if detect_threshold is not None: - for count, snr in enumerate(snrs): - if snr < detect_threshold: - ax.plot([xmin, xmax], [count, count], "w") - - ymin, ymax = ax.get_ylim() - ax.plot([nb_detectable + 0.5, nb_detectable + 0.5], [ymin, ymax], "r") - - ax.set_yticks(np.arange(0, len(scores.index))) - ax.set_yticklabels(scores.index, fontsize=8) - - res = [] - nb_spikes = [] - energy = [] - nb_channels = [] - - noise_levels = get_noise_levels(self.recording_f, return_scaled=False) - - for found, real in zip(unit_ids2, unit_ids1): - wfs = self.waveforms["clustering"].get_waveforms(found) - wfs_real = self.waveforms["gt"].get_waveforms(real) - template = self.waveforms["clustering"].get_template(found) - template_real = self.waveforms["gt"].get_template(real) - nb_channels += [np.sum(np.std(template_real, 0) < noise_levels)] - - wfs = wfs.reshape(len(wfs), -1) - template = template.reshape(template.size, 1).T - template_real = template_real.reshape(template_real.size, 1).T - - if metric == "cosine": - dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real).flatten().tolist() - else: - dist = sklearn.metrics.pairwise_distances(template, template_real, metric).flatten().tolist() - res += dist - nb_spikes += [self.sliced_gt_sorting.get_unit_spike_train(real).size] - energy += [np.linalg.norm(template_real)] - - ax = axs[1, 0] - res = np.array(res) - nb_spikes = np.array(nb_spikes) - nb_channels = np.array(nb_channels) - energy = np.array(energy) - - snrs = metrics["snr"][unit_ids1][inds_1[: len(inds_2)]] - cm = ax.scatter(snrs, nb_spikes, c=res) - ax.set_xlabel("template snr") - ax.set_ylabel("nb spikes") - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - cb = fig.colorbar(cm, ax=ax) - cb.set_label(metric) - if detect_threshold is not None: - ymin, ymax = ax.get_ylim() - ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") - - if annotations: - for l, x, y in zip(unit_ids1[: len(inds_2)], snrs, nb_spikes): - ax.annotate(l, (x, y)) - - ax = axs[1, 1] - cm = ax.scatter(energy, nb_channels, c=res) - ax.set_xlabel("template energy") - ax.set_ylabel("nb channels") - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - cb = fig.colorbar(cm, ax=ax) - cb.set_label(metric) - - if annotations: - for l, x, y in zip(unit_ids1[: len(inds_2)], energy, nb_channels): - ax.annotate(l, (x, y)) - - ax = axs[1, 2] - for performance_name in ["accuracy", "recall", "precision"]: - perf = self.comp.get_performance()[performance_name] - ax.plot(metrics["snr"], perf, markersize=10, marker=".", ls="", label=performance_name) - ax.set_xlabel("template snr") - ax.set_ylabel("performance") - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.legend() - if detect_threshold is not None: - ymin, ymax = ax.get_ylim() - ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") - - plt.tight_layout() +# class BenchmarkClustering: +# def __init__(self, recording, gt_sorting, method, exhaustive_gt=True, tmp_folder=None, job_kwargs={}, verbose=True): +# self.method = method + +# self.waveforms = {} +# self.pcas = {} +# self.templates = {} + +# def __del__(self): +# import shutil + +# shutil.rmtree(self.tmp_folder) + +# def set_peaks(self, peaks): +# self._peaks = peaks + +# def set_positions(self, positions): +# self._positions = positions + +# def set_gt_positions(self, gt_positions): +# self._gt_positions = gt_positions + +# @property +# def peaks(self): +# if self._peaks is None: +# self.detect_peaks() +# return self._peaks + +# @property +# def selected_peaks(self): +# if self._selected_peaks is None: +# self.select_peaks() +# return self._selected_peaks + +# @property +# def positions(self): +# if self._positions is None: +# self.localize_peaks() +# return self._positions + +# @property +# def gt_positions(self): +# if self._gt_positions is None: +# self.localize_gt_peaks() +# return self._gt_positions + +# def detect_peaks(self, method_kwargs={"method": "locally_exclusive"}): +# from spikeinterface.sortingcomponents.peak_detection import detect_peaks + +# if self.verbose: +# method = method_kwargs["method"] +# print(f"Detecting peaks with method {method}") +# self._peaks = detect_peaks(self.recording_f, **method_kwargs, **self.job_kwargs) + +# def select_peaks(self, method_kwargs={"method": "uniform", "n_peaks": 100}): +# from spikeinterface.sortingcomponents.peak_selection import select_peaks + +# if self.verbose: +# method = method_kwargs["method"] +# print(f"Selecting peaks with method {method}") +# self._selected_peaks = select_peaks(self.peaks, **method_kwargs, **self.job_kwargs) +# if self.verbose: +# ratio = len(self._selected_peaks) / len(self.peaks) +# print(f"The ratio of peaks kept for clustering is {ratio}%") + +# def localize_peaks(self, method_kwargs={"method": "center_of_mass"}): +# from spikeinterface.sortingcomponents.peak_localization import localize_peaks + +# if self.verbose: +# method = method_kwargs["method"] +# print(f"Localizing peaks with method {method}") +# self._positions = localize_peaks(self.recording_f, self.selected_peaks, **method_kwargs, **self.job_kwargs) + +# def localize_gt_peaks(self, method_kwargs={"method": "center_of_mass"}): +# from spikeinterface.sortingcomponents.peak_localization import localize_peaks + +# if self.verbose: +# method = method_kwargs["method"] +# print(f"Localizing gt peaks with method {method}") +# self._gt_positions = localize_peaks(self.recording_f, self.gt_peaks, **method_kwargs, **self.job_kwargs) + +# def run(self, peaks=None, positions=None, method=None, method_kwargs={}, delta=0.2): +# t_start = time.time() +# if method is not None: +# self.method = method +# if peaks is not None: +# self._peaks = peaks +# self._selected_peaks = peaks + +# nb_peaks = len(self.selected_peaks) +# if self.verbose: +# print(f"Launching the {self.method} clustering algorithm with {nb_peaks} peaks") + +# if positions is not None: +# self._positions = positions + +# labels, peak_labels = find_cluster_from_peaks( +# self.recording_f, self.selected_peaks, method=self.method, method_kwargs=method_kwargs, **self.job_kwargs +# ) +# nb_clusters = len(labels) +# if self.verbose: +# print(f"{nb_clusters} clusters have been found") +# self.noise = peak_labels == -1 +# self.run_time = time.time() - t_start +# self.selected_peaks_labels = peak_labels +# self.labels = labels + +# self.clustering = NumpySorting.from_times_labels( +# self.selected_peaks["sample_index"][~self.noise], +# self.selected_peaks_labels[~self.noise], +# self.sampling_rate, +# ) +# if self.verbose: +# print("Performing the comparison with (sliced) ground truth") + +# spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0] +# spikes2 = self.clustering.to_spike_vector(concatenated=False)[0] + +# matches = make_matching_events( +# spikes1["sample_index"], spikes2["sample_index"], int(delta * self.sampling_rate / 1000) +# ) + +# self.matches = matches +# idx = matches["index1"] +# self.sliced_gt_sorting = NumpySorting(spikes1[idx], self.sampling_rate, self.gt_sorting.unit_ids) + +# self.comp = GroundTruthComparison(self.sliced_gt_sorting, self.clustering, exhaustive_gt=self.exhaustive_gt) + +# for label, sorting in zip( +# ["gt", "clustering", "full_gt"], [self.sliced_gt_sorting, self.clustering, self.gt_sorting] +# ): +# tmp_folder = os.path.join(self.tmp_folder, label) +# if os.path.exists(tmp_folder): +# import shutil + +# shutil.rmtree(tmp_folder) + +# if not (label == "full_gt" and label in self.waveforms): +# if self.verbose: +# print(f"Extracting waveforms for {label}") + +# self.waveforms[label] = extract_waveforms( +# self.recording_f, +# sorting, +# tmp_folder, +# load_if_exists=True, +# ms_before=2.5, +# ms_after=3.5, +# max_spikes_per_unit=500, +# return_scaled=False, +# **self.job_kwargs, +# ) + +# # self.pcas[label] = compute_principal_components(self.waveforms[label], load_if_exists=True, +# # n_components=5, mode='by_channel_local', +# # whiten=True, dtype='float32') + +# self.templates[label] = self.waveforms[label].get_all_templates(mode="median") + +# if self.gt_peaks is None: +# if self.verbose: +# print("Computing gt peaks") +# gt_peaks_ = self.gt_sorting.to_spike_vector() +# self.gt_peaks = np.zeros( +# gt_peaks_.size, dtype=[("sample_index", " 0.1) + +# ax.plot( +# metrics["snr"][unit_ids1][inds_1[:nb_potentials]], +# nb_peaks[inds_1[:nb_potentials]], +# markersize=10, +# marker=".", +# ls="", +# c="k", +# label="Cluster potentially found", +# ) +# ax.plot( +# metrics["snr"][unit_ids1][inds_1[nb_potentials:]], +# nb_peaks[inds_1[nb_potentials:]], +# markersize=10, +# marker=".", +# ls="", +# c="r", +# label="Cluster clearly missed", +# ) + +# if annotations: +# for l, x, y in zip( +# unit_ids1[: len(inds_2)], +# metrics["snr"][unit_ids1][inds_1[: len(inds_2)]], +# nb_peaks[inds_1[: len(inds_2)]], +# ): +# ax.annotate(l, (x, y)) + +# for l, x, y in zip( +# unit_ids1[len(inds_2) :], +# metrics["snr"][unit_ids1][inds_1[len(inds_2) :]], +# nb_peaks[inds_1[len(inds_2) :]], +# ): +# ax.annotate(l, (x, y), c="r") + +# if detect_threshold is not None: +# ymin, ymax = ax.get_ylim() +# ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") + +# ax.legend() +# ax.set_xlabel("template snr") +# ax.set_ylabel("nb spikes") +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) + +# ax = axs[0, 2] +# im = ax.imshow(distances, aspect="auto") +# ax.set_title(metric) +# fig.colorbar(im, ax=ax) + +# if detect_threshold is not None: +# for count, snr in enumerate(snrs): +# if snr < detect_threshold: +# ax.plot([xmin, xmax], [count, count], "w") + +# ymin, ymax = ax.get_ylim() +# ax.plot([nb_detectable + 0.5, nb_detectable + 0.5], [ymin, ymax], "r") + +# ax.set_yticks(np.arange(0, len(scores.index))) +# ax.set_yticklabels(scores.index, fontsize=8) + +# res = [] +# nb_spikes = [] +# energy = [] +# nb_channels = [] + +# noise_levels = get_noise_levels(self.recording_f, return_scaled=False) + +# for found, real in zip(unit_ids2, unit_ids1): +# wfs = self.waveforms["clustering"].get_waveforms(found) +# wfs_real = self.waveforms["gt"].get_waveforms(real) +# template = self.waveforms["clustering"].get_template(found) +# template_real = self.waveforms["gt"].get_template(real) +# nb_channels += [np.sum(np.std(template_real, 0) < noise_levels)] + +# wfs = wfs.reshape(len(wfs), -1) +# template = template.reshape(template.size, 1).T +# template_real = template_real.reshape(template_real.size, 1).T + +# if metric == "cosine": +# dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real).flatten().tolist() +# else: +# dist = sklearn.metrics.pairwise_distances(template, template_real, metric).flatten().tolist() +# res += dist +# nb_spikes += [self.sliced_gt_sorting.get_unit_spike_train(real).size] +# energy += [np.linalg.norm(template_real)] + +# ax = axs[1, 0] +# res = np.array(res) +# nb_spikes = np.array(nb_spikes) +# nb_channels = np.array(nb_channels) +# energy = np.array(energy) + +# snrs = metrics["snr"][unit_ids1][inds_1[: len(inds_2)]] +# cm = ax.scatter(snrs, nb_spikes, c=res) +# ax.set_xlabel("template snr") +# ax.set_ylabel("nb spikes") +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) +# cb = fig.colorbar(cm, ax=ax) +# cb.set_label(metric) +# if detect_threshold is not None: +# ymin, ymax = ax.get_ylim() +# ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") + +# if annotations: +# for l, x, y in zip(unit_ids1[: len(inds_2)], snrs, nb_spikes): +# ax.annotate(l, (x, y)) + +# ax = axs[1, 1] +# cm = ax.scatter(energy, nb_channels, c=res) +# ax.set_xlabel("template energy") +# ax.set_ylabel("nb channels") +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) +# cb = fig.colorbar(cm, ax=ax) +# cb.set_label(metric) + +# if annotations: +# for l, x, y in zip(unit_ids1[: len(inds_2)], energy, nb_channels): +# ax.annotate(l, (x, y)) + +# ax = axs[1, 2] +# for performance_name in ["accuracy", "recall", "precision"]: +# perf = self.comp.get_performance()[performance_name] +# ax.plot(metrics["snr"], perf, markersize=10, marker=".", ls="", label=performance_name) +# ax.set_xlabel("template snr") +# ax.set_ylabel("performance") +# ax.spines["top"].set_visible(False) +# ax.spines["right"].set_visible(False) +# ax.legend() +# if detect_threshold is not None: +# ymin, ymax = ax.get_ylim() +# ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") + +# plt.tight_layout() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 296faeaa51..60922d3be5 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -61,32 +61,15 @@ def compute_result(self, **result_params): comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) self.result['gt_comparison'] = comp self.result['gt_collision'] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) - - def save_run(self, folder): - self.result['sorting'].save(folder = folder / "sorting", format="numpy_folder") - self.result['templates'].to_zarr(folder / "templates") - def save_result(self, folder): - comparison_file = folder / "gt_comparison.pickle" - with open(comparison_file, mode="wb") as f: - pickle.dump(self.result['gt_comparison'], f) - - collision_file = folder / "gt_collision.pickle" - with open(collision_file, mode="wb") as f: - pickle.dump(self.result['gt_collision'], f) - - @classmethod - def load_folder(cls, folder): - result = {} - result['sorting'] = load_extractor(folder / "sorting") - result['templates'] = Templates.from_zarr(folder / "templates") - if (folder / "gt_comparison.pickle").exists(): - with open(folder / "gt_comparison.pickle", "rb") as f: - result['gt_comparison'] = pickle.load(f) - if (folder / "gt_collision.pickle").exists(): - with open(folder / "gt_collision.pickle", "rb") as f: - result['gt_collision'] = pickle.load(f) - return result + _run_key_saved = [ + ("sorting", "sorting"), + ("templates", "zarr_templates"), + ] + _result_key_saved = [ + ("gt_collision", "pickle"), + ("gt_comparison", "pickle") + ] class MatchingStudy(BenchmarkStudy): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index cb2af1b369..e818d418fa 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -31,7 +31,7 @@ class PeakLocalizationBenchmark(Benchmark): - def __init__(self, recording, gt_sorting, gt_positions, params): + def __init__(self, recording, gt_sorting, params, gt_positions): self.recording = recording self.gt_sorting = gt_sorting self.gt_positions = gt_positions @@ -75,33 +75,15 @@ def compute_result(self, **result_params): ) self.result['errors'] = errors - def save_run(self, folder): - self.result['templates'].to_zarr(folder / "templates") - locations_file = folder / "spikes_locations.pickle" - with open(locations_file, mode="wb") as f: - pickle.dump(self.result['spikes_locations'], f) - - def save_result(self, folder): - errors_file = folder / "errors.pickle" - with open(errors_file, mode="wb") as f: - pickle.dump(self.result['errors'], f) - np.save(folder / "medians_over_templates", self.result['medians_over_templates']) - np.save(folder / "mads_over_templates", self.result['mads_over_templates']) - - @classmethod - def load_folder(cls, folder): - result = {} - result['templates'] = Templates.from_zarr(folder / "templates") - if (folder / "errors.pickle").exists(): - with open(folder / "errors.pickle", "rb") as f: - result['errors'] = pickle.load(f) - if (folder / "spikes_locations.pickle").exists(): - with open(folder / "spikes_locations.pickle", "rb") as f: - result['spikes_locations'] = pickle.load(f) - if (folder / "medians_over_templates.npy").exists(): - result["medians_over_templates"] = np.load(folder / "medians_over_templates.npy") - result["mads_over_templates"] = np.load(folder / "mads_over_templates.npy") - return result + _run_key_saved = [ + ("spikes_locations", "pickle"), + ("templates", "zarr_templates"), + ] + _result_key_saved = [ + ("errors", "pickle"), + ("medians_over_templates", "npy"), + ("mads_over_templates", "npy"), + ] class PeakLocalizationStudy(BenchmarkStudy): @@ -111,9 +93,9 @@ class PeakLocalizationStudy(BenchmarkStudy): def create_benchmark(self, key): dataset_key = self.cases[key]["dataset"] recording, gt_sorting = self.datasets[dataset_key] - gt_positions = self.cases[key]["gt_positions"] params = self.cases[key]["params"] - benchmark = PeakLocalizationBenchmark(recording, gt_sorting, gt_positions, params) + init_kwargs = self.cases[key]["init_kwargs"] + benchmark = PeakLocalizationBenchmark(recording, gt_sorting, params, **init_kwargs) return benchmark def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): @@ -230,23 +212,15 @@ def run(self, **job_kwargs): def compute_result(self, **result_params): errors = np.linalg.norm(self.gt_positions[:, :2] - self.result['unit_locations'][:, :2], axis=1) - self.result['errors'] = errors - - def save_run(self, folder): - self.result['templates'].to_zarr(folder / "templates") - np.save(folder / "unit_locations", self.result['unit_locations']) - - def save_result(self, folder): - np.save(folder / "errors", self.result['errors']) - - @classmethod - def load_folder(cls, folder): - result = {} - result['templates'] = Templates.from_zarr(folder / "templates") - result["unit_locations"] = np.load(folder / "unit_locations.npy") - if (folder / "errors.npy").exists(): - result["errors"] = np.load(folder / "errors.npy") - return result + self.result['errors'] = errors + + _run_key_saved = [ + ("unit_locations", "npy"), + ("templates", "zarr_templates"), + ] + _result_key_saved = [ + ("errors", "npy") + ] class UnitLocalizationStudy(BenchmarkStudy): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index b047ee45d5..868ecd58db 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -268,14 +268,20 @@ def _save_keys(self, saved_keys, folder): elif format =="pickle": with open(folder / f"{k}.pickle", mode="wb") as f: pickle.dump(self.result[k], f) + elif format == 'sorting': + self.result[k].save(folder = folder / k, format="numpy_folder") + elif format == 'zarr_templates': + self.result[k].to_zarr(folder / k) + elif format == 'sorting_analyzer': + pass else: raise ValueError(f"Save error {k} {format}") def save_run(self, folder): - self._save_keys(self._run_key_saved) + self._save_keys(self._run_key_saved, folder) def save_result(self, folder): - self._save_keys(self._result_key_saved) + self._save_keys(self._result_key_saved, folder) @classmethod def load_folder(cls, folder): @@ -290,6 +296,12 @@ def load_folder(cls, folder): if file.exists(): with open(file, mode="rb") as f: result[k] = pickle.load(f) + elif format =="sorting": + from spikeinterface.core import load_extractor + result[k] = load_extractor(folder / k) + elif format =="zarr_templates": + from spikeinterface.core.template import Templates + result[k] = Templates.from_zarr(folder / k) return result diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 17b8dea89a..0441943d5f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -539,9 +539,8 @@ def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core import NumpySorting - from spikeinterface.core import extract_waveforms from spikeinterface.core import get_global_tmp_folder - import string, random, shutil, os + import os from pathlib import Path job_kwargs = fix_job_kwargs(job_kwargs) @@ -579,8 +578,6 @@ def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job margin = 2 * max(templates.nbefore, templates.nafter) half_marging = margin // 2 - chunk_size = duration + 3 * margin - local_params = method_kwargs.copy() local_params.update({"templates": templates, "amplitudes": [0.975, 1.025]}) From bcea8b8d365b37736104b948a2bfefc4def2ea72 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 21 Feb 2024 17:31:01 +0100 Subject: [PATCH 119/192] WIP --- .../benchmark/benchmark_clustering.py | 366 +++++++++--------- .../benchmark/benchmark_matching.py | 1 + 2 files changed, 176 insertions(+), 191 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 8b1347f332..f1c04c23d9 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -31,15 +31,27 @@ from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from spikeinterface.core.template_tools import get_template_extremum_channel class ClusteringBenchmark(Benchmark): - def __init__(self, recording, gt_sorting, params, peaks): + def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.recording = recording self.gt_sorting = gt_sorting - self.peaks = peaks + self.indices = indices + + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer.select_random_spikes() + ext = sorting_analyzer.compute('fast_templates') + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + + peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + if self.indices is None: + self.indices = np.arange(len(peaks)) + self.peaks = peaks[self.indices] self.params = params + self.exhaustive_gt = exhaustive_gt self.method = params['method'] self.method_kwargs = params['method_kwargs'] self.result = {} @@ -51,20 +63,45 @@ def run(self, **job_kwargs): self.result['peak_labels'] = peak_labels def compute_result(self, **result_params): - pass + self.noise = self.result['peak_labels'] < 0 + + spikes = self.gt_sorting.to_spike_vector() + self.result['sliced_gt_sorting'] = NumpySorting(spikes[self.indices], + self.recording.sampling_frequency, + self.gt_sorting.unit_ids) - def save_run(self, folder): - np.save(folder / "peak_labels", self.result['peak_labels']) + data = spikes[self.indices][~self.noise] + data["unit_index"] = self.result['peak_labels'][~self.noise] + + self.result['clustering'] = NumpySorting.from_times_labels(data["sample_index"], + self.result['peak_labels'][~self.noise], + self.recording.sampling_frequency) + + self.result['gt_comparison'] = GroundTruthComparison(self.result['sliced_gt_sorting'], + self.result['clustering'], + exhaustive_gt=self.exhaustive_gt) - def save_result(self, folder): - pass + sorting_analyzer = create_sorting_analyzer(self.result['sliced_gt_sorting'], self.recording, format='memory', sparse=False) + sorting_analyzer.select_random_spikes() + ext = sorting_analyzer.compute('fast_templates') + self.result['sliced_gt_templates'] = ext.get_data(outputs="Templates") - @classmethod - def load_folder(cls, folder): - result = {} - if (folder / "peak_labels.npy").exists(): - result["peak_labels"] = np.load(folder / "peak_labels.npy") - return result + sorting_analyzer = create_sorting_analyzer(self.result['clustering'], self.recording, format='memory', sparse=False) + sorting_analyzer.select_random_spikes() + ext = sorting_analyzer.compute('fast_templates') + self.result['clustering_templates'] = ext.get_data(outputs="Templates") + + _run_key_saved = [ + ("peak_labels", "npy") + ] + + _result_key_saved = [ + ("gt_comparison", "pickle"), + ("sliced_gt_sorting", "sorting"), + ("clustering", "sorting"), + ("sliced_gt_templates", "zarr_templates"), + ("clustering_templates", "zarr_templates") + ] class ClusteringStudy(BenchmarkStudy): @@ -79,186 +116,133 @@ def create_benchmark(self, key): benchmark = ClusteringBenchmark(recording, gt_sorting, params, **init_kwargs) return benchmark + def homogeneity_score(self, ignore_noise=True, case_keys=None): + + if case_keys is None: + case_keys = list(self.cases.keys()) + + for count, key in enumerate(case_keys): + result = self.get_result(key) + noise = result["peak_labels"] < 0 + from sklearn.metrics import homogeneity_score + gt_labels = self.benchmarks[key].gt_sorting.to_spike_vector()["unit_index"] + gt_labels = gt_labels[self.benchmarks[key].indices] + found_labels = result['peak_labels'] + if ignore_noise: + gt_labels = gt_labels[~noise] + found_labels = found_labels[~noise] + print(self.cases[key]['label'], homogeneity_score(gt_labels, found_labels), np.mean(noise)) + + def plot_agreements(self, case_keys=None, figsize=(15,15)): + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) + + for count, key in enumerate(case_keys): + ax = axs[count] + ax.set_title(self.cases[key]['label']) + plot_agreement_matrix(self.get_result(key)['gt_comparison'], ax=ax) + + def plot_performances_vs_snr(self, case_keys=None, figsize=(15,15)): + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + + for count, k in enumerate(("accuracy", "recall", "precision")): + + ax = axs[count] + for key in case_keys: + label = self.cases[key]["label"] + + analyzer = self.get_sorting_analyzer(key) + metrics = analyzer.get_extension('quality_metrics').get_data() + x = metrics["snr"].values + y = self.get_result(key)['gt_comparison'].get_performance()[k].values + ax.scatter(x, y, marker=".", label=label) + ax.set_title(k) + + if count == 2: + ax.legend() + + def plot_error_metrics(self, metric='cosine', case_keys=None, figsize=(15,5)): + + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) + + for count, key in enumerate(case_keys): + + result = self.get_result(key) + scores = result['gt_comparison'].get_ordered_agreement_scores() + + unit_ids1 = scores.index.values + unit_ids2 = scores.columns.values + inds_1 = result['gt_comparison'].sorting1.ids_to_indices(unit_ids1) + inds_2 = result['gt_comparison'].sorting2.ids_to_indices(unit_ids2) + t1 = result["sliced_gt_templates"].templates_array + t2 = result['clustering_templates'].templates_array + a = t1.reshape(len(t1), -1)[inds_1] + b = t2.reshape(len(t2), -1)[inds_2] + + import sklearn + + if metric == "cosine": + distances = sklearn.metrics.pairwise.cosine_similarity(a, b) + else: + distances = sklearn.metrics.pairwise_distances(a, b, metric) + + im = axs[count].imshow(distances, aspect="auto") + axs[count].set_title(metric) + fig.colorbar(im, ax=axs[count]) + label = self.cases[key]["label"] + axs[count].set_title(label) + + + def plot_metrics_vs_snr(self, metric='cosine', case_keys=None, figsize=(15,5)): + + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) + + for count, key in enumerate(case_keys): + + result = self.get_result(key) + scores = result['gt_comparison'].get_ordered_agreement_scores() + + analyzer = self.get_sorting_analyzer(key) + metrics = analyzer.get_extension('quality_metrics').get_data() + + unit_ids1 = scores.index.values + unit_ids2 = scores.columns.values + inds_1 = result['gt_comparison'].sorting1.ids_to_indices(unit_ids1) + inds_2 = result['gt_comparison'].sorting2.ids_to_indices(unit_ids2) + t1 = result["sliced_gt_templates"].templates_array + t2 = result['clustering_templates'].templates_array + a = t1.reshape(len(t1), -1) + b = t2.reshape(len(t2), -1) + + import sklearn + + if metric == "cosine": + distances = sklearn.metrics.pairwise.cosine_similarity(a, b) + else: + distances = sklearn.metrics.pairwise_distances(a, b, metric) + + snr = metrics["snr"][unit_ids1][inds_1[: len(inds_2)]] + to_plot = [] + for found, real in zip(inds_2, inds_1): + to_plot += [distances[real, found]] + axs[count].plot(snr, to_plot, '.') + axs[count].set_xlabel('snr') + axs[count].set_ylabel(metric) + label = self.cases[key]["label"] + axs[count].set_title(label) -# class BenchmarkClustering: -# def __init__(self, recording, gt_sorting, method, exhaustive_gt=True, tmp_folder=None, job_kwargs={}, verbose=True): -# self.method = method - -# self.waveforms = {} -# self.pcas = {} -# self.templates = {} - -# def __del__(self): -# import shutil - -# shutil.rmtree(self.tmp_folder) - -# def set_peaks(self, peaks): -# self._peaks = peaks - -# def set_positions(self, positions): -# self._positions = positions - -# def set_gt_positions(self, gt_positions): -# self._gt_positions = gt_positions - -# @property -# def peaks(self): -# if self._peaks is None: -# self.detect_peaks() -# return self._peaks - -# @property -# def selected_peaks(self): -# if self._selected_peaks is None: -# self.select_peaks() -# return self._selected_peaks - -# @property -# def positions(self): -# if self._positions is None: -# self.localize_peaks() -# return self._positions - -# @property -# def gt_positions(self): -# if self._gt_positions is None: -# self.localize_gt_peaks() -# return self._gt_positions - -# def detect_peaks(self, method_kwargs={"method": "locally_exclusive"}): -# from spikeinterface.sortingcomponents.peak_detection import detect_peaks - -# if self.verbose: -# method = method_kwargs["method"] -# print(f"Detecting peaks with method {method}") -# self._peaks = detect_peaks(self.recording_f, **method_kwargs, **self.job_kwargs) - -# def select_peaks(self, method_kwargs={"method": "uniform", "n_peaks": 100}): -# from spikeinterface.sortingcomponents.peak_selection import select_peaks - -# if self.verbose: -# method = method_kwargs["method"] -# print(f"Selecting peaks with method {method}") -# self._selected_peaks = select_peaks(self.peaks, **method_kwargs, **self.job_kwargs) -# if self.verbose: -# ratio = len(self._selected_peaks) / len(self.peaks) -# print(f"The ratio of peaks kept for clustering is {ratio}%") - -# def localize_peaks(self, method_kwargs={"method": "center_of_mass"}): -# from spikeinterface.sortingcomponents.peak_localization import localize_peaks - -# if self.verbose: -# method = method_kwargs["method"] -# print(f"Localizing peaks with method {method}") -# self._positions = localize_peaks(self.recording_f, self.selected_peaks, **method_kwargs, **self.job_kwargs) - -# def localize_gt_peaks(self, method_kwargs={"method": "center_of_mass"}): -# from spikeinterface.sortingcomponents.peak_localization import localize_peaks - -# if self.verbose: -# method = method_kwargs["method"] -# print(f"Localizing gt peaks with method {method}") -# self._gt_positions = localize_peaks(self.recording_f, self.gt_peaks, **method_kwargs, **self.job_kwargs) - -# def run(self, peaks=None, positions=None, method=None, method_kwargs={}, delta=0.2): -# t_start = time.time() -# if method is not None: -# self.method = method -# if peaks is not None: -# self._peaks = peaks -# self._selected_peaks = peaks - -# nb_peaks = len(self.selected_peaks) -# if self.verbose: -# print(f"Launching the {self.method} clustering algorithm with {nb_peaks} peaks") - -# if positions is not None: -# self._positions = positions - -# labels, peak_labels = find_cluster_from_peaks( -# self.recording_f, self.selected_peaks, method=self.method, method_kwargs=method_kwargs, **self.job_kwargs -# ) -# nb_clusters = len(labels) -# if self.verbose: -# print(f"{nb_clusters} clusters have been found") -# self.noise = peak_labels == -1 -# self.run_time = time.time() - t_start -# self.selected_peaks_labels = peak_labels -# self.labels = labels - -# self.clustering = NumpySorting.from_times_labels( -# self.selected_peaks["sample_index"][~self.noise], -# self.selected_peaks_labels[~self.noise], -# self.sampling_rate, -# ) -# if self.verbose: -# print("Performing the comparison with (sliced) ground truth") - -# spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0] -# spikes2 = self.clustering.to_spike_vector(concatenated=False)[0] - -# matches = make_matching_events( -# spikes1["sample_index"], spikes2["sample_index"], int(delta * self.sampling_rate / 1000) -# ) - -# self.matches = matches -# idx = matches["index1"] -# self.sliced_gt_sorting = NumpySorting(spikes1[idx], self.sampling_rate, self.gt_sorting.unit_ids) - -# self.comp = GroundTruthComparison(self.sliced_gt_sorting, self.clustering, exhaustive_gt=self.exhaustive_gt) - -# for label, sorting in zip( -# ["gt", "clustering", "full_gt"], [self.sliced_gt_sorting, self.clustering, self.gt_sorting] -# ): -# tmp_folder = os.path.join(self.tmp_folder, label) -# if os.path.exists(tmp_folder): -# import shutil - -# shutil.rmtree(tmp_folder) - -# if not (label == "full_gt" and label in self.waveforms): -# if self.verbose: -# print(f"Extracting waveforms for {label}") - -# self.waveforms[label] = extract_waveforms( -# self.recording_f, -# sorting, -# tmp_folder, -# load_if_exists=True, -# ms_before=2.5, -# ms_after=3.5, -# max_spikes_per_unit=500, -# return_scaled=False, -# **self.job_kwargs, -# ) -# # self.pcas[label] = compute_principal_components(self.waveforms[label], load_if_exists=True, -# # n_components=5, mode='by_channel_local', -# # whiten=True, dtype='float32') - -# self.templates[label] = self.waveforms[label].get_all_templates(mode="median") - -# if self.gt_peaks is None: -# if self.verbose: -# print("Computing gt peaks") -# gt_peaks_ = self.gt_sorting.to_spike_vector() -# self.gt_peaks = np.zeros( -# gt_peaks_.size, dtype=[("sample_index", " Date: Wed, 21 Feb 2024 18:32:50 +0100 Subject: [PATCH 120/192] wip --- .../benchmark/benchmark_motion_estimation.py | 60 + .../benchmark_motion_interpolation.py | 1294 +++++++++-------- 2 files changed, 725 insertions(+), 629 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index ee9fa885c4..be048ede50 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -30,6 +30,66 @@ +def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direction_dim = 1): + """ + Get final displacement vector unit per units. + + See drifting_tools for shapes. + + + Parameters + ---------- + + displacement_vectors: list of numpy array + The lenght of the list is the number of segment. + Per segment, the drift vector is a numpy array with shape (num_times, 2, num_motions) + num_motions is generally = 1 but can be > 1 in case of combining several drift vectors + displacement_unit_factor: numpy array or None, default: None + A array containing the factor per unit of the drift. + This is used to create non rigid with a factor gradient of depending on units position. + shape (num_units, num_motions) + If None then all unit have the same factor (1) and the drift is rigid. + + Returns + ------- + unit_displacements: numpy array + shape (num_times, num_units) + + + """ + num_units = displacement_unit_factor.shape[0] + unit_displacements = np.zeros((displacement_vectors.shape[0], num_units)) + for i in range(displacement_vectors.shape[2]): + m = displacement_vectors[:, direction_dim, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :] + unit_displacements[:, :] += m + + return unit_displacements + + +def get_gt_motion_from_unit_discplacement(unit_displacements, displacement_sampling_frequency, + unit_locations, + temporal_bins, spatial_bins, + direction_dim=1,): + + times = np.arange(unit_displacements.shape[0]) / displacement_sampling_frequency + f = scipy.interpolate.interp1d(times, unit_displacements, axis=0) + unit_displacements = f(temporal_bins) + + # spatial interpolataion of units discplacement + if spatial_bins.shape[0] == 1: + # rigid + gt_motion = np.mean(unit_displacements, axis=1)[:, None] + else: + # non rigid + gt_motion = np.zeros((temporal_bins.size, spatial_bins.size)) + for t in range(temporal_bins.shape[0]): + f = scipy.interpolate.interp1d(unit_locations[:, direction_dim], unit_displacements[t, :], fill_value="extrapolate") + gt_motion[t, :] = f(spatial_bins) + + return gt_motion + + + class MotionEstimationBenchmark(Benchmark): def __init__(self, recording, gt_sorting, params, unit_locations, unit_displacements, displacement_sampling_frequency, diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index ad19ff08aa..a3b3b4172f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -7,674 +7,710 @@ import shutil -from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten +# from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten from spikeinterface.sorters import run_sorter, read_sorter_folder -from spikeinterface.comparison import GroundTruthComparison +# from spikeinterface.comparison import GroundTruthComparison from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import BenchmarkBase, _simpleaxis -from spikeinterface.qualitymetrics import compute_quality_metrics -from spikeinterface.widgets import plot_sorting_performance -from spikeinterface.qualitymetrics import compute_quality_metrics -from spikeinterface.curation import MergeUnitsSorting -from spikeinterface.core import get_template_extremum_channel +# from spikeinterface.qualitymetrics import compute_quality_metrics +# from spikeinterface.curation import MergeUnitsSorting +# from spikeinterface.core import get_template_extremum_channel -import sklearn + +# import sklearn from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis import matplotlib.pyplot as plt -import MEArec as mr - +# import MEArec as mr class MotionInterpolationBenchmark(Benchmark): - def __init__(self, recording, gt_sorting, params, - unit_locations, unit_displacements, displacement_sampling_frequency, - direction="y"): + def __init__(self, static_recording, gt_sorting, params, + sorter_folder, drifting_recording, + motion, temporal_bins, spatial_bins, + ): Benchmark.__init__(self) - self.recording = recording + self.static_recording = static_recording self.gt_sorting = gt_sorting self.params = params - self.unit_locations = unit_locations - self.unit_displacements = unit_displacements - self.displacement_sampling_frequency = displacement_sampling_frequency - self.direction = direction - self.direction_dim = ["x", "y"].index(direction) - - def run(self, **job_kwargs): - p = self.params - - -class MotionInterpolationStudy(BenchmarkStudy): - - benchmark_class = MotionInterpolationBenchmark - - def create_benchmark(self, key): - dataset_key = self.cases[key]["dataset"] - recording, gt_sorting = self.datasets[dataset_key] - params = self.cases[key]["params"] - init_kwargs = self.cases[key]["init_kwargs"] - benchmark = MotionInterpolationBenchmark(recording, gt_sorting, params, **init_kwargs) - return benchmark - - - - - -class BenchmarkMotionInterpolationMearec(BenchmarkBase): - _array_names = ("gt_motion", "estimated_motion", "temporal_bins", "spatial_bins") - _waveform_names = ("static", "drifting", "corrected_gt", "corrected_estimated") - _sorting_names = () - - _array_names_from_parent = () - _waveform_names_from_parent = ("static", "drifting") - _sorting_names_from_parent = ("static", "drifting") - - def __init__( - self, - mearec_filename_drifting, - mearec_filename_static, - gt_motion, - estimated_motion, - temporal_bins, - spatial_bins, - do_preprocessing=True, - correct_motion_kwargs={}, - waveforms_kwargs=dict( - ms_before=1.0, - ms_after=3.0, - max_spikes_per_unit=500, - ), - sparse_kwargs=dict( - method="radius", - radius_um=100.0, - ), - sorter_cases={}, - folder=None, - title="", - job_kwargs={"chunk_duration": "1s", "n_jobs": -1, "progress_bar": True, "verbose": True}, - overwrite=False, - delete_output_folder=True, - parent_benchmark=None, - ): - BenchmarkBase.__init__( - self, - folder=folder, - title=title, - overwrite=overwrite, - job_kwargs=job_kwargs, - parent_benchmark=parent_benchmark, - ) - - self._args.extend([str(mearec_filename_drifting), str(mearec_filename_static), None, None, None, None]) - - self.sorter_cases = sorter_cases.copy() - self.mearec_filenames = {} - self.keys = ["static", "drifting", "corrected_gt", "corrected_estimated"] - self.mearec_filenames["drifting"] = mearec_filename_drifting - self.mearec_filenames["static"] = mearec_filename_static + self.sorter_folder = sorter_folder + self.drifting_recording = drifting_recording + self.motion = motion self.temporal_bins = temporal_bins self.spatial_bins = spatial_bins - self.gt_motion = gt_motion - self.estimated_motion = estimated_motion - self.do_preprocessing = do_preprocessing - self.delete_output_folder = delete_output_folder - - self._recordings = None - _, self.sorting_gt = read_mearec(self.mearec_filenames["static"]) - - self.correct_motion_kwargs = correct_motion_kwargs.copy() - self.sparse_kwargs = sparse_kwargs.copy() - self.waveforms_kwargs = waveforms_kwargs.copy() - self.comparisons = {} - self.accuracies = {} - - self._kwargs.update( - dict( - correct_motion_kwargs=self.correct_motion_kwargs, - sorter_cases=self.sorter_cases, - do_preprocessing=do_preprocessing, - delete_output_folder=delete_output_folder, - waveforms_kwargs=waveforms_kwargs, - sparse_kwargs=sparse_kwargs, - ) - ) - @property - def recordings(self): - if self._recordings is None: - self._recordings = {} - - for key in ( - "drifting", - "static", - ): - rec, _ = read_mearec(self.mearec_filenames[key]) - self._recordings["raw_" + key] = rec - - if self.do_preprocessing: - # this processing chain is the same as the kilosort2.5 - # this is important if we want to skip the kilosort preprocessing - # * all computation are done in float32 - # * 150um is more or less 30 channels for the whittening - # * the lastet gain step is super important it is what KS2.5 is doing because the whiten traces - # have magnitude around 1 so a factor (200) is needed to go back to int16 - rec = common_reference(rec, dtype="float32") - rec = highpass_filter(rec, freq_min=150.0) - rec = whiten(rec, mode="local", radius_um=150.0, num_chunks_per_segment=40, chunk_size=32000) - rec = scale(rec, gain=200, dtype="int16") - self._recordings[key] = rec - - rec = self._recordings["drifting"] - self._recordings["corrected_gt"] = InterpolateMotionRecording( - rec, self.gt_motion, self.temporal_bins, self.spatial_bins, **self.correct_motion_kwargs - ) - self._recordings["corrected_estimated"] = InterpolateMotionRecording( - rec, self.estimated_motion, self.temporal_bins, self.spatial_bins, **self.correct_motion_kwargs - ) + def run(self, **job_kwargs): - return self._recordings - - def run(self): - self.extract_waveforms() - self.save_to_folder() - self.run_sorters() - self.save_to_folder() - - def extract_waveforms(self): - # the sparsity is estimated on the static recording and propagated to all of then - if self.parent_benchmark is None: - wf_kwargs = self.waveforms_kwargs.copy() - wf_kwargs.pop("max_spikes_per_unit", None) - sparsity = precompute_sparsity( - self.recordings["static"], - self.sorting_gt, - num_spikes_for_sparsity=200.0, - unit_batch_size=10000, - **wf_kwargs, - **self.sparse_kwargs, - **self.job_kwargs, + if self.params["recording_source"] == 'static': + recording = self.static_recording + elif self.params["recording_source"] == 'drifting': + recording = self.drifting_recording + elif self.params["recording_source"] == 'corrected': + correct_motion_kwargs = self.params["correct_motion_kwargs"] + recording = InterpolateMotionRecording( + self.drifting_recording, self.motion, self.temporal_bins, self.spatial_bins, **correct_motion_kwargs ) else: - sparsity = self.waveforms["static"].sparsity - - for key in self.keys: - if self.parent_benchmark is not None and key in self._waveform_names_from_parent: - continue - - waveforms_folder = self.folder / "waveforms" / key - we = WaveformExtractor.create( - self.recordings[key], - self.sorting_gt, - waveforms_folder, - mode="folder", - sparsity=sparsity, - remove_if_exists=True, - ) - we.set_params(**self.waveforms_kwargs, return_scaled=True) - we.run_extract_waveforms(seed=22051977, **self.job_kwargs) - self.waveforms[key] = we - - def run_sorters(self, skip_already_done=True): - for case in self.sorter_cases: - label = case["label"] - print("run sorter", label) - sorter_name = case["sorter_name"] - sorter_params = case["sorter_params"] - recording = self.recordings[case["recording"]] - output_folder = self.folder / f"tmp_sortings_{label}" - if output_folder.exists() and skip_already_done: - print("already done") - sorting = read_sorter_folder(output_folder) - else: - sorting = run_sorter( - sorter_name, - recording, - output_folder, - **sorter_params, - delete_output_folder=self.delete_output_folder, - ) - self.sortings[label] = sorting - - def compute_distances_to_static(self, force=False): - if hasattr(self, "distances") and not force: - return self.distances - - self.distances = {} - - n = len(self.waveforms["static"].unit_ids) - - sparsity = self.waveforms["static"].sparsity - - ref_templates = self.waveforms["static"].get_all_templates() - - for key in self.keys: - if self.parent_benchmark is not None and key in ("drifting", "static"): - continue - - print(key) - dist = self.distances[key] = { - "std": np.zeros(n), - "norm_std": np.zeros(n), - "template_norm_distance": np.zeros(n), - "template_cosine": np.zeros(n), - } - - templates = self.waveforms[key].get_all_templates() - - extremum_channel = get_template_extremum_channel(self.waveforms["static"], outputs="index") - - for unit_ind, unit_id in enumerate(self.waveforms[key].sorting.unit_ids): - mask = sparsity.mask[unit_ind, :] - ref_template = ref_templates[unit_ind][:, mask] - template = templates[unit_ind][:, mask] - - max_chan = extremum_channel[unit_id] - max_chan - - max_chan_sparse = list(np.nonzero(mask)[0]).index(max_chan) - - # this is already sparse - wfs = self.waveforms[key].get_waveforms(unit_id) - ref_wfs = self.waveforms["static"].get_waveforms(unit_id) - - rms = np.sqrt(np.mean(template**2)) - ref_rms = np.sqrt(np.mean(ref_template**2)) - if rms == 0: - print(key, unit_id, unit_ind, rms, ref_rms) - - dist["std"][unit_ind] = np.mean(np.std(wfs, axis=0), axis=(0, 1)) - dist["norm_std"][unit_ind] = np.mean(np.std(wfs, axis=0), axis=(0, 1)) / rms - dist["template_norm_distance"][unit_ind] = np.sum((ref_template - template) ** 2) / ref_rms - dist["template_cosine"][unit_ind] = sklearn.metrics.pairwise.cosine_similarity( - ref_template.reshape(1, -1), template.reshape(1, -1) - )[0] - - return self.distances - - def compute_residuals(self, force=True): - fr = int(self.recordings["static"].get_sampling_frequency()) - duration = int(self.recordings["static"].get_total_duration()) - - t_start = 0 - t_stop = duration - - if hasattr(self, "residuals") and not force: - return self.residuals, (t_start, t_stop) - - self.residuals = {} - - for key in ["corrected"]: - difference = ResidualRecording(self.recordings["static"], self.recordings[key]) - self.residuals[key] = np.zeros((self.recordings["static"].get_num_channels(), 0)) - - for i in np.arange(t_start * fr, t_stop * fr, fr): - data = np.linalg.norm(difference.get_traces(start_frame=i, end_frame=i + fr), axis=0) / np.sqrt(fr) - self.residuals[key] = np.hstack((self.residuals[key], data[:, np.newaxis])) - - return self.residuals, (t_start, t_stop) - - def compute_accuracies(self): - for case in self.sorter_cases: - label = case["label"] - sorting = self.sortings[label] - if label not in self.comparisons: - comp = GroundTruthComparison(self.sorting_gt, sorting, exhaustive_gt=True) - self.comparisons[label] = comp - self.accuracies[label] = comp.get_performance()["accuracy"].values - - def _plot_accuracy( - self, accuracies, mode="ordered_accuracy", figsize=(15, 5), axes=None, ax=None, ls="-", legend=True, colors=None - ): - if len(self.accuracies) != len(self.sorter_cases): - self.compute_accuracies() - - n = len(self.sorter_cases) - - if "depth" in mode: - # gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filenames['drifting'], time_vector=np.array([0., 1.])) - # unit_depth = gt_unit_positions[0, :] - - template_locations = np.array(mr.load_recordings(self.mearec_filenames["drifting"]).template_locations) - assert len(template_locations.shape) == 3 - mid = template_locations.shape[1] // 2 - unit_depth = template_locations[:, mid, 2] - - chan_locations = self.recordings["drifting"].get_channel_locations() - - if mode == "ordered_accuracy": - if ax is None: - fig, ax = plt.subplots(figsize=figsize) - else: - fig = ax.figure - - order = None - for i, case in enumerate(self.sorter_cases): - color = colors[i] if colors is not None else None - label = case["label"] - # comp = self.comparisons[label] - acc = accuracies[label] - order = np.argsort(acc)[::-1] - acc = acc[order] - ax.plot(acc, label=label, ls=ls, color=color) - if legend: - ax.legend() - ax.set_ylabel("accuracy") - ax.set_xlabel("units ordered by accuracy") - - elif mode == "depth_snr": - if axes is None: - fig, axs = plt.subplots(nrows=n, figsize=figsize, sharey=True, sharex=True) - else: - fig = axes[0].figure - axs = axes - - metrics = compute_quality_metrics(self.waveforms["static"], metric_names=["snr"], load_if_exists=True) - snr = metrics["snr"].values - - for i, case in enumerate(self.sorter_cases): - ax = axs[i] - label = case["label"] - acc = accuracies[label] - - points = ax.scatter(unit_depth, snr, c=acc) - points.set_clim(0.0, 1.0) - ax.set_title(label) - ax.axvline(np.min(chan_locations[:, 1]), ls="--", color="k") - ax.axvline(np.max(chan_locations[:, 1]), ls="--", color="k") - ax.set_ylabel("snr") - ax.set_xlabel("depth") - - cbar = fig.colorbar(points, ax=axs[:], location="right", shrink=0.6) - cbar.ax.set_ylabel("accuracy") - - elif mode == "snr": - fig, ax = plt.subplots(figsize=figsize) - - metrics = compute_quality_metrics(self.waveforms["static"], metric_names=["snr"], load_if_exists=True) - snr = metrics["snr"].values - - for i, case in enumerate(self.sorter_cases): - label = case["label"] - acc = self.accuracies[label] - ax.scatter(snr, acc, label=label) - ax.set_xlabel("snr") - ax.set_ylabel("accuracy") - - ax.legend() - - elif mode == "depth": - fig, ax = plt.subplots(figsize=figsize) - - for i, case in enumerate(self.sorter_cases): - label = case["label"] - acc = accuracies[label] - - ax.scatter(unit_depth, acc, label=label) - ax.axvline(np.min(chan_locations[:, 1]), ls="--", color="k") - ax.axvline(np.max(chan_locations[:, 1]), ls="--", color="k") - ax.legend() - ax.set_xlabel("depth") - ax.set_ylabel("accuracy") - - return fig - - def plot_sortings_accuracy(self, **kwargs): - if len(self.accuracies) != len(self.sorter_cases): - self.compute_accuracies() - - return self._plot_accuracy(self.accuracies, ls="-", **kwargs) - - def plot_best_merges_accuracy(self, **kwargs): - return self._plot_accuracy(self.merged_accuracies, **kwargs, ls="--") - - def plot_sorting_units_categories(self): - if len(self.accuracies) != len(self.sorter_cases): - self.compute_accuracies() - - for i, case in enumerate(self.sorter_cases): - label = case["label"] - comp = self.comparisons[label] - count = comp.count_units_categories() - if i == 0: - df = pd.DataFrame(columns=count.index) - df.loc[label, :] = count - df.plot.bar() - - def find_best_merges(self, merging_score=0.2): - # this find best merges having the ground truth - - self.merged_sortings = {} - self.merged_comparisons = {} - self.merged_accuracies = {} - self.units_to_merge = {} - for i, case in enumerate(self.sorter_cases): - label = case["label"] - # print() - # print(label) - gt_unit_ids = self.sorting_gt.unit_ids - sorting = self.sortings[label] - unit_ids = sorting.unit_ids - - comp = self.comparisons[label] - scores = comp.agreement_scores - - to_merge = [] - for gt_unit_id in gt_unit_ids: - (inds,) = np.nonzero(scores.loc[gt_unit_id, :].values > merging_score) - merge_ids = unit_ids[inds] - if merge_ids.size > 1: - to_merge.append(list(merge_ids)) - - self.units_to_merge[label] = to_merge - merged_sporting = MergeUnitsSorting(sorting, to_merge) - comp_merged = GroundTruthComparison(self.sorting_gt, merged_sporting, exhaustive_gt=True) - - self.merged_sortings[label] = merged_sporting - self.merged_comparisons[label] = comp_merged - self.merged_accuracies[label] = comp_merged.get_performance()["accuracy"].values - - -def plot_distances_to_static(benchmarks, metric="cosine", figsize=(15, 10)): - fig = plt.figure(figsize=figsize) - gs = fig.add_gridspec(4, 2) - - ax = fig.add_subplot(gs[0:2, 0]) - for count, bench in enumerate(benchmarks): - distances = bench.compute_distances_to_static(force=False) - print(distances.keys()) - ax.scatter( - distances["drifting"][f"template_{metric}"], - distances["corrected"][f"template_{metric}"], - c=f"C{count}", - alpha=0.5, - label=bench.title, + raise ValueError("recording_source") + + sorter_name = self.params["sorter_name"] + sorter_params = self.params["sorter_params"] + sorting = run_sorter( + sorter_name, + recording, + output_folder=self.sorter_folder, + **sorter_params, + delete_output_folder=False, ) - ax.legend() - - xmin, xmax = ax.get_xlim() - ax.plot([xmin, xmax], [xmin, xmax], "k--") - _simpleaxis(ax) - if metric == "euclidean": - ax.set_xlabel(r"$\|drift - static\|_2$") - ax.set_ylabel(r"$\|corrected - static\|_2$") - elif metric == "cosine": - ax.set_xlabel(r"$cosine(drift, static)$") - ax.set_ylabel(r"$cosine(corrected, static)$") - - recgen = mr.load_recordings(benchmarks[0].mearec_filenames["static"]) - nb_templates, nb_versions, _ = recgen.template_locations.shape - template_positions = recgen.template_locations[:, nb_versions // 2, 1:3] - distances_to_center = template_positions[:, 1] - - ax_1 = fig.add_subplot(gs[0, 1]) - ax_2 = fig.add_subplot(gs[1, 1]) - ax_3 = fig.add_subplot(gs[2:, 1]) - ax_4 = fig.add_subplot(gs[2:, 0]) - - for count, bench in enumerate(benchmarks): - # results = bench._compute_snippets_variability(metric=metric, num_channels=num_channels) - distances = bench.compute_distances_to_static(force=False) - - m_differences = distances["corrected"][f"wf_{metric}_mean"] / distances["static"][f"wf_{metric}_mean"] - s_differences = distances["corrected"][f"wf_{metric}_std"] / distances["static"][f"wf_{metric}_std"] - - ax_3.bar([count], [m_differences.mean()], yerr=[m_differences.std()], color=f"C{count}") - ax_4.bar([count], [s_differences.mean()], yerr=[s_differences.std()], color=f"C{count}") - idx = np.argsort(distances_to_center) - ax_1.scatter(distances_to_center[idx], m_differences[idx], color=f"C{count}") - ax_2.scatter(distances_to_center[idx], s_differences[idx], color=f"C{count}") - - for a in [ax_1, ax_2, ax_3, ax_4]: - _simpleaxis(a) - - if metric == "euclidean": - ax_1.set_ylabel(r"$\Delta mean(\|~\|_2)$ (% static)") - ax_2.set_ylabel(r"$\Delta std(\|~\|_2)$ (% static)") - ax_3.set_ylabel(r"$\Delta mean(\|~\|_2)$ (% static)") - ax_4.set_ylabel(r"$\Delta std(\|~\|_2)$ (% static)") - elif metric == "cosine": - ax_1.set_ylabel(r"$\Delta mean(cosine)$ (% static)") - ax_2.set_ylabel(r"$\Delta std(cosine)$ (% static)") - ax_3.set_ylabel(r"$\Delta mean(cosine)$ (% static)") - ax_4.set_ylabel(r"$\Delta std(cosine)$ (% static)") - ax_3.set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) - ax_4.set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) - xmin, xmax = ax_3.get_xlim() - ax_3.plot([xmin, xmax], [1, 1], "k--") - ax_4.plot([xmin, xmax], [1, 1], "k--") - ax_1.set_xticks([]) - ax_2.set_xlabel("depth (um)") - - xmin, xmax = ax_1.get_xlim() - ax_1.plot([xmin, xmax], [1, 1], "k--") - ax_2.plot([xmin, xmax], [1, 1], "k--") - plt.tight_layout() - - -def plot_snr_decrease(benchmarks, figsize=(15, 10)): - fig, axes = plt.subplots(2, 2, figsize=figsize, squeeze=False) - - recgen = mr.load_recordings(benchmarks[0].mearec_filenames["static"]) - nb_templates, nb_versions, _ = recgen.template_locations.shape - template_positions = recgen.template_locations[:, nb_versions // 2, 1:3] - distances_to_center = template_positions[:, 1] - idx = np.argsort(distances_to_center) - _simpleaxis(axes[0, 0]) - - snr_static = compute_quality_metrics(benchmarks[0].waveforms["static"], metric_names=["snr"], load_if_exists=True) - snr_drifting = compute_quality_metrics( - benchmarks[0].waveforms["drifting"], metric_names=["snr"], load_if_exists=True - ) - - m = np.max(snr_static) - axes[0, 0].scatter(snr_static.values, snr_drifting.values, c="0.5") - axes[0, 0].plot([0, m], [0, m], color="k") - - axes[0, 0].set_ylabel("units SNR for drifting") - _simpleaxis(axes[0, 0]) - - axes[0, 1].plot(distances_to_center[idx], (snr_drifting.values / snr_static.values)[idx], c="0.5") - axes[0, 1].plot(distances_to_center[idx], np.ones(len(idx)), "k--") - _simpleaxis(axes[0, 1]) - axes[0, 1].set_xticks([]) - axes[0, 0].set_xticks([]) - - for count, bench in enumerate(benchmarks): - snr_corrected = compute_quality_metrics(bench.waveforms["corrected"], metric_names=["snr"], load_if_exists=True) - axes[1, 0].scatter(snr_static.values, snr_corrected.values, label=bench.title) - axes[1, 0].plot([0, m], [0, m], color="k") - - axes[1, 1].plot(distances_to_center[idx], (snr_corrected.values / snr_static.values)[idx], c=f"C{count}") - - axes[1, 0].set_xlabel("units SNR for static") - axes[1, 0].set_ylabel("units SNR for corrected") - axes[1, 1].plot(distances_to_center[idx], np.ones(len(idx)), "k--") - axes[1, 0].legend() - _simpleaxis(axes[1, 0]) - _simpleaxis(axes[1, 1]) - axes[1, 1].set_ylabel(r"$\Delta(SNR)$") - axes[0, 1].set_ylabel(r"$\Delta(SNR)$") - - axes[1, 1].set_xlabel("depth (um)") - - -def plot_residuals_comparisons(benchmarks): - fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - for count, bench in enumerate(benchmarks): - residuals, (t_start, t_stop) = bench.compute_residuals(force=False) - time_axis = np.arange(t_start, t_stop) - axes[0].plot(time_axis, residuals["corrected"].mean(0), label=bench.title) - axes[0].legend() - axes[0].set_xlabel("time (s)") - axes[0].set_ylabel(r"$|S_{corrected} - S_{static}|$") - _simpleaxis(axes[0]) - - channel_positions = benchmarks[0].recordings["static"].get_channel_locations() - distances_to_center = channel_positions[:, 1] - idx = np.argsort(distances_to_center) - - for count, bench in enumerate(benchmarks): - residuals, (t_start, t_stop) = bench.compute_residuals(force=False) - time_axis = np.arange(t_start, t_stop) - axes[1].plot( - distances_to_center[idx], residuals["corrected"].mean(1)[idx], label=bench.title, lw=2, c=f"C{count}" - ) - axes[1].fill_between( - distances_to_center[idx], - residuals["corrected"].mean(1)[idx] - residuals["corrected"].std(1)[idx], - residuals["corrected"].mean(1)[idx] + residuals["corrected"].std(1)[idx], - color=f"C{count}", - alpha=0.25, - ) - axes[1].set_xlabel("depth (um)") - _simpleaxis(axes[1]) + self.result["sorting"] = sorting - for count, bench in enumerate(benchmarks): - residuals, (t_start, t_stop) = bench.compute_residuals(force=False) - axes[2].bar([count], [residuals["corrected"].mean()], yerr=[residuals["corrected"].std()], color=f"C{count}") + def compute_result(self, **result_params): + pass + # self.result[""] = - _simpleaxis(axes[2]) - axes[2].set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) + _run_key_saved = [ + ] + _result_key_saved = [ + ] -from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +class MotionInterpolationStudy(BenchmarkStudy): -class ResidualRecording(BasePreprocessor): - name = "residual_recording" + benchmark_class = MotionInterpolationBenchmark - def __init__(self, recording_1, recording_2): - assert recording_1.get_num_segments() == recording_2.get_num_segments() - BasePreprocessor.__init__(self, recording_1) + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + init_kwargs = self.cases[key]["init_kwargs"] + sorter_folder = self.folder / "sorters" /self.key_to_str(key) + sorter_folder.parent.mkdir(exist_ok=True) + benchmark = MotionInterpolationBenchmark(recording, gt_sorting, params, + sorter_folder=sorter_folder, **init_kwargs) + return benchmark - for parent_recording_segment_1, parent_recording_segment_2 in zip( - recording_1._recording_segments, recording_2._recording_segments - ): - rec_segment = DifferenceRecordingSegment(parent_recording_segment_1, parent_recording_segment_2) - self.add_recording_segment(rec_segment) - self._kwargs = dict(recording_1=recording_1, recording_2=recording_2) -class DifferenceRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment_1, parent_recording_segment_2): - BasePreprocessorSegment.__init__(self, parent_recording_segment_1) - self.parent_recording_segment_1 = parent_recording_segment_1 - self.parent_recording_segment_2 = parent_recording_segment_2 - def get_traces(self, start_frame, end_frame, channel_indices): - traces_1 = self.parent_recording_segment_1.get_traces(start_frame, end_frame, channel_indices) - traces_2 = self.parent_recording_segment_2.get_traces(start_frame, end_frame, channel_indices) +# class BenchmarkMotionInterpolationMearec(BenchmarkBase): +# _array_names = ("gt_motion", "estimated_motion", "temporal_bins", "spatial_bins") +# _waveform_names = ("static", "drifting", "corrected_gt", "corrected_estimated") +# _sorting_names = () + +# _array_names_from_parent = () +# _waveform_names_from_parent = ("static", "drifting") +# _sorting_names_from_parent = ("static", "drifting") + +# def __init__( +# self, +# mearec_filename_drifting, +# mearec_filename_static, +# gt_motion, +# estimated_motion, +# temporal_bins, +# spatial_bins, +# do_preprocessing=True, +# correct_motion_kwargs={}, +# waveforms_kwargs=dict( +# ms_before=1.0, +# ms_after=3.0, +# max_spikes_per_unit=500, +# ), +# sparse_kwargs=dict( +# method="radius", +# radius_um=100.0, +# ), +# sorter_cases={}, +# folder=None, +# title="", +# job_kwargs={"chunk_duration": "1s", "n_jobs": -1, "progress_bar": True, "verbose": True}, +# overwrite=False, +# delete_output_folder=True, +# parent_benchmark=None, +# ): +# BenchmarkBase.__init__( +# self, +# folder=folder, +# title=title, +# overwrite=overwrite, +# job_kwargs=job_kwargs, +# parent_benchmark=parent_benchmark, +# ) + +# self._args.extend([str(mearec_filename_drifting), str(mearec_filename_static), None, None, None, None]) + +# self.sorter_cases = sorter_cases.copy() +# self.mearec_filenames = {} +# self.keys = ["static", "drifting", "corrected_gt", "corrected_estimated"] +# self.mearec_filenames["drifting"] = mearec_filename_drifting +# self.mearec_filenames["static"] = mearec_filename_static + +# self.temporal_bins = temporal_bins +# self.spatial_bins = spatial_bins +# self.gt_motion = gt_motion +# self.estimated_motion = estimated_motion +# self.do_preprocessing = do_preprocessing +# self.delete_output_folder = delete_output_folder + +# self._recordings = None +# _, self.sorting_gt = read_mearec(self.mearec_filenames["static"]) + +# self.correct_motion_kwargs = correct_motion_kwargs.copy() +# self.sparse_kwargs = sparse_kwargs.copy() +# self.waveforms_kwargs = waveforms_kwargs.copy() +# self.comparisons = {} +# self.accuracies = {} + +# self._kwargs.update( +# dict( +# correct_motion_kwargs=self.correct_motion_kwargs, +# sorter_cases=self.sorter_cases, +# do_preprocessing=do_preprocessing, +# delete_output_folder=delete_output_folder, +# waveforms_kwargs=waveforms_kwargs, +# sparse_kwargs=sparse_kwargs, +# ) +# ) + +# @property +# def recordings(self): +# if self._recordings is None: +# self._recordings = {} + +# for key in ( +# "drifting", +# "static", +# ): +# rec, _ = read_mearec(self.mearec_filenames[key]) +# self._recordings["raw_" + key] = rec + +# if self.do_preprocessing: +# # this processing chain is the same as the kilosort2.5 +# # this is important if we want to skip the kilosort preprocessing +# # * all computation are done in float32 +# # * 150um is more or less 30 channels for the whittening +# # * the lastet gain step is super important it is what KS2.5 is doing because the whiten traces +# # have magnitude around 1 so a factor (200) is needed to go back to int16 +# rec = common_reference(rec, dtype="float32") +# rec = highpass_filter(rec, freq_min=150.0) +# rec = whiten(rec, mode="local", radius_um=150.0, num_chunks_per_segment=40, chunk_size=32000) +# rec = scale(rec, gain=200, dtype="int16") +# self._recordings[key] = rec + +# rec = self._recordings["drifting"] +# self._recordings["corrected_gt"] = InterpolateMotionRecording( +# rec, self.gt_motion, self.temporal_bins, self.spatial_bins, **self.correct_motion_kwargs +# ) + +# self._recordings["corrected_estimated"] = InterpolateMotionRecording( +# rec, self.estimated_motion, self.temporal_bins, self.spatial_bins, **self.correct_motion_kwargs +# ) + +# return self._recordings + +# def run(self): +# self.extract_waveforms() +# self.save_to_folder() +# self.run_sorters() +# self.save_to_folder() + +# def extract_waveforms(self): +# # the sparsity is estimated on the static recording and propagated to all of then +# if self.parent_benchmark is None: +# wf_kwargs = self.waveforms_kwargs.copy() +# wf_kwargs.pop("max_spikes_per_unit", None) +# sparsity = precompute_sparsity( +# self.recordings["static"], +# self.sorting_gt, +# num_spikes_for_sparsity=200.0, +# unit_batch_size=10000, +# **wf_kwargs, +# **self.sparse_kwargs, +# **self.job_kwargs, +# ) +# else: +# sparsity = self.waveforms["static"].sparsity + +# for key in self.keys: +# if self.parent_benchmark is not None and key in self._waveform_names_from_parent: +# continue + +# waveforms_folder = self.folder / "waveforms" / key +# we = WaveformExtractor.create( +# self.recordings[key], +# self.sorting_gt, +# waveforms_folder, +# mode="folder", +# sparsity=sparsity, +# remove_if_exists=True, +# ) +# we.set_params(**self.waveforms_kwargs, return_scaled=True) +# we.run_extract_waveforms(seed=22051977, **self.job_kwargs) +# self.waveforms[key] = we + +# def run_sorters(self, skip_already_done=True): +# for case in self.sorter_cases: +# label = case["label"] +# print("run sorter", label) +# sorter_name = case["sorter_name"] +# sorter_params = case["sorter_params"] +# recording = self.recordings[case["recording"]] +# output_folder = self.folder / f"tmp_sortings_{label}" +# if output_folder.exists() and skip_already_done: +# print("already done") +# sorting = read_sorter_folder(output_folder) +# else: +# sorting = run_sorter( +# sorter_name, +# recording, +# output_folder, +# **sorter_params, +# delete_output_folder=self.delete_output_folder, +# ) +# self.sortings[label] = sorting + +# def compute_distances_to_static(self, force=False): +# if hasattr(self, "distances") and not force: +# return self.distances + +# self.distances = {} + +# n = len(self.waveforms["static"].unit_ids) + +# sparsity = self.waveforms["static"].sparsity + +# ref_templates = self.waveforms["static"].get_all_templates() + +# for key in self.keys: +# if self.parent_benchmark is not None and key in ("drifting", "static"): +# continue + +# print(key) +# dist = self.distances[key] = { +# "std": np.zeros(n), +# "norm_std": np.zeros(n), +# "template_norm_distance": np.zeros(n), +# "template_cosine": np.zeros(n), +# } + +# templates = self.waveforms[key].get_all_templates() + +# extremum_channel = get_template_extremum_channel(self.waveforms["static"], outputs="index") + +# for unit_ind, unit_id in enumerate(self.waveforms[key].sorting.unit_ids): +# mask = sparsity.mask[unit_ind, :] +# ref_template = ref_templates[unit_ind][:, mask] +# template = templates[unit_ind][:, mask] + +# max_chan = extremum_channel[unit_id] +# max_chan + +# max_chan_sparse = list(np.nonzero(mask)[0]).index(max_chan) + +# # this is already sparse +# wfs = self.waveforms[key].get_waveforms(unit_id) +# ref_wfs = self.waveforms["static"].get_waveforms(unit_id) + +# rms = np.sqrt(np.mean(template**2)) +# ref_rms = np.sqrt(np.mean(ref_template**2)) +# if rms == 0: +# print(key, unit_id, unit_ind, rms, ref_rms) + +# dist["std"][unit_ind] = np.mean(np.std(wfs, axis=0), axis=(0, 1)) +# dist["norm_std"][unit_ind] = np.mean(np.std(wfs, axis=0), axis=(0, 1)) / rms +# dist["template_norm_distance"][unit_ind] = np.sum((ref_template - template) ** 2) / ref_rms +# dist["template_cosine"][unit_ind] = sklearn.metrics.pairwise.cosine_similarity( +# ref_template.reshape(1, -1), template.reshape(1, -1) +# )[0] + +# return self.distances + +# def compute_residuals(self, force=True): +# fr = int(self.recordings["static"].get_sampling_frequency()) +# duration = int(self.recordings["static"].get_total_duration()) + +# t_start = 0 +# t_stop = duration + +# if hasattr(self, "residuals") and not force: +# return self.residuals, (t_start, t_stop) + +# self.residuals = {} + +# for key in ["corrected"]: +# difference = ResidualRecording(self.recordings["static"], self.recordings[key]) +# self.residuals[key] = np.zeros((self.recordings["static"].get_num_channels(), 0)) + +# for i in np.arange(t_start * fr, t_stop * fr, fr): +# data = np.linalg.norm(difference.get_traces(start_frame=i, end_frame=i + fr), axis=0) / np.sqrt(fr) +# self.residuals[key] = np.hstack((self.residuals[key], data[:, np.newaxis])) + +# return self.residuals, (t_start, t_stop) + +# def compute_accuracies(self): +# for case in self.sorter_cases: +# label = case["label"] +# sorting = self.sortings[label] +# if label not in self.comparisons: +# comp = GroundTruthComparison(self.sorting_gt, sorting, exhaustive_gt=True) +# self.comparisons[label] = comp +# self.accuracies[label] = comp.get_performance()["accuracy"].values + +# def _plot_accuracy( +# self, accuracies, mode="ordered_accuracy", figsize=(15, 5), axes=None, ax=None, ls="-", legend=True, colors=None +# ): +# if len(self.accuracies) != len(self.sorter_cases): +# self.compute_accuracies() + +# n = len(self.sorter_cases) + +# if "depth" in mode: +# # gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filenames['drifting'], time_vector=np.array([0., 1.])) +# # unit_depth = gt_unit_positions[0, :] + +# template_locations = np.array(mr.load_recordings(self.mearec_filenames["drifting"]).template_locations) +# assert len(template_locations.shape) == 3 +# mid = template_locations.shape[1] // 2 +# unit_depth = template_locations[:, mid, 2] + +# chan_locations = self.recordings["drifting"].get_channel_locations() + +# if mode == "ordered_accuracy": +# if ax is None: +# fig, ax = plt.subplots(figsize=figsize) +# else: +# fig = ax.figure + +# order = None +# for i, case in enumerate(self.sorter_cases): +# color = colors[i] if colors is not None else None +# label = case["label"] +# # comp = self.comparisons[label] +# acc = accuracies[label] +# order = np.argsort(acc)[::-1] +# acc = acc[order] +# ax.plot(acc, label=label, ls=ls, color=color) +# if legend: +# ax.legend() +# ax.set_ylabel("accuracy") +# ax.set_xlabel("units ordered by accuracy") + +# elif mode == "depth_snr": +# if axes is None: +# fig, axs = plt.subplots(nrows=n, figsize=figsize, sharey=True, sharex=True) +# else: +# fig = axes[0].figure +# axs = axes + +# metrics = compute_quality_metrics(self.waveforms["static"], metric_names=["snr"], load_if_exists=True) +# snr = metrics["snr"].values + +# for i, case in enumerate(self.sorter_cases): +# ax = axs[i] +# label = case["label"] +# acc = accuracies[label] + +# points = ax.scatter(unit_depth, snr, c=acc) +# points.set_clim(0.0, 1.0) +# ax.set_title(label) +# ax.axvline(np.min(chan_locations[:, 1]), ls="--", color="k") +# ax.axvline(np.max(chan_locations[:, 1]), ls="--", color="k") +# ax.set_ylabel("snr") +# ax.set_xlabel("depth") + +# cbar = fig.colorbar(points, ax=axs[:], location="right", shrink=0.6) +# cbar.ax.set_ylabel("accuracy") + +# elif mode == "snr": +# fig, ax = plt.subplots(figsize=figsize) + +# metrics = compute_quality_metrics(self.waveforms["static"], metric_names=["snr"], load_if_exists=True) +# snr = metrics["snr"].values + +# for i, case in enumerate(self.sorter_cases): +# label = case["label"] +# acc = self.accuracies[label] +# ax.scatter(snr, acc, label=label) +# ax.set_xlabel("snr") +# ax.set_ylabel("accuracy") + +# ax.legend() + +# elif mode == "depth": +# fig, ax = plt.subplots(figsize=figsize) + +# for i, case in enumerate(self.sorter_cases): +# label = case["label"] +# acc = accuracies[label] + +# ax.scatter(unit_depth, acc, label=label) +# ax.axvline(np.min(chan_locations[:, 1]), ls="--", color="k") +# ax.axvline(np.max(chan_locations[:, 1]), ls="--", color="k") +# ax.legend() +# ax.set_xlabel("depth") +# ax.set_ylabel("accuracy") + +# return fig + +# def plot_sortings_accuracy(self, **kwargs): +# if len(self.accuracies) != len(self.sorter_cases): +# self.compute_accuracies() + +# return self._plot_accuracy(self.accuracies, ls="-", **kwargs) + +# def plot_best_merges_accuracy(self, **kwargs): +# return self._plot_accuracy(self.merged_accuracies, **kwargs, ls="--") + +# def plot_sorting_units_categories(self): +# if len(self.accuracies) != len(self.sorter_cases): +# self.compute_accuracies() + +# for i, case in enumerate(self.sorter_cases): +# label = case["label"] +# comp = self.comparisons[label] +# count = comp.count_units_categories() +# if i == 0: +# df = pd.DataFrame(columns=count.index) +# df.loc[label, :] = count +# df.plot.bar() + +# def find_best_merges(self, merging_score=0.2): +# # this find best merges having the ground truth + +# self.merged_sortings = {} +# self.merged_comparisons = {} +# self.merged_accuracies = {} +# self.units_to_merge = {} +# for i, case in enumerate(self.sorter_cases): +# label = case["label"] +# # print() +# # print(label) +# gt_unit_ids = self.sorting_gt.unit_ids +# sorting = self.sortings[label] +# unit_ids = sorting.unit_ids + +# comp = self.comparisons[label] +# scores = comp.agreement_scores + +# to_merge = [] +# for gt_unit_id in gt_unit_ids: +# (inds,) = np.nonzero(scores.loc[gt_unit_id, :].values > merging_score) +# merge_ids = unit_ids[inds] +# if merge_ids.size > 1: +# to_merge.append(list(merge_ids)) + +# self.units_to_merge[label] = to_merge +# merged_sporting = MergeUnitsSorting(sorting, to_merge) +# comp_merged = GroundTruthComparison(self.sorting_gt, merged_sporting, exhaustive_gt=True) + +# self.merged_sortings[label] = merged_sporting +# self.merged_comparisons[label] = comp_merged +# self.merged_accuracies[label] = comp_merged.get_performance()["accuracy"].values + + +# def plot_distances_to_static(benchmarks, metric="cosine", figsize=(15, 10)): +# fig = plt.figure(figsize=figsize) +# gs = fig.add_gridspec(4, 2) + +# ax = fig.add_subplot(gs[0:2, 0]) +# for count, bench in enumerate(benchmarks): +# distances = bench.compute_distances_to_static(force=False) +# print(distances.keys()) +# ax.scatter( +# distances["drifting"][f"template_{metric}"], +# distances["corrected"][f"template_{metric}"], +# c=f"C{count}", +# alpha=0.5, +# label=bench.title, +# ) + +# ax.legend() + +# xmin, xmax = ax.get_xlim() +# ax.plot([xmin, xmax], [xmin, xmax], "k--") +# _simpleaxis(ax) +# if metric == "euclidean": +# ax.set_xlabel(r"$\|drift - static\|_2$") +# ax.set_ylabel(r"$\|corrected - static\|_2$") +# elif metric == "cosine": +# ax.set_xlabel(r"$cosine(drift, static)$") +# ax.set_ylabel(r"$cosine(corrected, static)$") + +# recgen = mr.load_recordings(benchmarks[0].mearec_filenames["static"]) +# nb_templates, nb_versions, _ = recgen.template_locations.shape +# template_positions = recgen.template_locations[:, nb_versions // 2, 1:3] +# distances_to_center = template_positions[:, 1] + +# ax_1 = fig.add_subplot(gs[0, 1]) +# ax_2 = fig.add_subplot(gs[1, 1]) +# ax_3 = fig.add_subplot(gs[2:, 1]) +# ax_4 = fig.add_subplot(gs[2:, 0]) + +# for count, bench in enumerate(benchmarks): +# # results = bench._compute_snippets_variability(metric=metric, num_channels=num_channels) +# distances = bench.compute_distances_to_static(force=False) + +# m_differences = distances["corrected"][f"wf_{metric}_mean"] / distances["static"][f"wf_{metric}_mean"] +# s_differences = distances["corrected"][f"wf_{metric}_std"] / distances["static"][f"wf_{metric}_std"] + +# ax_3.bar([count], [m_differences.mean()], yerr=[m_differences.std()], color=f"C{count}") +# ax_4.bar([count], [s_differences.mean()], yerr=[s_differences.std()], color=f"C{count}") +# idx = np.argsort(distances_to_center) +# ax_1.scatter(distances_to_center[idx], m_differences[idx], color=f"C{count}") +# ax_2.scatter(distances_to_center[idx], s_differences[idx], color=f"C{count}") + +# for a in [ax_1, ax_2, ax_3, ax_4]: +# _simpleaxis(a) + +# if metric == "euclidean": +# ax_1.set_ylabel(r"$\Delta mean(\|~\|_2)$ (% static)") +# ax_2.set_ylabel(r"$\Delta std(\|~\|_2)$ (% static)") +# ax_3.set_ylabel(r"$\Delta mean(\|~\|_2)$ (% static)") +# ax_4.set_ylabel(r"$\Delta std(\|~\|_2)$ (% static)") +# elif metric == "cosine": +# ax_1.set_ylabel(r"$\Delta mean(cosine)$ (% static)") +# ax_2.set_ylabel(r"$\Delta std(cosine)$ (% static)") +# ax_3.set_ylabel(r"$\Delta mean(cosine)$ (% static)") +# ax_4.set_ylabel(r"$\Delta std(cosine)$ (% static)") +# ax_3.set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) +# ax_4.set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) +# xmin, xmax = ax_3.get_xlim() +# ax_3.plot([xmin, xmax], [1, 1], "k--") +# ax_4.plot([xmin, xmax], [1, 1], "k--") +# ax_1.set_xticks([]) +# ax_2.set_xlabel("depth (um)") + +# xmin, xmax = ax_1.get_xlim() +# ax_1.plot([xmin, xmax], [1, 1], "k--") +# ax_2.plot([xmin, xmax], [1, 1], "k--") +# plt.tight_layout() + + +# def plot_snr_decrease(benchmarks, figsize=(15, 10)): +# fig, axes = plt.subplots(2, 2, figsize=figsize, squeeze=False) + +# recgen = mr.load_recordings(benchmarks[0].mearec_filenames["static"]) +# nb_templates, nb_versions, _ = recgen.template_locations.shape +# template_positions = recgen.template_locations[:, nb_versions // 2, 1:3] +# distances_to_center = template_positions[:, 1] +# idx = np.argsort(distances_to_center) +# _simpleaxis(axes[0, 0]) + +# snr_static = compute_quality_metrics(benchmarks[0].waveforms["static"], metric_names=["snr"], load_if_exists=True) +# snr_drifting = compute_quality_metrics( +# benchmarks[0].waveforms["drifting"], metric_names=["snr"], load_if_exists=True +# ) + +# m = np.max(snr_static) +# axes[0, 0].scatter(snr_static.values, snr_drifting.values, c="0.5") +# axes[0, 0].plot([0, m], [0, m], color="k") + +# axes[0, 0].set_ylabel("units SNR for drifting") +# _simpleaxis(axes[0, 0]) + +# axes[0, 1].plot(distances_to_center[idx], (snr_drifting.values / snr_static.values)[idx], c="0.5") +# axes[0, 1].plot(distances_to_center[idx], np.ones(len(idx)), "k--") +# _simpleaxis(axes[0, 1]) +# axes[0, 1].set_xticks([]) +# axes[0, 0].set_xticks([]) + +# for count, bench in enumerate(benchmarks): +# snr_corrected = compute_quality_metrics(bench.waveforms["corrected"], metric_names=["snr"], load_if_exists=True) +# axes[1, 0].scatter(snr_static.values, snr_corrected.values, label=bench.title) +# axes[1, 0].plot([0, m], [0, m], color="k") + +# axes[1, 1].plot(distances_to_center[idx], (snr_corrected.values / snr_static.values)[idx], c=f"C{count}") + +# axes[1, 0].set_xlabel("units SNR for static") +# axes[1, 0].set_ylabel("units SNR for corrected") +# axes[1, 1].plot(distances_to_center[idx], np.ones(len(idx)), "k--") +# axes[1, 0].legend() +# _simpleaxis(axes[1, 0]) +# _simpleaxis(axes[1, 1]) +# axes[1, 1].set_ylabel(r"$\Delta(SNR)$") +# axes[0, 1].set_ylabel(r"$\Delta(SNR)$") + +# axes[1, 1].set_xlabel("depth (um)") + + +# def plot_residuals_comparisons(benchmarks): +# fig, axes = plt.subplots(1, 3, figsize=(15, 5)) +# for count, bench in enumerate(benchmarks): +# residuals, (t_start, t_stop) = bench.compute_residuals(force=False) +# time_axis = np.arange(t_start, t_stop) +# axes[0].plot(time_axis, residuals["corrected"].mean(0), label=bench.title) +# axes[0].legend() +# axes[0].set_xlabel("time (s)") +# axes[0].set_ylabel(r"$|S_{corrected} - S_{static}|$") +# _simpleaxis(axes[0]) + +# channel_positions = benchmarks[0].recordings["static"].get_channel_locations() +# distances_to_center = channel_positions[:, 1] +# idx = np.argsort(distances_to_center) + +# for count, bench in enumerate(benchmarks): +# residuals, (t_start, t_stop) = bench.compute_residuals(force=False) +# time_axis = np.arange(t_start, t_stop) +# axes[1].plot( +# distances_to_center[idx], residuals["corrected"].mean(1)[idx], label=bench.title, lw=2, c=f"C{count}" +# ) +# axes[1].fill_between( +# distances_to_center[idx], +# residuals["corrected"].mean(1)[idx] - residuals["corrected"].std(1)[idx], +# residuals["corrected"].mean(1)[idx] + residuals["corrected"].std(1)[idx], +# color=f"C{count}", +# alpha=0.25, +# ) +# axes[1].set_xlabel("depth (um)") +# _simpleaxis(axes[1]) + +# for count, bench in enumerate(benchmarks): +# residuals, (t_start, t_stop) = bench.compute_residuals(force=False) +# axes[2].bar([count], [residuals["corrected"].mean()], yerr=[residuals["corrected"].std()], color=f"C{count}") + +# _simpleaxis(axes[2]) +# axes[2].set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) + + +# from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment + + +# class ResidualRecording(BasePreprocessor): +# name = "residual_recording" + +# def __init__(self, recording_1, recording_2): +# assert recording_1.get_num_segments() == recording_2.get_num_segments() +# BasePreprocessor.__init__(self, recording_1) + +# for parent_recording_segment_1, parent_recording_segment_2 in zip( +# recording_1._recording_segments, recording_2._recording_segments +# ): +# rec_segment = DifferenceRecordingSegment(parent_recording_segment_1, parent_recording_segment_2) +# self.add_recording_segment(rec_segment) + +# self._kwargs = dict(recording_1=recording_1, recording_2=recording_2) + + +# class DifferenceRecordingSegment(BasePreprocessorSegment): +# def __init__(self, parent_recording_segment_1, parent_recording_segment_2): +# BasePreprocessorSegment.__init__(self, parent_recording_segment_1) +# self.parent_recording_segment_1 = parent_recording_segment_1 +# self.parent_recording_segment_2 = parent_recording_segment_2 + +# def get_traces(self, start_frame, end_frame, channel_indices): +# traces_1 = self.parent_recording_segment_1.get_traces(start_frame, end_frame, channel_indices) +# traces_2 = self.parent_recording_segment_2.get_traces(start_frame, end_frame, channel_indices) - return traces_2 - traces_1 +# return traces_2 - traces_1 -colors = {"static": "C0", "drifting": "C1", "corrected": "C2"} +# colors = {"static": "C0", "drifting": "C1", "corrected": "C2"} From c17d694ce61c69919f82b04746b97e70bba80dd2 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:35:21 -0500 Subject: [PATCH 121/192] make sure `compute` includes return value --- doc/modules/postprocessing.rst | 14 +++++++------- doc/modules/qualitymetrics.rst | 10 ++++++---- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 1b0162aa36..8fbbaf4d86 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -86,7 +86,7 @@ For dense waveforms, sparsity can also be passed as an argument. .. code-block:: python - sorting_analyzer.compute(input="principal_components", + pc = sorting_analyzer.compute(input="principal_components", n_components=3, mode="by_channel_local") @@ -103,7 +103,7 @@ and is not well suited for high-density probes. .. code-block:: python - sorting_analyzer.compute(input="template_similarity", method='cosine_similarity') + similarity = sorting_analyzer.compute(input="template_similarity", method='cosine_similarity') For more information, see :py:func:`~spikeinterface.postprocessing.compute_template_similarity` @@ -121,7 +121,7 @@ each spike. .. code-block:: python - sorting_analyzer.compute(input="spike_amplitudes", + amplitudes = sorting_analyzer.compute(input="spike_amplitudes", peak_sign="neg", outputs="concatenated") @@ -141,7 +141,7 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), .. code-block:: python - sorting_analyzer.compute(input="spike_locations", + spike_locations = sorting_analyzer.compute(input="spike_locations", ms_before=0.5, ms_after=0.5, spike_retriever_kwargs=dict( @@ -166,7 +166,7 @@ based on individual waveforms, it calculates at the unit level using templates. .. code-block:: python - sorting_analyzer.compute(input="unit_locations", method="monopolar_triangulation") + unit_locations = sorting_analyzer.compute(input="unit_locations", method="monopolar_triangulation") For more information, see :py:func:`~spikeinterface.postprocessing.compute_unit_locations` @@ -209,7 +209,7 @@ with shape (num_units, num_units, num_bins) with all correlograms for each pair .. code-block:: python - sorting_analyer.compute(input="correlograms", + ccg = sorting_analyzer.compute(input="correlograms", window_ms=50.0, bin_ms=1.0, method="auto") @@ -226,7 +226,7 @@ This extension computes the histograms of inter-spike-intervals. The computed ou .. code-block:: python - sorting_analyer.compute(input="isi_histograms" + isi = sorting_analyer.compute(input="isi_histograms" window_ms=50.0, bin_ms=1.0, method="auto") diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index 42db0e645f..d3eb26e35a 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -51,16 +51,18 @@ This code snippet shows how to compute quality metrics (with or without principa sorting_analyzer = si.load_sorting_analyzer(folder='waveforms') # start from a sorting_analyzer # without PC (depends on "waveforms", "templates", and "noise_levels") - sorting_analyzer.compute(input="quality_metrics", metric_names=['snr'], skip_pc_metrics=False) - metrics = sorting_analyzer.get_extension(extension_name="quality_metrics") + qm_ext = sorting_analyzer.compute(input="quality_metrics", metric_names=['snr'], skip_pc_metrics=True) + metrics = qm_ext.get_data() assert 'snr' in metrics.columns # with PCs (depends on "pca" in addition to the above metrics) - sorting_analyzer.compute(input={"pca": dict(n_components=5, mode="by_channel_local"), + qm_ext = sorting_analyzer.compute(input={"pca": dict(n_components=5, mode="by_channel_local"), "quality_metrics": dict(skip_pc_metrics=False)}) + metrics = qm_ext.get_data() + assert 'isolation_distance' in metrics.columns - + For more information about quality metrics, check out this excellent `documentation `_ From 3a63faf57f543bf8b0b6f43f30536055550820eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Feb 2024 18:37:05 +0000 Subject: [PATCH 122/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/modules/qualitymetrics.rst | 2 +- src/spikeinterface/postprocessing/isi.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index d3eb26e35a..f5f3581c31 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -62,7 +62,7 @@ This code snippet shows how to compute quality metrics (with or without principa metrics = qm_ext.get_data() assert 'isolation_distance' in metrics.columns - + For more information about quality metrics, check out this excellent `documentation `_ diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 2b39a376e0..a99a677d65 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -103,7 +103,7 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float window_size = int(round(fs * window_ms * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size - bins = np.arange(0, window_size + bin_size, bin_size)# * 1e3 / fs + bins = np.arange(0, window_size + bin_size, bin_size) # * 1e3 / fs ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64) # TODO: There might be a better way than a double for loop? @@ -137,7 +137,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size - bins = np.arange(0, window_size + bin_size, bin_size)# * 1e3 / fs + bins = np.arange(0, window_size + bin_size, bin_size) # * 1e3 / fs spikes = sorting.to_spike_vector(concatenated=False) ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64) @@ -159,7 +159,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float if HAVE_NUMBA: @numba.jit( - (numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.int64[::1]), + (numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.int64[::1]), nopython=True, nogil=True, cache=True, From 0fc81fc674ff291efa1b78b961e707a5131a8d8e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 22 Feb 2024 09:10:48 +0100 Subject: [PATCH 123/192] oups --- .../sortingcomponents/benchmark/benchmark_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 6a95bb4b0b..1bdf53d414 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -272,10 +272,10 @@ def _save_keys(self, saved_keys, folder): raise ValueError(f"Save error {k} {format}") def save_run(self, folder): - self._save_keys(self._run_key_saved) + self._save_keys(self._run_key_saved, folder) def save_result(self, folder): - self._save_keys(self._result_key_saved) + self._save_keys(self._result_key_saved, folder) @classmethod def load_folder(cls, folder): From 2fc311c63b2c387738d4157c3b1c61c5ebc77338 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 22 Feb 2024 09:17:44 +0100 Subject: [PATCH 124/192] WIP --- .../benchmark/benchmark_peak_selection.py | 1269 +++++++++-------- 1 file changed, 679 insertions(+), 590 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index c5844a9c88..16d5725b82 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -28,609 +28,698 @@ import os import numpy as np +from .benchmark_tools import BenchmarkStudy, Benchmark +from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -class BenchmarkPeakSelection: - def __init__(self, recording, gt_sorting, exhaustive_gt=True, job_kwargs={}, tmp_folder=None, verbose=True): - self.verbose = verbose + +class PeakSelectionBenchmark(Benchmark): + + def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.recording = recording self.gt_sorting = gt_sorting - self.job_kwargs = job_kwargs + self.indices = indices + + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer.select_random_spikes() + ext = sorting_analyzer.compute('fast_templates') + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + + peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + if self.indices is None: + self.indices = np.arange(len(peaks)) + self.peaks = peaks[self.indices] + self.params = params self.exhaustive_gt = exhaustive_gt - self.sampling_rate = self.recording.get_sampling_frequency() + self.method = params['method'] + self.method_kwargs = params['method_kwargs'] + self.result = {} + + def run(self, **job_kwargs): + labels, peak_labels = find_cluster_from_peaks( + self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs + ) + self.result['peak_labels'] = peak_labels - self.tmp_folder = tmp_folder - if self.tmp_folder is None: - self.tmp_folder = os.path.join(".", "".join(random.choices(string.ascii_uppercase + string.digits, k=8))) + def compute_result(self, **result_params): + self.noise = self.result['peak_labels'] < 0 - self._peaks = None - self._positions = None - self._gt_positions = None - self.gt_peaks = None + spikes = self.gt_sorting.to_spike_vector() + self.result['sliced_gt_sorting'] = NumpySorting(spikes[self.indices], + self.recording.sampling_frequency, + self.gt_sorting.unit_ids) - self.waveforms = {} - self.pcas = {} - self.templates = {} + data = spikes[self.indices][~self.noise] + data["unit_index"] = self.result['peak_labels'][~self.noise] - def __del__(self): - import shutil + self.result['clustering'] = NumpySorting.from_times_labels(data["sample_index"], + self.result['peak_labels'][~self.noise], + self.recording.sampling_frequency) + + self.result['gt_comparison'] = GroundTruthComparison(self.result['sliced_gt_sorting'], + self.result['clustering'], + exhaustive_gt=self.exhaustive_gt) - shutil.rmtree(self.tmp_folder) + sorting_analyzer = create_sorting_analyzer(self.result['sliced_gt_sorting'], self.recording, format='memory', sparse=False) + sorting_analyzer.select_random_spikes() + ext = sorting_analyzer.compute('fast_templates') + self.result['sliced_gt_templates'] = ext.get_data(outputs="Templates") - def set_peaks(self, peaks): - self._peaks = peaks + sorting_analyzer = create_sorting_analyzer(self.result['clustering'], self.recording, format='memory', sparse=False) + sorting_analyzer.select_random_spikes() + ext = sorting_analyzer.compute('fast_templates') + self.result['clustering_templates'] = ext.get_data(outputs="Templates") - def set_positions(self, positions): - self._positions = positions + _run_key_saved = [ + ("peak_labels", "npy") + ] - @property - def peaks(self): - if self._peaks is None: - self.detect_peaks() - return self._peaks + _result_key_saved = [ + ("gt_comparison", "pickle"), + ("sliced_gt_sorting", "sorting"), + ("clustering", "sorting"), + ("sliced_gt_templates", "zarr_templates"), + ("clustering_templates", "zarr_templates") + ] - @property - def positions(self): - if self._positions is None: - self.localize_peaks() - return self._positions - @property - def gt_positions(self): - if self._gt_positions is None: - self.localize_gt_peaks() - return self._gt_positions +class PeakSelectionStudy(BenchmarkStudy): - def detect_peaks(self, method_kwargs={"method": "locally_exclusive"}): - from spikeinterface.sortingcomponents.peak_detection import detect_peaks + benchmark_class = PeakSelectionBenchmark - if self.verbose: - method = method_kwargs["method"] - print(f"Detecting peaks with method {method}") - self._peaks = detect_peaks(self.recording, **method_kwargs, **self.job_kwargs) - - def localize_peaks(self, method_kwargs={"method": "center_of_mass"}): - from spikeinterface.sortingcomponents.peak_localization import localize_peaks - - if self.verbose: - method = method_kwargs["method"] - print(f"Localizing peaks with method {method}") - self._positions = localize_peaks(self.recording, self.peaks, **method_kwargs, **self.job_kwargs) - - def localize_gt_peaks(self, method_kwargs={"method": "center_of_mass"}): - from spikeinterface.sortingcomponents.peak_localization import localize_peaks - - if self.verbose: - method = method_kwargs["method"] - print(f"Localizing gt peaks with method {method}") - self._gt_positions = localize_peaks(self.recording, self.gt_peaks, **method_kwargs, **self.job_kwargs) - - def run(self, peaks=None, positions=None, delta=0.2): - t_start = time.time() - - if peaks is not None: - self._peaks = peaks - - nb_peaks = len(self.peaks) - - if positions is not None: - self._positions = positions - - spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0]["sample_index"] - times2 = self.peaks["sample_index"] - - print("The gt recording has {} peaks and {} have been detected".format(len(times1[0]), len(times2))) - - matches = make_matching_events(spikes1["sample_index"], times2, int(delta * self.sampling_rate / 1000)) - self.matches = matches - - self.deltas = {"labels": [], "delta": matches["delta_frame"]} - self.deltas["labels"] = spikes1["unit_index"][matches["index1"]] - - gt_matches = matches["index1"] - self.sliced_gt_sorting = NumpySorting(spikes1[gt_matches], self.sampling_rate, self.gt_sorting.unit_ids) - - ratio = 100 * len(gt_matches) / len(spikes1) - print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) - - matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) - self.good_matches = matches["index1"] - - garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) - garbage_channels = self.peaks["channel_index"][garbage_matches] - garbage_peaks = times2[garbage_matches] - nb_garbage = len(garbage_peaks) - - ratio = 100 * len(garbage_peaks) / len(times2) - self.garbage_sorting = NumpySorting.from_times_labels(garbage_peaks, garbage_channels, self.sampling_rate) - - print("The peaks have {0:.2f}% of garbage (without gt around)".format(ratio)) - - self.comp = GroundTruthComparison(self.gt_sorting, self.sliced_gt_sorting, exhaustive_gt=self.exhaustive_gt) - - for label, sorting in zip( - ["gt", "full_gt", "garbage"], [self.sliced_gt_sorting, self.gt_sorting, self.garbage_sorting] - ): - tmp_folder = os.path.join(self.tmp_folder, label) - if os.path.exists(tmp_folder): - import shutil - - shutil.rmtree(tmp_folder) - - if not (label == "full_gt" and label in self.waveforms): - if self.verbose: - print(f"Extracting waveforms for {label}") - - self.waveforms[label] = extract_waveforms( - self.recording, - sorting, - tmp_folder, - load_if_exists=True, - ms_before=2.5, - ms_after=3.5, - max_spikes_per_unit=500, - return_scaled=False, - **self.job_kwargs, - ) - - self.templates[label] = self.waveforms[label].get_all_templates(mode="median") - - if self.gt_peaks is None: - if self.verbose: - print("Computing gt peaks") - gt_peaks_ = self.gt_sorting.to_spike_vector() - self.gt_peaks = np.zeros( - gt_peaks_.size, - dtype=[ - ("sample_index", " Date: Thu, 22 Feb 2024 09:21:32 +0100 Subject: [PATCH 125/192] WIP --- .../benchmark/benchmark_clustering.py | 6 ------ .../benchmark/benchmark_matching.py | 11 ----------- .../benchmark/benchmark_peak_localization.py | 12 +----------- .../benchmark/benchmark_peak_selection.py | 1 - 4 files changed, 1 insertion(+), 29 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index f1c04c23d9..4a418037b3 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -1,12 +1,7 @@ from __future__ import annotations -from spikeinterface.core import extract_waveforms -from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks, clustering_methods -from spikeinterface.preprocessing import bandpass_filter, common_reference from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks -from spikeinterface.extractors import read_mearec from spikeinterface.core import NumpySorting -from spikeinterface.qualitymetrics import compute_quality_metrics from spikeinterface.comparison import GroundTruthComparison from spikeinterface.widgets import ( plot_probe_map, @@ -15,7 +10,6 @@ plot_unit_templates, plot_unit_waveforms, ) -from spikeinterface.postprocessing import compute_principal_components from spikeinterface.comparison.comparisontools import make_matching_events #from spikeinterface.postprocessing import get_template_extremum_channel from spikeinterface.core import get_noise_levels diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 15c637a7cb..23a9f0459c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -1,31 +1,20 @@ from __future__ import annotations -from spikeinterface.preprocessing import bandpass_filter, common_reference from spikeinterface.postprocessing import compute_template_similarity from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.core.template import Templates from spikeinterface.core import NumpySorting -from spikeinterface.qualitymetrics import compute_quality_metrics -from spikeinterface import load_extractor from spikeinterface.comparison import CollisionGTComparison, compare_sorter_to_ground_truth from spikeinterface.widgets import ( plot_agreement_matrix, plot_comparison_collision_by_similarity, - plot_unit_waveforms, ) -import time -import os -import pickle from pathlib import Path -import string, random import pylab as plt import matplotlib.patches as mpatches import numpy as np import pandas as pd -import shutil -import copy -from tqdm.auto import tqdm from .benchmark_tools import BenchmarkStudy, Benchmark from spikeinterface.core.basesorting import minimum_spike_dtype diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index e818d418fa..a4bd6283a0 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -1,11 +1,8 @@ from __future__ import annotations -from spikeinterface.core import extract_waveforms -from spikeinterface.preprocessing import bandpass_filter, common_reference -from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks from spikeinterface.core import NumpySorting -from spikeinterface.qualitymetrics import compute_quality_metrics, compute_snrs +from spikeinterface.qualitymetrics import compute_snrs from spikeinterface.widgets import ( plot_probe_map, plot_agreement_matrix, @@ -13,19 +10,12 @@ plot_unit_templates, plot_unit_waveforms, ) -from spikeinterface.postprocessing import compute_spike_locations from spikeinterface.postprocessing.unit_localization import compute_center_of_mass, compute_monopolar_triangulation, compute_grid_convolution from spikeinterface.core import get_noise_levels -import time -import string, random import pylab as plt -from spikeinterface.core.template import Templates -import os import numpy as np -import pickle from .benchmark_tools import BenchmarkStudy, Benchmark -from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sortinganalyzer import create_sorting_analyzer diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 16d5725b82..0e9f2b9052 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -1,6 +1,5 @@ from __future__ import annotations -from spikeinterface.core import extract_waveforms from spikeinterface.preprocessing import bandpass_filter, common_reference from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.core import NumpySorting From 898f8533ac621f2d047a7d2a4474834d1ec978bd Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 22 Feb 2024 09:59:03 +0100 Subject: [PATCH 126/192] Fix --- .../benchmark/benchmark_peak_localization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index a4bd6283a0..6f75a57a78 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -165,7 +165,7 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): class UnitLocalizationBenchmark(Benchmark): - def __init__(self, recording, gt_sorting, gt_positions, params): + def __init__(self, recording, gt_sorting, params, gt_positions): self.recording = recording self.gt_sorting = gt_sorting self.gt_positions = gt_positions @@ -220,9 +220,9 @@ class UnitLocalizationStudy(BenchmarkStudy): def create_benchmark(self, key): dataset_key = self.cases[key]["dataset"] recording, gt_sorting = self.datasets[dataset_key] - gt_positions = self.cases[key]["gt_positions"] + init_kwargs = self.cases[key]["init_kwargs"] params = self.cases[key]["params"] - benchmark = UnitLocalizationBenchmark(recording, gt_sorting, gt_positions, params) + benchmark = UnitLocalizationBenchmark(recording, gt_sorting, params, **init_kwargs) return benchmark def plot_template_errors(self, case_keys=None): From a46c1cacaa82294e43234a0429b1ced05c37c8a7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 22 Feb 2024 20:07:13 +0100 Subject: [PATCH 127/192] Make return_scaled=True more easy when float32 and no gain_to_uV property --- src/spikeinterface/core/analyzer_extension_core.py | 2 +- src/spikeinterface/core/baserecording.py | 13 +++++++++---- .../postprocessing/spike_amplitudes.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index f6d7399c4e..d952c93f75 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -98,7 +98,7 @@ def _set_params( if return_scaled: # check if has scaled values: - if not recording.has_scaled(): + if not recording.has_scaled() and recording.get_dtype().kind == 'i': print("Setting 'return_scaled' to False") return_scaled = False diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b65409e033..b834cbac96 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -310,10 +310,15 @@ def get_traces( warnings.warn(message) if not self.has_scaled(): - raise ValueError( - "This recording does not support return_scaled=True (need gain_to_uV and offset_" - "to_uV properties)" - ) + if self._dtype.kind == 'f': + # here we do not truely have scale but we assume this is scaled + # this helps a lot for simulated data + pass + else: + raise ValueError( + "This recording does not support return_scaled=True (need gain_to_uV and offset_" + "to_uV properties)" + ) else: gains = self.get_property("gain_to_uV") offsets = self.get_property("offset_to_uV") diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 1899951dc5..68894e3646 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -98,7 +98,7 @@ def _get_pipeline_nodes(self): if return_scaled: # check if has scaled values: - if not recording.has_scaled_traces(): + if not recording.has_scaled_traces() and recording.get_dtype().kind == 'i': warnings.warn("Recording doesn't have scaled traces! Setting 'return_scaled' to False") return_scaled = False From 72418347f9b1aace126ac07dc6dbc0625d7932bd Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 22 Feb 2024 20:07:48 +0100 Subject: [PATCH 128/192] more clena in MotionInterpolationBenchmark --- src/spikeinterface/core/sortinganalyzer.py | 3 +- .../benchmark_motion_interpolation.py | 775 ++++-------------- 2 files changed, 153 insertions(+), 625 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 5eaa165850..67a8e673e8 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -132,7 +132,6 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto"): The loaded SortingAnalyzer """ - return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format) @@ -577,12 +576,14 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> elif format == "binary_folder": # create a new folder assert folder is not None, "For format='binary_folder' folder must be provided" + folder = Path(folder) SortingAnalyzer.create_binary_folder(folder, sorting_provenance, recording, sparsity, self.rec_attributes) new_sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) new_sorting_analyzer.folder = folder elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" + folder = Path(folder) SortingAnalyzer.create_zarr(folder, sorting_provenance, recording, sparsity, self.rec_attributes) new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) new_sorting_analyzer.folder = folder diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index a3b3b4172f..d2b83f181a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -7,25 +7,18 @@ import shutil -# from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten from spikeinterface.sorters import run_sorter, read_sorter_folder -# from spikeinterface.comparison import GroundTruthComparison +from spikeinterface.comparison import GroundTruthComparison from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording -# from spikeinterface.qualitymetrics import compute_quality_metrics -# from spikeinterface.curation import MergeUnitsSorting -# from spikeinterface.core import get_template_extremum_channel +from spikeinterface.curation import MergeUnitsSorting -# import sklearn - from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis import matplotlib.pyplot as plt -# import MEArec as mr - class MotionInterpolationBenchmark(Benchmark): @@ -71,14 +64,41 @@ def run(self, **job_kwargs): self.result["sorting"] = sorting - def compute_result(self, **result_params): - pass + def compute_result(self, exhaustive_gt=True, merging_score=0.2): + sorting = self.result["sorting"] # self.result[""] = + comparison = GroundTruthComparison(self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt) + self.result["comparison"] = comparison + self.result["accuracy"] = comparison.get_performance()["accuracy"].values.astype("float32") + + + gt_unit_ids = self.gt_sorting.unit_ids + unit_ids = sorting.unit_ids + + # find best merges + scores = comparison.agreement_scores + to_merge = [] + for gt_unit_id in gt_unit_ids: + (inds,) = np.nonzero(scores.loc[gt_unit_id, :].values > merging_score) + merge_ids = unit_ids[inds] + if merge_ids.size > 1: + to_merge.append(list(merge_ids)) + + merged_sporting = MergeUnitsSorting(sorting, to_merge) + comparison_merged = GroundTruthComparison(self.gt_sorting, merged_sporting, exhaustive_gt=True) + + self.result["comparison_merged"] = comparison_merged + self.result["accuracy_merged"] = comparison_merged.get_performance()["accuracy"].values.astype("float32") _run_key_saved = [ + ("sorting", "sorting"), ] _result_key_saved = [ + ("comparison", "pickle"), + ("accuracy", "npy"), + ("comparison_merged", "pickle"), + ("accuracy_merged", "npy"), ] @@ -98,619 +118,126 @@ def create_benchmark(self, key): return benchmark + def plot_sorting_accuracy(self, case_keys=None, mode="ordered_accuracy", legend=True, colors=None, + mode_best_merge=False, figsize=(10, 5), ax=None, axes=None): + if case_keys is None: + case_keys = list(self.cases.keys()) -# class BenchmarkMotionInterpolationMearec(BenchmarkBase): -# _array_names = ("gt_motion", "estimated_motion", "temporal_bins", "spatial_bins") -# _waveform_names = ("static", "drifting", "corrected_gt", "corrected_estimated") -# _sorting_names = () - -# _array_names_from_parent = () -# _waveform_names_from_parent = ("static", "drifting") -# _sorting_names_from_parent = ("static", "drifting") - -# def __init__( -# self, -# mearec_filename_drifting, -# mearec_filename_static, -# gt_motion, -# estimated_motion, -# temporal_bins, -# spatial_bins, -# do_preprocessing=True, -# correct_motion_kwargs={}, -# waveforms_kwargs=dict( -# ms_before=1.0, -# ms_after=3.0, -# max_spikes_per_unit=500, -# ), -# sparse_kwargs=dict( -# method="radius", -# radius_um=100.0, -# ), -# sorter_cases={}, -# folder=None, -# title="", -# job_kwargs={"chunk_duration": "1s", "n_jobs": -1, "progress_bar": True, "verbose": True}, -# overwrite=False, -# delete_output_folder=True, -# parent_benchmark=None, -# ): -# BenchmarkBase.__init__( -# self, -# folder=folder, -# title=title, -# overwrite=overwrite, -# job_kwargs=job_kwargs, -# parent_benchmark=parent_benchmark, -# ) - -# self._args.extend([str(mearec_filename_drifting), str(mearec_filename_static), None, None, None, None]) - -# self.sorter_cases = sorter_cases.copy() -# self.mearec_filenames = {} -# self.keys = ["static", "drifting", "corrected_gt", "corrected_estimated"] -# self.mearec_filenames["drifting"] = mearec_filename_drifting -# self.mearec_filenames["static"] = mearec_filename_static - -# self.temporal_bins = temporal_bins -# self.spatial_bins = spatial_bins -# self.gt_motion = gt_motion -# self.estimated_motion = estimated_motion -# self.do_preprocessing = do_preprocessing -# self.delete_output_folder = delete_output_folder - -# self._recordings = None -# _, self.sorting_gt = read_mearec(self.mearec_filenames["static"]) - -# self.correct_motion_kwargs = correct_motion_kwargs.copy() -# self.sparse_kwargs = sparse_kwargs.copy() -# self.waveforms_kwargs = waveforms_kwargs.copy() -# self.comparisons = {} -# self.accuracies = {} - -# self._kwargs.update( -# dict( -# correct_motion_kwargs=self.correct_motion_kwargs, -# sorter_cases=self.sorter_cases, -# do_preprocessing=do_preprocessing, -# delete_output_folder=delete_output_folder, -# waveforms_kwargs=waveforms_kwargs, -# sparse_kwargs=sparse_kwargs, -# ) -# ) - -# @property -# def recordings(self): -# if self._recordings is None: -# self._recordings = {} - -# for key in ( -# "drifting", -# "static", -# ): -# rec, _ = read_mearec(self.mearec_filenames[key]) -# self._recordings["raw_" + key] = rec - -# if self.do_preprocessing: -# # this processing chain is the same as the kilosort2.5 -# # this is important if we want to skip the kilosort preprocessing -# # * all computation are done in float32 -# # * 150um is more or less 30 channels for the whittening -# # * the lastet gain step is super important it is what KS2.5 is doing because the whiten traces -# # have magnitude around 1 so a factor (200) is needed to go back to int16 -# rec = common_reference(rec, dtype="float32") -# rec = highpass_filter(rec, freq_min=150.0) -# rec = whiten(rec, mode="local", radius_um=150.0, num_chunks_per_segment=40, chunk_size=32000) -# rec = scale(rec, gain=200, dtype="int16") -# self._recordings[key] = rec - -# rec = self._recordings["drifting"] -# self._recordings["corrected_gt"] = InterpolateMotionRecording( -# rec, self.gt_motion, self.temporal_bins, self.spatial_bins, **self.correct_motion_kwargs -# ) - -# self._recordings["corrected_estimated"] = InterpolateMotionRecording( -# rec, self.estimated_motion, self.temporal_bins, self.spatial_bins, **self.correct_motion_kwargs -# ) - -# return self._recordings - -# def run(self): -# self.extract_waveforms() -# self.save_to_folder() -# self.run_sorters() -# self.save_to_folder() - -# def extract_waveforms(self): -# # the sparsity is estimated on the static recording and propagated to all of then -# if self.parent_benchmark is None: -# wf_kwargs = self.waveforms_kwargs.copy() -# wf_kwargs.pop("max_spikes_per_unit", None) -# sparsity = precompute_sparsity( -# self.recordings["static"], -# self.sorting_gt, -# num_spikes_for_sparsity=200.0, -# unit_batch_size=10000, -# **wf_kwargs, -# **self.sparse_kwargs, -# **self.job_kwargs, -# ) -# else: -# sparsity = self.waveforms["static"].sparsity - -# for key in self.keys: -# if self.parent_benchmark is not None and key in self._waveform_names_from_parent: -# continue - -# waveforms_folder = self.folder / "waveforms" / key -# we = WaveformExtractor.create( -# self.recordings[key], -# self.sorting_gt, -# waveforms_folder, -# mode="folder", -# sparsity=sparsity, -# remove_if_exists=True, -# ) -# we.set_params(**self.waveforms_kwargs, return_scaled=True) -# we.run_extract_waveforms(seed=22051977, **self.job_kwargs) -# self.waveforms[key] = we - -# def run_sorters(self, skip_already_done=True): -# for case in self.sorter_cases: -# label = case["label"] -# print("run sorter", label) -# sorter_name = case["sorter_name"] -# sorter_params = case["sorter_params"] -# recording = self.recordings[case["recording"]] -# output_folder = self.folder / f"tmp_sortings_{label}" -# if output_folder.exists() and skip_already_done: -# print("already done") -# sorting = read_sorter_folder(output_folder) -# else: -# sorting = run_sorter( -# sorter_name, -# recording, -# output_folder, -# **sorter_params, -# delete_output_folder=self.delete_output_folder, -# ) -# self.sortings[label] = sorting - -# def compute_distances_to_static(self, force=False): -# if hasattr(self, "distances") and not force: -# return self.distances - -# self.distances = {} - -# n = len(self.waveforms["static"].unit_ids) - -# sparsity = self.waveforms["static"].sparsity - -# ref_templates = self.waveforms["static"].get_all_templates() - -# for key in self.keys: -# if self.parent_benchmark is not None and key in ("drifting", "static"): -# continue - -# print(key) -# dist = self.distances[key] = { -# "std": np.zeros(n), -# "norm_std": np.zeros(n), -# "template_norm_distance": np.zeros(n), -# "template_cosine": np.zeros(n), -# } - -# templates = self.waveforms[key].get_all_templates() - -# extremum_channel = get_template_extremum_channel(self.waveforms["static"], outputs="index") - -# for unit_ind, unit_id in enumerate(self.waveforms[key].sorting.unit_ids): -# mask = sparsity.mask[unit_ind, :] -# ref_template = ref_templates[unit_ind][:, mask] -# template = templates[unit_ind][:, mask] - -# max_chan = extremum_channel[unit_id] -# max_chan - -# max_chan_sparse = list(np.nonzero(mask)[0]).index(max_chan) - -# # this is already sparse -# wfs = self.waveforms[key].get_waveforms(unit_id) -# ref_wfs = self.waveforms["static"].get_waveforms(unit_id) - -# rms = np.sqrt(np.mean(template**2)) -# ref_rms = np.sqrt(np.mean(ref_template**2)) -# if rms == 0: -# print(key, unit_id, unit_ind, rms, ref_rms) - -# dist["std"][unit_ind] = np.mean(np.std(wfs, axis=0), axis=(0, 1)) -# dist["norm_std"][unit_ind] = np.mean(np.std(wfs, axis=0), axis=(0, 1)) / rms -# dist["template_norm_distance"][unit_ind] = np.sum((ref_template - template) ** 2) / ref_rms -# dist["template_cosine"][unit_ind] = sklearn.metrics.pairwise.cosine_similarity( -# ref_template.reshape(1, -1), template.reshape(1, -1) -# )[0] - -# return self.distances - -# def compute_residuals(self, force=True): -# fr = int(self.recordings["static"].get_sampling_frequency()) -# duration = int(self.recordings["static"].get_total_duration()) - -# t_start = 0 -# t_stop = duration - -# if hasattr(self, "residuals") and not force: -# return self.residuals, (t_start, t_stop) - -# self.residuals = {} - -# for key in ["corrected"]: -# difference = ResidualRecording(self.recordings["static"], self.recordings[key]) -# self.residuals[key] = np.zeros((self.recordings["static"].get_num_channels(), 0)) - -# for i in np.arange(t_start * fr, t_stop * fr, fr): -# data = np.linalg.norm(difference.get_traces(start_frame=i, end_frame=i + fr), axis=0) / np.sqrt(fr) -# self.residuals[key] = np.hstack((self.residuals[key], data[:, np.newaxis])) - -# return self.residuals, (t_start, t_stop) - -# def compute_accuracies(self): -# for case in self.sorter_cases: -# label = case["label"] -# sorting = self.sortings[label] -# if label not in self.comparisons: -# comp = GroundTruthComparison(self.sorting_gt, sorting, exhaustive_gt=True) -# self.comparisons[label] = comp -# self.accuracies[label] = comp.get_performance()["accuracy"].values - -# def _plot_accuracy( -# self, accuracies, mode="ordered_accuracy", figsize=(15, 5), axes=None, ax=None, ls="-", legend=True, colors=None -# ): -# if len(self.accuracies) != len(self.sorter_cases): -# self.compute_accuracies() - -# n = len(self.sorter_cases) - -# if "depth" in mode: -# # gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filenames['drifting'], time_vector=np.array([0., 1.])) -# # unit_depth = gt_unit_positions[0, :] - -# template_locations = np.array(mr.load_recordings(self.mearec_filenames["drifting"]).template_locations) -# assert len(template_locations.shape) == 3 -# mid = template_locations.shape[1] // 2 -# unit_depth = template_locations[:, mid, 2] - -# chan_locations = self.recordings["drifting"].get_channel_locations() - -# if mode == "ordered_accuracy": -# if ax is None: -# fig, ax = plt.subplots(figsize=figsize) -# else: -# fig = ax.figure - -# order = None -# for i, case in enumerate(self.sorter_cases): -# color = colors[i] if colors is not None else None -# label = case["label"] -# # comp = self.comparisons[label] -# acc = accuracies[label] -# order = np.argsort(acc)[::-1] -# acc = acc[order] -# ax.plot(acc, label=label, ls=ls, color=color) -# if legend: -# ax.legend() -# ax.set_ylabel("accuracy") -# ax.set_xlabel("units ordered by accuracy") - -# elif mode == "depth_snr": -# if axes is None: -# fig, axs = plt.subplots(nrows=n, figsize=figsize, sharey=True, sharex=True) -# else: -# fig = axes[0].figure -# axs = axes - -# metrics = compute_quality_metrics(self.waveforms["static"], metric_names=["snr"], load_if_exists=True) -# snr = metrics["snr"].values - -# for i, case in enumerate(self.sorter_cases): -# ax = axs[i] -# label = case["label"] -# acc = accuracies[label] - -# points = ax.scatter(unit_depth, snr, c=acc) -# points.set_clim(0.0, 1.0) -# ax.set_title(label) -# ax.axvline(np.min(chan_locations[:, 1]), ls="--", color="k") -# ax.axvline(np.max(chan_locations[:, 1]), ls="--", color="k") -# ax.set_ylabel("snr") -# ax.set_xlabel("depth") - -# cbar = fig.colorbar(points, ax=axs[:], location="right", shrink=0.6) -# cbar.ax.set_ylabel("accuracy") - -# elif mode == "snr": -# fig, ax = plt.subplots(figsize=figsize) - -# metrics = compute_quality_metrics(self.waveforms["static"], metric_names=["snr"], load_if_exists=True) -# snr = metrics["snr"].values - -# for i, case in enumerate(self.sorter_cases): -# label = case["label"] -# acc = self.accuracies[label] -# ax.scatter(snr, acc, label=label) -# ax.set_xlabel("snr") -# ax.set_ylabel("accuracy") - -# ax.legend() - -# elif mode == "depth": -# fig, ax = plt.subplots(figsize=figsize) - -# for i, case in enumerate(self.sorter_cases): -# label = case["label"] -# acc = accuracies[label] - -# ax.scatter(unit_depth, acc, label=label) -# ax.axvline(np.min(chan_locations[:, 1]), ls="--", color="k") -# ax.axvline(np.max(chan_locations[:, 1]), ls="--", color="k") -# ax.legend() -# ax.set_xlabel("depth") -# ax.set_ylabel("accuracy") - -# return fig - -# def plot_sortings_accuracy(self, **kwargs): -# if len(self.accuracies) != len(self.sorter_cases): -# self.compute_accuracies() - -# return self._plot_accuracy(self.accuracies, ls="-", **kwargs) - -# def plot_best_merges_accuracy(self, **kwargs): -# return self._plot_accuracy(self.merged_accuracies, **kwargs, ls="--") - -# def plot_sorting_units_categories(self): -# if len(self.accuracies) != len(self.sorter_cases): -# self.compute_accuracies() - -# for i, case in enumerate(self.sorter_cases): -# label = case["label"] -# comp = self.comparisons[label] -# count = comp.count_units_categories() -# if i == 0: -# df = pd.DataFrame(columns=count.index) -# df.loc[label, :] = count -# df.plot.bar() - -# def find_best_merges(self, merging_score=0.2): -# # this find best merges having the ground truth - -# self.merged_sortings = {} -# self.merged_comparisons = {} -# self.merged_accuracies = {} -# self.units_to_merge = {} -# for i, case in enumerate(self.sorter_cases): -# label = case["label"] -# # print() -# # print(label) -# gt_unit_ids = self.sorting_gt.unit_ids -# sorting = self.sortings[label] -# unit_ids = sorting.unit_ids - -# comp = self.comparisons[label] -# scores = comp.agreement_scores - -# to_merge = [] -# for gt_unit_id in gt_unit_ids: -# (inds,) = np.nonzero(scores.loc[gt_unit_id, :].values > merging_score) -# merge_ids = unit_ids[inds] -# if merge_ids.size > 1: -# to_merge.append(list(merge_ids)) - -# self.units_to_merge[label] = to_merge -# merged_sporting = MergeUnitsSorting(sorting, to_merge) -# comp_merged = GroundTruthComparison(self.sorting_gt, merged_sporting, exhaustive_gt=True) - -# self.merged_sortings[label] = merged_sporting -# self.merged_comparisons[label] = comp_merged -# self.merged_accuracies[label] = comp_merged.get_performance()["accuracy"].values - - -# def plot_distances_to_static(benchmarks, metric="cosine", figsize=(15, 10)): -# fig = plt.figure(figsize=figsize) -# gs = fig.add_gridspec(4, 2) - -# ax = fig.add_subplot(gs[0:2, 0]) -# for count, bench in enumerate(benchmarks): -# distances = bench.compute_distances_to_static(force=False) -# print(distances.keys()) -# ax.scatter( -# distances["drifting"][f"template_{metric}"], -# distances["corrected"][f"template_{metric}"], -# c=f"C{count}", -# alpha=0.5, -# label=bench.title, -# ) - -# ax.legend() - -# xmin, xmax = ax.get_xlim() -# ax.plot([xmin, xmax], [xmin, xmax], "k--") -# _simpleaxis(ax) -# if metric == "euclidean": -# ax.set_xlabel(r"$\|drift - static\|_2$") -# ax.set_ylabel(r"$\|corrected - static\|_2$") -# elif metric == "cosine": -# ax.set_xlabel(r"$cosine(drift, static)$") -# ax.set_ylabel(r"$cosine(corrected, static)$") - -# recgen = mr.load_recordings(benchmarks[0].mearec_filenames["static"]) -# nb_templates, nb_versions, _ = recgen.template_locations.shape -# template_positions = recgen.template_locations[:, nb_versions // 2, 1:3] -# distances_to_center = template_positions[:, 1] - -# ax_1 = fig.add_subplot(gs[0, 1]) -# ax_2 = fig.add_subplot(gs[1, 1]) -# ax_3 = fig.add_subplot(gs[2:, 1]) -# ax_4 = fig.add_subplot(gs[2:, 0]) - -# for count, bench in enumerate(benchmarks): -# # results = bench._compute_snippets_variability(metric=metric, num_channels=num_channels) -# distances = bench.compute_distances_to_static(force=False) - -# m_differences = distances["corrected"][f"wf_{metric}_mean"] / distances["static"][f"wf_{metric}_mean"] -# s_differences = distances["corrected"][f"wf_{metric}_std"] / distances["static"][f"wf_{metric}_std"] - -# ax_3.bar([count], [m_differences.mean()], yerr=[m_differences.std()], color=f"C{count}") -# ax_4.bar([count], [s_differences.mean()], yerr=[s_differences.std()], color=f"C{count}") -# idx = np.argsort(distances_to_center) -# ax_1.scatter(distances_to_center[idx], m_differences[idx], color=f"C{count}") -# ax_2.scatter(distances_to_center[idx], s_differences[idx], color=f"C{count}") - -# for a in [ax_1, ax_2, ax_3, ax_4]: -# _simpleaxis(a) - -# if metric == "euclidean": -# ax_1.set_ylabel(r"$\Delta mean(\|~\|_2)$ (% static)") -# ax_2.set_ylabel(r"$\Delta std(\|~\|_2)$ (% static)") -# ax_3.set_ylabel(r"$\Delta mean(\|~\|_2)$ (% static)") -# ax_4.set_ylabel(r"$\Delta std(\|~\|_2)$ (% static)") -# elif metric == "cosine": -# ax_1.set_ylabel(r"$\Delta mean(cosine)$ (% static)") -# ax_2.set_ylabel(r"$\Delta std(cosine)$ (% static)") -# ax_3.set_ylabel(r"$\Delta mean(cosine)$ (% static)") -# ax_4.set_ylabel(r"$\Delta std(cosine)$ (% static)") -# ax_3.set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) -# ax_4.set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) -# xmin, xmax = ax_3.get_xlim() -# ax_3.plot([xmin, xmax], [1, 1], "k--") -# ax_4.plot([xmin, xmax], [1, 1], "k--") -# ax_1.set_xticks([]) -# ax_2.set_xlabel("depth (um)") - -# xmin, xmax = ax_1.get_xlim() -# ax_1.plot([xmin, xmax], [1, 1], "k--") -# ax_2.plot([xmin, xmax], [1, 1], "k--") -# plt.tight_layout() - - -# def plot_snr_decrease(benchmarks, figsize=(15, 10)): -# fig, axes = plt.subplots(2, 2, figsize=figsize, squeeze=False) - -# recgen = mr.load_recordings(benchmarks[0].mearec_filenames["static"]) -# nb_templates, nb_versions, _ = recgen.template_locations.shape -# template_positions = recgen.template_locations[:, nb_versions // 2, 1:3] -# distances_to_center = template_positions[:, 1] -# idx = np.argsort(distances_to_center) -# _simpleaxis(axes[0, 0]) - -# snr_static = compute_quality_metrics(benchmarks[0].waveforms["static"], metric_names=["snr"], load_if_exists=True) -# snr_drifting = compute_quality_metrics( -# benchmarks[0].waveforms["drifting"], metric_names=["snr"], load_if_exists=True -# ) - -# m = np.max(snr_static) -# axes[0, 0].scatter(snr_static.values, snr_drifting.values, c="0.5") -# axes[0, 0].plot([0, m], [0, m], color="k") - -# axes[0, 0].set_ylabel("units SNR for drifting") -# _simpleaxis(axes[0, 0]) - -# axes[0, 1].plot(distances_to_center[idx], (snr_drifting.values / snr_static.values)[idx], c="0.5") -# axes[0, 1].plot(distances_to_center[idx], np.ones(len(idx)), "k--") -# _simpleaxis(axes[0, 1]) -# axes[0, 1].set_xticks([]) -# axes[0, 0].set_xticks([]) - -# for count, bench in enumerate(benchmarks): -# snr_corrected = compute_quality_metrics(bench.waveforms["corrected"], metric_names=["snr"], load_if_exists=True) -# axes[1, 0].scatter(snr_static.values, snr_corrected.values, label=bench.title) -# axes[1, 0].plot([0, m], [0, m], color="k") - -# axes[1, 1].plot(distances_to_center[idx], (snr_corrected.values / snr_static.values)[idx], c=f"C{count}") - -# axes[1, 0].set_xlabel("units SNR for static") -# axes[1, 0].set_ylabel("units SNR for corrected") -# axes[1, 1].plot(distances_to_center[idx], np.ones(len(idx)), "k--") -# axes[1, 0].legend() -# _simpleaxis(axes[1, 0]) -# _simpleaxis(axes[1, 1]) -# axes[1, 1].set_ylabel(r"$\Delta(SNR)$") -# axes[0, 1].set_ylabel(r"$\Delta(SNR)$") - -# axes[1, 1].set_xlabel("depth (um)") - - -# def plot_residuals_comparisons(benchmarks): -# fig, axes = plt.subplots(1, 3, figsize=(15, 5)) -# for count, bench in enumerate(benchmarks): -# residuals, (t_start, t_stop) = bench.compute_residuals(force=False) -# time_axis = np.arange(t_start, t_stop) -# axes[0].plot(time_axis, residuals["corrected"].mean(0), label=bench.title) -# axes[0].legend() -# axes[0].set_xlabel("time (s)") -# axes[0].set_ylabel(r"$|S_{corrected} - S_{static}|$") -# _simpleaxis(axes[0]) - -# channel_positions = benchmarks[0].recordings["static"].get_channel_locations() -# distances_to_center = channel_positions[:, 1] -# idx = np.argsort(distances_to_center) - -# for count, bench in enumerate(benchmarks): -# residuals, (t_start, t_stop) = bench.compute_residuals(force=False) -# time_axis = np.arange(t_start, t_stop) -# axes[1].plot( -# distances_to_center[idx], residuals["corrected"].mean(1)[idx], label=bench.title, lw=2, c=f"C{count}" -# ) -# axes[1].fill_between( -# distances_to_center[idx], -# residuals["corrected"].mean(1)[idx] - residuals["corrected"].std(1)[idx], -# residuals["corrected"].mean(1)[idx] + residuals["corrected"].std(1)[idx], -# color=f"C{count}", -# alpha=0.25, -# ) -# axes[1].set_xlabel("depth (um)") -# _simpleaxis(axes[1]) - -# for count, bench in enumerate(benchmarks): -# residuals, (t_start, t_stop) = bench.compute_residuals(force=False) -# axes[2].bar([count], [residuals["corrected"].mean()], yerr=[residuals["corrected"].std()], color=f"C{count}") - -# _simpleaxis(axes[2]) -# axes[2].set_xticks(np.arange(len(benchmarks)), [i.title for i in benchmarks]) - - -# from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment - - -# class ResidualRecording(BasePreprocessor): -# name = "residual_recording" - -# def __init__(self, recording_1, recording_2): -# assert recording_1.get_num_segments() == recording_2.get_num_segments() -# BasePreprocessor.__init__(self, recording_1) - -# for parent_recording_segment_1, parent_recording_segment_2 in zip( -# recording_1._recording_segments, recording_2._recording_segments -# ): -# rec_segment = DifferenceRecordingSegment(parent_recording_segment_1, parent_recording_segment_2) -# self.add_recording_segment(rec_segment) - -# self._kwargs = dict(recording_1=recording_1, recording_2=recording_2) - - -# class DifferenceRecordingSegment(BasePreprocessorSegment): -# def __init__(self, parent_recording_segment_1, parent_recording_segment_2): -# BasePreprocessorSegment.__init__(self, parent_recording_segment_1) -# self.parent_recording_segment_1 = parent_recording_segment_1 -# self.parent_recording_segment_2 = parent_recording_segment_2 - -# def get_traces(self, start_frame, end_frame, channel_indices): -# traces_1 = self.parent_recording_segment_1.get_traces(start_frame, end_frame, channel_indices) -# traces_2 = self.parent_recording_segment_2.get_traces(start_frame, end_frame, channel_indices) - -# return traces_2 - traces_1 - - -# colors = {"static": "C0", "drifting": "C1", "corrected": "C2"} + if not mode_best_merge: + ls = '-' + else: + ls = '--' + + if mode == "ordered_accuracy": + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig = ax.figure + + order = None + for i, key in enumerate(case_keys): + result = self.get_result(key) + if not mode_best_merge: + accuracy = result["accuracy"] + else: + accuracy = result["accuracy_merged"] + label = self.cases[key]["label"] + color = colors[i] if colors is not None else None + order = np.argsort(accuracy)[::-1] + accuracy = accuracy[order] + ax.plot(accuracy, label=label, ls=ls, color=color) + if legend: + ax.legend() + ax.set_ylabel("accuracy") + ax.set_xlabel("units ordered by accuracy") + + elif mode == "depth_snr": + if axes is None: + fig, axs = plt.subplots(nrows=len(case_keys), figsize=figsize, sharey=True, sharex=True) + else: + fig = axes[0].figure + axs = axes + + for i, key in enumerate(case_keys): + ax = axs[i] + result = self.get_result(key) + if not mode_best_merge: + accuracy = result["accuracy"] + else: + accuracy = result["accuracy_merged"] + label = self.cases[key]["label"] + + analyzer = self.get_sorting_analyzer(key) + ext = analyzer.get_extension("unit_locations") + if ext is None: + ext = analyzer.compute("unit_locations") + unit_locations = ext.get_data() + unit_depth = unit_locations[:, 1] + + snr= analyzer.get_extension("quality_metrics").get_data()["snr"].values + + points = ax.scatter(unit_depth, snr, c=accuracy) + points.set_clim(0.0, 1.0) + ax.set_title(label) + + chan_locations = analyzer.get_channel_locations() + + ax.axvline(np.min(chan_locations[:, 1]), ls="--", color="k") + ax.axvline(np.max(chan_locations[:, 1]), ls="--", color="k") + ax.set_ylabel("snr") + ax.set_xlabel("depth") + + cbar = fig.colorbar(points, ax=axs[:], location="right", shrink=0.6) + cbar.ax.set_ylabel("accuracy") + + elif mode == "snr": + fig, ax = plt.subplots(figsize=figsize) + + for i, key in enumerate(case_keys): + result = self.get_result(key) + label = self.cases[key]["label"] + if not mode_best_merge: + accuracy = result["accuracy"] + else: + accuracy = result["accuracy_merged"] + + analyzer = self.get_sorting_analyzer(key) + snr= analyzer.get_extension("quality_metrics").get_data()["snr"].values + + ax.scatter(snr, accuracy, label=label) + ax.set_xlabel("snr") + ax.set_ylabel("accuracy") + + ax.legend() + + elif mode == "depth": + fig, ax = plt.subplots(figsize=figsize) + + for i, key in enumerate(case_keys): + result = self.get_result(key) + label = self.cases[key]["label"] + if not mode_best_merge: + accuracy = result["accuracy"] + else: + accuracy = result["accuracy_merged"] + analyzer = self.get_sorting_analyzer(key) + + ext = analyzer.get_extension("unit_locations") + if ext is None: + ext = analyzer.compute("unit_locations") + unit_locations = ext.get_data() + unit_depth = unit_locations[:, 1] + + ax.scatter(unit_depth, accuracy, label=label) + + chan_locations = analyzer.get_channel_locations() + + ax.axvline(np.min(chan_locations[:, 1]), ls="--", color="k") + ax.axvline(np.max(chan_locations[:, 1]), ls="--", color="k") + ax.legend() + ax.set_xlabel("depth") + ax.set_ylabel("accuracy") + + return fig From 3ad54a8154948fa1f2e6896290a0d1f342bf447b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 22 Feb 2024 20:08:11 +0100 Subject: [PATCH 129/192] Implement run_times in BenchmarkStudy. --- .../benchmark/benchmark_tools.py | 210 +++++------------- 1 file changed, 56 insertions(+), 154 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 868ecd58db..798523241d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -4,6 +4,9 @@ import shutil import json import numpy as np +import pandas as pd + +import time import os @@ -16,7 +19,17 @@ class BenchmarkStudy: """ - Manage a list of Benchmark + Generic study for sorting components. + This manage a list of Benchmark. + This manage a dict of "cases" every case is one Benchmark. + + Benchmark is responsible for run() and compute_result() + BenchmarkStudy is the main API for: + * running (re-running) some cases + * save (run + compute_result) in results dict + * make some plots in inherited classes. + + """ benchmark_class = None def __init__(self, study_folder): @@ -156,15 +169,42 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): for key in job_keys: benchmark = self.create_benchmark(key) + t0 = time.perf_counter() benchmark.run() + t1 = time.perf_counter() self.benchmarks[key] = benchmark bench_folder = self.folder / "results" / self.key_to_str(key) bench_folder.mkdir(exist_ok=True) benchmark.save_run(bench_folder) + benchmark.result["run_time"] = float(t1 - t0) + benchmark.save_main(bench_folder) + def get_run_times(self, case_keys=None): + if case_keys is None: + case_keys = list(self.cases.keys()) + + run_times = {} + for key in case_keys: + benchmark = self.benchmarks[key] + assert benchmark is not None + run_times[key] = benchmark.result["run_time"] + + df = pd.DataFrame(dict(run_times=run_times)) + df.index.names = self.levels + return df + + def plot_run_times(self, case_keys=None): + if case_keys is None: + case_keys = list(self.cases.keys()) + run_times = self.get_run_times(case_keys=case_keys) + + run_times.plot(kind='bar') + + + def compute_results(self, case_keys=None, verbose=False, **result_params): if case_keys is None: - case_keys = self.cases.keys() + case_keys = list(self.cases.keys()) job_keys = [] for key in case_keys: @@ -173,7 +213,7 @@ def compute_results(self, case_keys=None, verbose=False, **result_params): benchmark.compute_result(**result_params) benchmark.save_result(self.folder / "results" / self.key_to_str(key)) - def create_sorting_analyzer_gt(self, case_keys=None, **kwargs): + def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, **kwargs): if case_keys is None: case_keys = self.cases.keys() @@ -190,9 +230,9 @@ def create_sorting_analyzer_gt(self, case_keys=None, **kwargs): recording, gt_sorting = self.datasets[dataset_key] sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) sorting_analyzer.select_random_spikes(**select_params) - sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("waveforms", return_scaled=return_scaled, **job_kwargs) sorting_analyzer.compute("templates") - sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("noise_levels", return_scaled=return_scaled) def get_sorting_analyzer(self, case_key=None, dataset_key=None): if case_key is not None: @@ -254,10 +294,16 @@ def get_result(self, key): class Benchmark: """ + Responsible to make a unique run() and compute_result() for one case. """ def __init__(self): self.result = {} + # this must not be changed in inherited + _main_key_saved = [ + ("run_time", "pickle"), + ] + # this must be updated in hirerited _run_key_saved = [] _result_key_saved = [] @@ -277,6 +323,10 @@ def _save_keys(self, saved_keys, folder): else: raise ValueError(f"Save error {k} {format}") + def save_main(self, folder): + # used for run time + self._save_keys(self._main_key_saved, folder) + def save_run(self, folder): self._save_keys(self._run_key_saved, folder) @@ -286,7 +336,7 @@ def save_result(self, folder): @classmethod def load_folder(cls, folder): result = {} - for k, format in cls._run_key_saved + cls._result_key_saved: + for k, format in cls._run_key_saved + cls._result_key_saved + cls._main_key_saved: if format == "npy": file = folder / f"{k}.npy" if file.exists(): @@ -314,156 +364,8 @@ def compute_result(self): raise NotImplementedError - - - - def _simpleaxis(ax): ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.get_xaxis().tick_bottom() ax.get_yaxis().tick_left() - - -# class BenchmarkBaseOld: -# _array_names = () -# _waveform_names = () -# _sorting_names = () - -# _array_names_from_parent = () -# _waveform_names_from_parent = () -# _sorting_names_from_parent = () - -# def __init__( -# self, -# folder=None, -# title="", -# overwrite=None, -# job_kwargs={"chunk_duration": "1s", "n_jobs": -1, "progress_bar": True, "verbose": True}, -# parent_benchmark=None, -# ): -# self.folder = Path(folder) -# self.title = title -# self.overwrite = overwrite -# self.job_kwargs = job_kwargs -# self.run_times = None - -# self._args = [] -# self._kwargs = dict(title=title, overwrite=overwrite, job_kwargs=job_kwargs) - -# self.waveforms = {} -# self.sortings = {} - -# self.parent_benchmark = parent_benchmark - -# if self.parent_benchmark is not None: -# for name in self._array_names_from_parent: -# setattr(self, name, getattr(parent_benchmark, name)) - -# for name in self._waveform_names_from_parent: -# self.waveforms[name] = parent_benchmark.waveforms[name] - -# for key in parent_benchmark.sortings.keys(): -# if isinstance(key, str) and key in self._sorting_names_from_parent: -# self.sortings[key] = parent_benchmark.sortings[key] -# elif isinstance(key, tuple) and key[0] in self._sorting_names_from_parent: -# self.sortings[key] = parent_benchmark.sortings[key] - -# def save_to_folder(self): -# if self.folder.exists(): -# import glob, os - -# pattern = "*.*" -# files = self.folder.glob(pattern) -# for file in files: -# if file.is_file(): -# os.remove(file) -# else: -# self.folder.mkdir(parents=True) - -# if self.parent_benchmark is None: -# parent_folder = None -# else: -# parent_folder = str(self.parent_benchmark.folder) - -# info = { -# "args": self._args, -# "kwargs": self._kwargs, -# "parent_folder": parent_folder, -# } -# info = check_json(info) -# (self.folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") - -# for name in self._array_names: -# if self.parent_benchmark is not None and name in self._array_names_from_parent: -# continue -# value = getattr(self, name) -# if value is not None: -# np.save(self.folder / f"{name}.npy", value) - -# if self.run_times is not None: -# run_times_filename = self.folder / "run_times.json" -# run_times_filename.write_text(json.dumps(self.run_times, indent=4), encoding="utf8") - -# for key, sorting in self.sortings.items(): -# (self.folder / "sortings").mkdir(exist_ok=True) -# if isinstance(key, str): -# npz_file = self.folder / "sortings" / (str(key) + ".npz") -# elif isinstance(key, tuple): -# npz_file = self.folder / "sortings" / ("_###_".join(key) + ".npz") -# NpzSortingExtractor.write_sorting(sorting, npz_file) - -# @classmethod -# def load_from_folder(cls, folder, parent_benchmark=None): -# folder = Path(folder) -# assert folder.exists() - -# with open(folder / "info.json", "r") as f: -# info = json.load(f) -# args = info["args"] -# kwargs = info["kwargs"] - -# if info["parent_folder"] is None: -# parent_benchmark = None -# else: -# if parent_benchmark is None: -# parent_benchmark = cls.load_from_folder(info["parent_folder"]) - -# import os - -# kwargs["folder"] = folder - -# bench = cls(*args, **kwargs, parent_benchmark=parent_benchmark) - -# for name in cls._array_names: -# filename = folder / f"{name}.npy" -# if filename.exists(): -# arr = np.load(filename) -# else: -# arr = None -# setattr(bench, name, arr) - -# if (folder / "run_times.json").exists(): -# with open(folder / "run_times.json", "r") as f: -# bench.run_times = json.load(f) -# else: -# bench.run_times = None - -# for key in bench._waveform_names: -# if parent_benchmark is not None and key in bench._waveform_names_from_parent: -# continue -# waveforms_folder = folder / "waveforms" / key -# if waveforms_folder.exists(): -# bench.waveforms[key] = load_waveforms(waveforms_folder, with_recording=True) - -# sorting_folder = folder / "sortings" -# if sorting_folder.exists(): -# for npz_file in sorting_folder.glob("*.npz"): -# name = npz_file.stem -# if "_###_" in name: -# key = tuple(name.split("_###_")) -# else: -# key = name -# bench.sortings[key] = NpzSortingExtractor(npz_file) - -# return bench From 568dcde9fb9e567b788762ed7d538bce25cb2614 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 22 Feb 2024 20:53:15 +0100 Subject: [PATCH 130/192] various fixes --- .../core/waveforms_extractor_backwards_compatibility.py | 2 +- .../sortingcomponents/benchmark/benchmark_tools.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index ecb87c967e..afda9b3967 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -331,7 +331,7 @@ def load_waveforms( folder = Path(folder) assert folder.is_dir(), "Waveform folder does not exists" - if (folder / "spikeinterface_info.json").exists: + if (folder / "spikeinterface_info.json").exists(): with open(folder / "spikeinterface_info.json", mode="r") as f: info = json.load(f) if info.get("object", None) == "SortingAnalyzer": diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 798523241d..e92c6af724 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -190,7 +190,8 @@ def get_run_times(self, case_keys=None): run_times[key] = benchmark.result["run_time"] df = pd.DataFrame(dict(run_times=run_times)) - df.index.names = self.levels + if not isinstance(self.levels, str): + df.index.names = self.levels return df def plot_run_times(self, case_keys=None): From 96d7a301a7cda82d41d5760df9bef3fe5a363032 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 23 Feb 2024 12:50:51 +0100 Subject: [PATCH 131/192] Simple tests for benchmark components. --- .../benchmark/benchmark_matching.py | 8 +- .../tests/common_benchmark_testing.py | 245 ++++++++++++++++++ .../tests/test_benchmark_clustering.py | 71 +++++ .../tests/test_benchmark_matching.py | 70 +++++ .../tests/test_benchmark_motion_estimation.py | 81 ++++++ .../test_benchmark_motion_interpolation.py | 21 ++ .../tests/test_benchmark_peak_localization.py | 21 ++ .../tests/test_benchmark_peak_selection.py | 21 ++ .../tests/test_benchmark_matching.py | 179 ------------- 9 files changed, 534 insertions(+), 183 deletions(-) create mode 100644 src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py create mode 100644 src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py create mode 100644 src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py create mode 100644 src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py create mode 100644 src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py create mode 100644 src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py create mode 100644 src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py delete mode 100644 src/spikeinterface/sortingcomponents/tests/test_benchmark_matching.py diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 23a9f0459c..4a5221e16d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -72,7 +72,7 @@ def create_benchmark(self,key): benchmark = MatchingBenchmark(recording, gt_sorting, params) return benchmark - def plot_agreements(self, case_keys=None, figsize=(15,15)): + def plot_agreements(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -83,7 +83,7 @@ def plot_agreements(self, case_keys=None, figsize=(15,15)): ax.set_title(self.cases[key]['label']) plot_agreement_matrix(self.get_result(key)['gt_comparison'], ax=ax) - def plot_performances_vs_snr(self, case_keys=None, figsize=(15,15)): + def plot_performances_vs_snr(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -105,7 +105,7 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15,15)): if count == 2: ax.legend() - def plot_collisions(self, case_keys=None, figsize=(15,15)): + def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -122,7 +122,7 @@ def plot_comparison_matching(self, case_keys=None, performance_names=["accuracy", "recall", "precision"], colors=["g", "b", "r"], ylim=(-0.1, 1.1), - figsize=(15,15) + figsize=None ): if case_keys is None: diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py new file mode 100644 index 0000000000..091ab0820e --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py @@ -0,0 +1,245 @@ +""" +Important : this benchmark machinery is very heavy. +This is not tested on github because not relevant at all. +This only a local testing. +""" +import pytest +from pathlib import Path +import os + +import numpy as np + +from spikeinterface.core import ( + generate_ground_truth_recording, + generate_templates, + estimate_templates, + Templates, + generate_sorting, + NoiseGeneratorRecording, +) +from spikeinterface.core.generate import generate_unit_locations +from spikeinterface.generation import ( + DriftingTemplates, + make_linear_displacement, + InjectDriftingTemplatesRecording +) + + +from probeinterface import generate_multi_columns_probe + + +ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) + + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "sortingcomponents_benchmark" +else: + cache_folder = Path("cache_folder") / "sortingcomponents_benchmark" + + + +def make_dataset(): + recording, gt_sorting = generate_ground_truth_recording( + durations=[60.0], + sampling_frequency=30000.0, + num_channels=16, + num_units=10, + generate_probe_kwargs=dict( + num_columns=2, + xpitch=20, + ypitch=20, + contact_shapes="circle", + contact_shape_params={"radius": 6}, + ), + generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), + seed=2205, + ) + return recording, gt_sorting + +def compute_gt_templates(recording, gt_sorting, ms_before=2., ms_after=3., return_scaled=False, **job_kwargs): + spikes = gt_sorting.to_spike_vector()#[spike_indices] + fs = recording.sampling_frequency + nbefore = int(ms_before * fs / 1000) + nafter = int(ms_after * fs / 1000) + templates_array = estimate_templates( + recording, spikes, + gt_sorting.unit_ids, nbefore, nafter, return_scaled=return_scaled, + **job_kwargs, + ) + + gt_templates = Templates( + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=gt_sorting.unit_ids, + probe=recording.get_probe(), + ) + return gt_templates + + +def make_drifting_dataset(): + + num_units = 15 + duration = 125.5 + sampling_frequency = 30000. + ms_before = 1. + ms_after = 3. + displacement_sampling_frequency = 5. + + + probe = generate_multi_columns_probe( + num_columns=3, + num_contact_per_column=12, + xpitch=15, + ypitch=15, + contact_shapes="square", + contact_shape_params={"width": 10}, + ) + probe.set_device_channel_indices(np.arange(probe.contact_ids.size)) + + + + channel_locations = probe.contact_positions + + unit_locations = generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=20.0, + max_iteration=100, + distance_strict=False, + seed=None, + ) + + + + nbefore = int(sampling_frequency * ms_before / 1000.) + + generate_kwargs = dict( + sampling_frequency=sampling_frequency, + ms_before=ms_before, + ms_after=ms_after, + seed=2205, + unit_params=dict( + decay_power=np.ones(num_units) * 2, + repolarization_ms=np.ones(num_units) * 0.8, + ), + unit_params_range=dict( + alpha=(4_000., 8_000.), + depolarization_ms=(0.09, 0.16), + + ), + + + ) + templates_array = generate_templates(channel_locations, unit_locations, **generate_kwargs) + + templates = Templates( + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + probe=probe, + ) + + drifting_templates = DriftingTemplates.from_static(templates) + channel_locations = probe.contact_positions + + start = np.array([0, -15.]) + stop = np.array([0, 12]) + displacements = make_linear_displacement(start, stop, num_step=29) + + + sorting = generate_sorting( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations = [duration,], + firing_rates=25.) + sorting + + + + + times = np.arange(0, duration, 1 / displacement_sampling_frequency) + times + + # 2 rythm + mid = (start + stop) / 2 + freq0 = 0.1 + displacement_vector0 = np.sin(2 * np.pi * freq0 *times)[:, np.newaxis] * (start - stop) + mid + # freq1 = 0.01 + # displacement_vector1 = 0.2 * np.sin(2 * np.pi * freq1 *times)[:, np.newaxis] * (start - stop) + mid + + # print() + + displacement_vectors = displacement_vector0[:, :, np.newaxis] + + # TODO gradient + num_motion = displacement_vectors.shape[2] + displacement_unit_factor = np.zeros((num_units, num_motion)) + displacement_unit_factor[:, 0] = 1 + + + drifting_templates.precompute_displacements(displacements) + + direction = 1 + unit_displacements = np.zeros((displacement_vectors.shape[0], num_units)) + for i in range(displacement_vectors.shape[2]): + m = displacement_vectors[:, direction, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :] + unit_displacements[:, :] += m + + noise = NoiseGeneratorRecording( + num_channels=probe.contact_ids.size, + sampling_frequency=sampling_frequency, + durations=[duration], + noise_level=1., + dtype="float32", + ) + + drifting_rec = InjectDriftingTemplatesRecording( + sorting=sorting, + parent_recording=noise, + drifting_templates=drifting_templates, + displacement_vectors=[displacement_vectors], + displacement_sampling_frequency=displacement_sampling_frequency, + displacement_unit_factor=displacement_unit_factor, + num_samples=[int(duration*sampling_frequency)], + amplitude_factor=None, + ) + + static_rec = InjectDriftingTemplatesRecording( + sorting=sorting, + parent_recording=noise, + drifting_templates=drifting_templates, + displacement_vectors=[displacement_vectors], + displacement_sampling_frequency=displacement_sampling_frequency, + displacement_unit_factor=np.zeros_like(displacement_unit_factor), + num_samples=[int(duration*sampling_frequency)], + amplitude_factor=None, + ) + + my_dict = _variable_from_namespace([ + drifting_rec, + static_rec, + sorting, + displacement_vectors, + displacement_sampling_frequency, + unit_locations, displacement_unit_factor, + unit_displacements + ], locals()) + return my_dict + + +def _variable_from_namespace(objs, namespace): + d = dict() + for obj in objs: + for name in namespace: + if namespace[name] is obj: + d[name] = obj + return d + + diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py new file mode 100644 index 0000000000..b60fb963fd --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py @@ -0,0 +1,71 @@ +import pytest + +import spikeinterface.full as si +import pandas as pd +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np + +import shutil + +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder +from spikeinterface.sortingcomponents.benchmark.benchmark_clustering import ClusteringStudy + + + +@pytest.mark.skip() +def test_benchmark_clustering(): + + job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") + + recording, gt_sorting = make_dataset() + + num_spikes = gt_sorting.to_spike_vector().size + spike_indices = np.arange(0, num_spikes, 5) + + + # create study + study_folder = cache_folder / 'study_clustering' + datasets = {"toy" : (recording, gt_sorting)} + cases = {} + for method in ['random_projections', 'circus']: + cases[method] = { + "label": f"{method} on toy", + "dataset": "toy", + "init_kwargs": {'indices' : spike_indices}, + "params" : {"method" : method, "method_kwargs" : {}}, + } + + if study_folder.exists(): + shutil.rmtree(study_folder) + study = ClusteringStudy.create(study_folder, datasets=datasets, cases=cases) + print(study) + + # this study needs analyzer + study.create_sorting_analyzer_gt(**job_kwargs) + study.compute_metrics() + + + study = ClusteringStudy(study_folder) + + # run and result + study.run(**job_kwargs) + study.compute_results() + + # load study to check persistency + study = ClusteringStudy(study_folder) + print(study) + + # plots + study.plot_performances_vs_snr() + # @pierre : This one has a bug + # study.plot_metrics_vs_snr('cosine') + study.homogeneity_score(ignore_noise=False) + plt.show() + + + +if __name__ == "__main__": + test_benchmark_clustering() + + diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py new file mode 100644 index 0000000000..2af8bff1e5 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -0,0 +1,70 @@ +import pytest + +import shutil + +import spikeinterface.full as si +import pandas as pd +from pathlib import Path +import matplotlib.pyplot as plt + +from spikeinterface.core import ( + get_noise_levels, + compute_sparsity, +) + +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder, compute_gt_templates +from spikeinterface.sortingcomponents.benchmark.benchmark_matching import MatchingStudy + + +@pytest.mark.skip() +def test_benchmark_matching(): + + job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") + + recording, gt_sorting = make_dataset() + + # templates sparse + gt_templates = compute_gt_templates(recording, gt_sorting, ms_before=2., ms_after=3., return_scaled=False, **job_kwargs) + noise_levels = get_noise_levels(recording) + sparsity = compute_sparsity(gt_templates, noise_levels, method='ptp', threshold=0.25) + gt_templates = gt_templates.to_sparse(sparsity) + + + # create study + study_folder = cache_folder / 'study_matching' + datasets = {"toy" : (recording, gt_sorting)} + cases = {} + for engine in ['wobble', 'circus-omp-svd',]: + cases[engine] = { + "label": f"{engine} on toy", + "dataset": "toy", + "params" : {"method" : engine, "method_kwargs" : {"templates" : gt_templates}}, + } + if study_folder.exists(): + shutil.rmtree(study_folder) + study = MatchingStudy.create(study_folder, datasets=datasets, cases=cases) + print(study) + + # this study needs analyzer + study.create_sorting_analyzer_gt(**job_kwargs) + study.compute_metrics() + + # run and result + study.run(**job_kwargs) + study.compute_results() + + # load study to check persistency + study = MatchingStudy(study_folder) + print(study) + + # plots + study.plot_performances_vs_snr() + study.plot_agreements() + study.plot_comparison_matching() + plt.show() + + +if __name__ == "__main__": + test_benchmark_matching() + + diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py new file mode 100644 index 0000000000..0f009afa9a --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -0,0 +1,81 @@ +import pytest + +import spikeinterface.full as si +import pandas as pd +from pathlib import Path +import matplotlib.pyplot as plt + +import shutil + +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_drifting_dataset, cache_folder + +from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import MotionEstimationStudy + +@pytest.mark.skip() +def test_benchmark_motion_estimaton(): + + job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") + + + data = make_drifting_dataset() + + datasets = { + "drifting_rec": (data["drifting_rec"], data["sorting"]), + } + + cases = {} + for label, loc_method, est_method in [ + ("COM + KS", "center_of_mass", "iterative_template"), + ("Grid + Dec", "grid_convolution", "decentralized"), + ]: + cases[label] = dict( + label = label, + dataset="drifting_rec", + init_kwargs=dict( + unit_locations=data["unit_locations"], + unit_displacements=data["unit_displacements"], + displacement_sampling_frequency=data["displacement_sampling_frequency"], + direction="y" + ), + params=dict( + detect_kwargs=dict(method="locally_exclusive", detect_threshold=10.), + select_kwargs=None, + localize_kwargs=dict(method=loc_method), + estimate_motion_kwargs=dict( + method=est_method, + bin_duration_s=1., + bin_um=5., + rigid=False, + win_step_um=50., + win_sigma_um=200., + ), + ) + ) + + study_folder = cache_folder / 'study_motion_estimation' + if study_folder.exists(): + shutil.rmtree(study_folder) + study = MotionEstimationStudy.create(study_folder, datasets, cases) + + + # run and result + study.run(**job_kwargs) + study.compute_results() + + # load study to check persistency + study = MotionEstimationStudy(study_folder) + print(study) + + # plots + study.plot_true_drift() + study.plot_errors() + study.plot_summary_errors() + + plt.show() + + +if __name__ == "__main__": + test_benchmark_motion_estimaton() + + + diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py new file mode 100644 index 0000000000..4e6006f539 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -0,0 +1,21 @@ +import pytest + +import spikeinterface.full as si +import pandas as pd +from pathlib import Path +import matplotlib.pyplot as plt + +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder + + +@pytest.mark.skip() +def test_benchmark_motion_interpolation(): + pass + + + + +if __name__ == "__main__": + test_benchmark_motion_interpolation() + + diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py new file mode 100644 index 0000000000..d2e07b7a1b --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py @@ -0,0 +1,21 @@ +import pytest + +import spikeinterface.full as si +import pandas as pd +from pathlib import Path +import matplotlib.pyplot as plt + +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder + + +@pytest.mark.skip() +def test_benchmark_peak_localization(): + pass + + + + +if __name__ == "__main__": + test_benchmark_peak_localization() + + diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py new file mode 100644 index 0000000000..78b59be489 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py @@ -0,0 +1,21 @@ +import pytest + +import spikeinterface.full as si +import pandas as pd +from pathlib import Path +import matplotlib.pyplot as plt + +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder + + +@pytest.mark.skip() +def test_benchmark_peak_selection(): + pass + + + + +if __name__ == "__main__": + test_benchmark_peak_selection() + + diff --git a/src/spikeinterface/sortingcomponents/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/tests/test_benchmark_matching.py deleted file mode 100644 index 6e7f1c0d8a..0000000000 --- a/src/spikeinterface/sortingcomponents/tests/test_benchmark_matching.py +++ /dev/null @@ -1,179 +0,0 @@ -# import pytest -# import numpy as np -# import pandas as pd -# import shutil -# import os -# from pathlib import Path - -# import spikeinterface.core as sc -# import spikeinterface.extractors as se -# import spikeinterface.preprocessing as spre -# from spikeinterface.sortingcomponents.benchmark import benchmark_matching - - -# @pytest.fixture(scope="session") -# def benchmark_and_kwargs(tmp_path_factory): -# recording, sorting = se.toy_example(duration=1, num_channels=2, num_units=2, num_segments=1, firing_rate=10, seed=0) -# recording = spre.common_reference(recording, dtype="float32") -# we_path = tmp_path_factory.mktemp("waveforms") -# sort_path = tmp_path_factory.mktemp("sortings") / ("sorting.npz") -# se.NpzSortingExtractor.write_sorting(sorting, sort_path) -# sorting = se.NpzSortingExtractor(sort_path) -# we = sc.extract_waveforms(recording, sorting, we_path, overwrite=True) -# templates = we.get_all_templates() -# noise_levels = sc.get_noise_levels(recording, return_scaled=False) -# methods_kwargs = { -# "tridesclous": dict(waveform_extractor=we, noise_levels=noise_levels), -# "wobble": dict(templates=templates, nbefore=we.nbefore, nafter=we.nafter, parameters={"approx_rank": 2}), -# } -# methods = list(methods_kwargs.keys()) -# benchmark = benchmark_matching.BenchmarkMatching(recording, sorting, we, methods, methods_kwargs) -# return benchmark, methods_kwargs - - -# @pytest.mark.parametrize( -# "parameters, parameter_name", -# [ -# ([1, 10, 100], "num_spikes"), -# ([0, 0.5, 1], "fraction_misclassed"), -# ([0, 0.5, 1], "fraction_missing"), -# ], -# ) -# def test_run_matching_vary_parameter(benchmark_and_kwargs, parameters, parameter_name): -# # Arrange -# benchmark, methods_kwargs = benchmark_and_kwargs -# num_replicates = 2 - -# # Act -# with benchmark as bmk: -# matching_df = bmk.run_matching_vary_parameter(parameters, parameter_name, num_replicates=num_replicates) - -# # Assert -# assert matching_df.shape[0] == len(parameters) * num_replicates * len(methods_kwargs) -# assert matching_df.shape[1] == 6 - - -# @pytest.mark.parametrize( -# "parameter_name, num_replicates", -# [ -# ("invalid_parameter_name", 1), -# ("num_spikes", -1), -# ("num_spikes", 0.5), -# ], -# ) -# def test_run_matching_vary_parameter_invalid_inputs(benchmark_and_kwargs, parameter_name, num_replicates): -# parameters = [1, 2] -# benchmark, methods_kwargs = benchmark_and_kwargs -# with benchmark as bmk: -# with pytest.raises(ValueError): -# bmk.run_matching_vary_parameter(parameters, parameter_name, num_replicates=num_replicates) - - -# @pytest.mark.parametrize( -# "fraction_misclassed, min_similarity", -# [ -# (-1, -1), -# (2, -1), -# (0, 2), -# ], -# ) -# def test_run_matching_misclassed_invalid_inputs(benchmark_and_kwargs, fraction_misclassed, min_similarity): -# benchmark, methods_kwargs = benchmark_and_kwargs -# with benchmark as bmk: -# with pytest.raises(ValueError): -# bmk.run_matching_misclassed(fraction_misclassed, min_similarity=min_similarity) - - -# @pytest.mark.parametrize( -# "fraction_missing, snr_threshold", -# [ -# (-1, 0), -# (2, 0), -# (0, -1), -# ], -# ) -# def test_run_matching_missing_units_invalid_inputs(benchmark_and_kwargs, fraction_missing, snr_threshold): -# benchmark, methods_kwargs = benchmark_and_kwargs -# with benchmark as bmk: -# with pytest.raises(ValueError): -# bmk.run_matching_missing_units(fraction_missing, snr_threshold=snr_threshold) - - -# def test_compare_all_sortings(benchmark_and_kwargs): -# # Arrange -# benchmark, methods_kwargs = benchmark_and_kwargs -# parameter_name = "num_spikes" -# num_replicates = 2 -# num_spikes = [1, 10, 100] -# rng = np.random.default_rng(0) -# sortings, gt_sortings, parameter_values, parameter_names, iter_nums, methods = [], [], [], [], [], [] -# for replicate in range(num_replicates): -# for spike_num in num_spikes: -# for method in list(methods_kwargs.keys()): -# len_spike_train = 100 -# spike_time_inds = rng.choice(benchmark.recording.get_num_frames(), len_spike_train, replace=False) -# unit_ids = rng.choice(benchmark.gt_sorting.get_unit_ids(), len_spike_train, replace=True) -# sort_index = np.argsort(spike_time_inds) -# spike_time_inds = spike_time_inds[sort_index] -# unit_ids = unit_ids[sort_index] -# sorting = sc.NumpySorting.from_times_labels( -# spike_time_inds, unit_ids, benchmark.recording.sampling_frequency -# ) -# spike_time_inds = rng.choice(benchmark.recording.get_num_frames(), len_spike_train, replace=False) -# unit_ids = rng.choice(benchmark.gt_sorting.get_unit_ids(), len_spike_train, replace=True) -# sort_index = np.argsort(spike_time_inds) -# spike_time_inds = spike_time_inds[sort_index] -# unit_ids = unit_ids[sort_index] -# gt_sorting = sc.NumpySorting.from_times_labels( -# spike_time_inds, unit_ids, benchmark.recording.sampling_frequency -# ) -# sortings.append(sorting) -# gt_sortings.append(gt_sorting) -# parameter_values.append(spike_num) -# parameter_names.append(parameter_name) -# iter_nums.append(replicate) -# methods.append(method) -# matching_df = pd.DataFrame( -# { -# "sorting": sortings, -# "gt_sorting": gt_sortings, -# "parameter_value": parameter_values, -# "parameter_name": parameter_names, -# "iter_num": iter_nums, -# "method": methods, -# } -# ) -# comparison_from_df = matching_df.copy() -# comparison_from_self = matching_df.copy() -# comparison_collision = matching_df.copy() - -# # Act -# benchmark.compare_all_sortings(comparison_from_df, ground_truth="from_df") -# benchmark.compare_all_sortings(comparison_from_self, ground_truth="from_self") -# benchmark.compare_all_sortings(comparison_collision, collision=True) - -# # Assert -# for comparison in [comparison_from_df, comparison_from_self, comparison_collision]: -# assert comparison.shape[0] == len(num_spikes) * num_replicates * len(methods_kwargs) -# assert comparison.shape[1] == 7 -# for comp, sorting in zip(comparison["comparison"], comparison["sorting"]): -# comp.sorting2 == sorting -# for comp, gt_sorting in zip(comparison_from_df["comparison"], comparison["gt_sorting"]): -# comp.sorting1 == gt_sorting -# for comp in comparison_from_self["comparison"]: -# comp.sorting1 == benchmark.gt_sorting - - -# def test_compare_all_sortings_invalid_inputs(benchmark_and_kwargs): -# benchmark, methods_kwargs = benchmark_and_kwargs -# with pytest.raises(ValueError): -# benchmark.compare_all_sortings(pd.DataFrame(), ground_truth="invalid") - - -# if __name__ == "__main__": -# test_run_matching_vary_parameter(benchmark_and_kwargs) -# test_run_matching_vary_parameter_invalid_inputs(benchmark_and_kwargs) -# test_run_matching_misclassed_invalid_inputs(benchmark_and_kwargs) -# test_run_matching_missing_units_invalid_inputs(benchmark_and_kwargs) -# test_compare_all_sortings(benchmark_and_kwargs) -# test_compare_all_sortings_invalid_inputs(benchmark_and_kwargs) From 2467535c62ee8e79d2301a01878ed47c0c20415c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 23 Feb 2024 14:58:32 +0100 Subject: [PATCH 132/192] Wip tests benchmark --- .../test_benchmark_motion_interpolation.py | 126 +++++++++++++++++- .../tests/test_benchmark_peak_localization.py | 93 ++++++++++++- 2 files changed, 214 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index 4e6006f539..cb8cc50b68 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -5,12 +5,134 @@ from pathlib import Path import matplotlib.pyplot as plt -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder +import numpy as np + +import shutil + + +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_drifting_dataset, cache_folder + +from spikeinterface.sortingcomponents.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy +from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import get_unit_disclacement, get_gt_motion_from_unit_discplacement @pytest.mark.skip() def test_benchmark_motion_interpolation(): - pass + + job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") + + + data = make_drifting_dataset() + + datasets = { + "data_static": (data["static_rec"], data["sorting"]), + } + + duration = data["drifting_rec"].get_duration() + channel_locations = data["drifting_rec"].get_channel_locations() + + + + unit_displacements = get_unit_disclacement(data["displacement_vectors"], data["displacement_unit_factor"], direction_dim=1) + + bin_s = 1 + temporal_bins = np.arange(0, duration, bin_s) + spatial_bins = np.linspace(np.min(channel_locations[:, 1]), + np.max(channel_locations[:, 1]), + 10 + ) + print(spatial_bins) + gt_motion = get_gt_motion_from_unit_discplacement( + unit_displacements, data["displacement_sampling_frequency"], + data["unit_locations"], + temporal_bins, spatial_bins, + direction_dim=1 + ) + # fig, ax = plt.subplots() + # ax.imshow(gt_motion.T) + # plt.show() + + + cases = {} + bin_duration_s = 1. + + cases["static_SC2"] = dict( + label = "No drift - no correction - SC2", + dataset="data_static", + init_kwargs=dict( + drifting_recording=data["drifting_rec"], + motion=gt_motion, + temporal_bins=temporal_bins, + spatial_bins=spatial_bins, + ), + params=dict( + recording_source="static", + sorter_name="spykingcircus2", + sorter_params=dict(), + ) + ) + + cases["drifting_SC2"] = dict( + label = "Drift - no correction - SC2", + dataset="data_static", + init_kwargs=dict( + drifting_recording=data["drifting_rec"], + motion=gt_motion, + temporal_bins=temporal_bins, + spatial_bins=spatial_bins, + ), + params=dict( + recording_source="drifting", + sorter_name="spykingcircus2", + sorter_params=dict(), + ) + ) + + cases["drifting_SC2"] = dict( + label = "Drift - correction with GT - SC2", + dataset="data_static", + init_kwargs=dict( + drifting_recording=data["drifting_rec"], + motion=gt_motion, + temporal_bins=temporal_bins, + spatial_bins=spatial_bins, + ), + params=dict( + recording_source="corrected", + sorter_name="spykingcircus2", + sorter_params=dict(), + correct_motion_kwargs=dict(spatial_interpolation_method="kriging"), + ) + ) + + study_folder = cache_folder / 'study_motion_interpolation' + if study_folder.exists(): + shutil.rmtree(study_folder) + study = MotionInterpolationStudy.create(study_folder, datasets, cases) + + # this study needs analyzer + study.create_sorting_analyzer_gt(**job_kwargs) + study.compute_metrics() + + + # run and result + study.run(**job_kwargs) + study.compute_results() + + # load study to check persistency + study = MotionInterpolationStudy(study_folder) + print(study) + + # plots + study.plot_sorting_accuracy(mode="ordered_accuracy", mode_best_merge=False) + study.plot_sorting_accuracy(mode="ordered_accuracy", mode_best_merge=True) + study.plot_sorting_accuracy(mode="depth_snr") + study.plot_sorting_accuracy(mode="snr", mode_best_merge=False) + study.plot_sorting_accuracy(mode="snr", mode_best_merge=True) + study.plot_sorting_accuracy(mode="depth", mode_best_merge=False) + study.plot_sorting_accuracy(mode="depth", mode_best_merge=True) + + plt.show() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py index d2e07b7a1b..baa756d521 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py @@ -1,21 +1,108 @@ import pytest +import shutil + import spikeinterface.full as si import pandas as pd from pathlib import Path import matplotlib.pyplot as plt +import numpy as np + from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder +from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import PeakLocalizationStudy + +from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import UnitLocalizationStudy + @pytest.mark.skip() def test_benchmark_peak_localization(): - pass + job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") + recording, gt_sorting = make_dataset() + # create study + study_folder = cache_folder / 'study_peak_localization' + datasets = {"toy" : (recording, gt_sorting)} + cases = {} + for method in ['center_of_mass', 'grid_convolution', 'monopolar_triangulation']: + cases[method] = { + "label": f"{method} on toy", + "dataset": "toy", + "init_kwargs": {"gt_positions" : gt_sorting.get_property('gt_unit_locations')}, + "params" : {"ms_before" : 2, + "method" : method, + "method_kwargs" : {}, + "spike_retriver_kwargs" : {"channel_from_template" : False}} + } + + if study_folder.exists(): + shutil.rmtree(study_folder) + study = PeakLocalizationStudy.create(study_folder, datasets=datasets, cases=cases) + print(study) -if __name__ == "__main__": - test_benchmark_peak_localization() + # this study needs analyzer + study.create_sorting_analyzer_gt(**job_kwargs) + study.compute_metrics() + + # run and result + study.run(**job_kwargs) + study.compute_results() + + # load study to check persistency + study = PeakLocalizationStudy(study_folder) + study.plot_comparison_positions(smoothing_factor=31) + study.plot_run_times() + + plt.show() + + +@pytest.mark.skip() +def test_benchmark_unit_localization(): + job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") + + recording, gt_sorting = make_dataset() + + # create study + study_folder = cache_folder / 'study_unit_localization' + datasets = {"toy" : (recording, gt_sorting)} + cases = {} + for method in ['center_of_mass', 'grid_convolution', 'monopolar_triangulation']: + cases[method] = { + "label": f"{method} on toy", + "dataset": "toy", + "init_kwargs": {"gt_positions" : gt_sorting.get_property('gt_unit_locations')}, + "params" : {"ms_before" : 2, + "method" : method, + "method_kwargs" : {}, + "spike_retriver_kwargs" : {"channel_from_template" : False}} + } + + if study_folder.exists(): + shutil.rmtree(study_folder) + study = UnitLocalizationStudy.create(study_folder, datasets=datasets, cases=cases) + print(study) + + # this study needs analyzer + study.create_sorting_analyzer_gt(**job_kwargs) + study.compute_metrics() + + # run and result + study.run(**job_kwargs) + study.compute_results() + # load study to check persistency + study = UnitLocalizationStudy(study_folder) + study.plot_comparison_positions(smoothing_factor=31) + study.plot_run_times() + + plt.show() + + + +if __name__ == "__main__": + # test_benchmark_peak_localization() + test_benchmark_unit_localization() From 2819ac72af5d30bc55927d202c18885a07e9a26b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 26 Feb 2024 09:22:54 +0100 Subject: [PATCH 133/192] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- .../sortingcomponents/benchmark/benchmark_clustering.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 52ed56b52e..8848570ef6 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -29,13 +29,13 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 5}, + "sparsity": {"method": "ptp", "threshold": 0.25}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { "method": "smart_sampling_amplitudes", "n_peaks_per_channel": 5000, - "min_n_peaks": 20000, + "min_n_peaks": 100000, "select_per_channel": False, }, "clustering": {"legacy": False}, diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 4a418037b3..21ff2cfa33 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -14,10 +14,7 @@ #from spikeinterface.postprocessing import get_template_extremum_channel from spikeinterface.core import get_noise_levels -import time -import string, random import pylab as plt -import os import numpy as np From 2371758562b1f81ff3df6716fe622bbff2411e3c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 11:59:57 +0100 Subject: [PATCH 134/192] Change sorting_analyzer.select_random_spikes() into sorting_analyzer.compute("random_spikes") --- .../comparison/groundtruthstudy.py | 8 +- .../tests/test_templatecomparison.py | 2 +- .../core/analyzer_extension_core.py | 135 ++++++++++++++++-- src/spikeinterface/core/sortinganalyzer.py | 85 +---------- .../tests/test_analyzer_extension_core.py | 47 ++++-- .../core/tests/test_node_pipeline.py | 2 +- .../core/tests/test_sortinganalyzer.py | 10 -- .../core/tests/test_sparsity.py | 2 +- .../core/tests/test_template_tools.py | 2 +- ...forms_extractor_backwards_compatibility.py | 4 +- ...forms_extractor_backwards_compatibility.py | 16 +-- src/spikeinterface/curation/tests/common.py | 2 +- .../curation/tests/test_auto_merge.py | 2 +- .../curation/tests/test_remove_redundant.py | 2 +- .../tests/test_sortingview_curation.py | 2 +- src/spikeinterface/exporters/tests/common.py | 2 +- .../postprocessing/principal_component.py | 16 +-- .../tests/common_extension_tests.py | 2 +- .../tests/test_principal_component.py | 8 +- .../tests/test_template_similarity.py | 2 +- .../tests/test_metrics_functions.py | 2 +- .../qualitymetrics/tests/test_pca_metrics.py | 2 +- .../tests/test_quality_metric_calculator.py | 4 +- .../benchmark/benchmark_clustering.py | 6 +- .../benchmark/benchmark_peak_localization.py | 5 +- .../benchmark/benchmark_peak_selection.py | 6 +- .../benchmark/benchmark_tools.py | 6 +- .../tests/test_template_matching.py | 2 +- .../waveforms/temporal_pca.py | 2 +- .../widgets/tests/test_widgets.py | 4 +- 30 files changed, 207 insertions(+), 183 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 112e859dad..3a510dd522 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -284,12 +284,10 @@ def get_run_times(self, case_keys=None): return pd.Series(run_times, name="run_time") - def create_sorting_analyzer_gt(self, case_keys=None, **kwargs): + def create_sorting_analyzer_gt(self, case_keys=None, random_params={}, waveforms_params={}, **job_kwargs): if case_keys is None: case_keys = self.cases.keys() - select_params, job_kwargs = split_job_kwargs(kwargs) - base_folder = self.folder / "sorting_analyzer" base_folder.mkdir(exist_ok=True) @@ -300,8 +298,8 @@ def create_sorting_analyzer_gt(self, case_keys=None, **kwargs): folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) - sorting_analyzer.select_random_spikes(**select_params) - sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("random_spikes", **random_params) + sorting_analyzer.compute("waveforms", **waveforms_params, **job_kwargs) sorting_analyzer.compute("templates") sorting_analyzer.compute("noise_levels") diff --git a/src/spikeinterface/comparison/tests/test_templatecomparison.py b/src/spikeinterface/comparison/tests/test_templatecomparison.py index 6e4c2d1714..d361e3ed36 100644 --- a/src/spikeinterface/comparison/tests/test_templatecomparison.py +++ b/src/spikeinterface/comparison/tests/test_templatecomparison.py @@ -46,7 +46,7 @@ def test_compare_multiple_templates(): sorting_analyzer_3 = create_sorting_analyzer(sort3, rec3, format="memory") for sorting_analyzer in (sorting_analyzer_1, sorting_analyzer_2, sorting_analyzer_3): - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("fast_templates") # paired comparison diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index d952c93f75..0d482f5369 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -15,6 +15,102 @@ from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_average from .recording_tools import get_noise_levels from .template import Templates +from .sorting_tools import random_spikes_selection + + +class SelectRandomSpikes(ResultExtension): + """ + ResultExtension that select some random spikes. + + This will be used by "compute_waveforms" and so "compute_templates" or "compute_fast_templates" + + This internally use `random_spikes_selection()` parameters are the same. + + Parameters + ---------- + unit_ids: list or None + Unit ids to retrieve waveforms for + mode: "average" | "median" | "std" | "percentile", default: "average" + The mode to compute the templates + percentile: float, default: None + Percentile to use for mode="percentile" + save: bool, default True + In case, the operator is not computed yet it can be saved to folder or zarr. + + Returns + ------- + + """ + extension_name = "random_spikes" + depend_on = [] + need_recording = False + use_nodepipeline = False + need_job_kwargs = False + + def _run(self, + ): + self.data["random_spikes_indices"] = random_spikes_selection( + self.sorting_analyzer.sorting, num_samples=self.sorting_analyzer.rec_attributes["num_samples"], + **self.params) + + def _set_params(self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None): + params = dict( + method=method, + max_spikes_per_unit=max_spikes_per_unit, + margin_size=margin_size, + seed=seed) + return params + + def _select_extension_data(self, unit_ids): + random_spikes_indices = self.data["random_spikes_indices"] + + spikes = self.sorting_analyzer.sorting.to_spike_vector() + + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) + keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) + + selected_mask = np.zeros(spikes.size, dtype=bool) + selected_mask[random_spikes_indices] = True + + new_data = dict() + new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_spike_mask]) + return new_data + + + def _get_data(self): + return self.data["random_spikes_indices"] + + def some_spikes(self): + # utils to get the some_spikes vector + # use internal cache + if not hasattr(self, "_some_spikes"): + spikes = self.sorting_analyzer.sorting.to_spike_vector() + self._some_spikes = spikes[self.data["random_spikes_indices"]] + return self._some_spikes + + + def get_selected_indices_in_spike_train(self, unit_id, segment_index): + # usefull for Waveforms extractor backwars compatibility + # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain + sorting = self.sorting_analyzer.sorting + random_spikes_indices = self.data["random_spikes_indices"] + + unit_index = sorting.id_to_index(unit_id) + spikes = sorting.to_spike_vector() + spike_indices_in_seg = np.flatnonzero( + (spikes["segment_index"] == segment_index) & (spikes["unit_index"] == unit_index) + ) + common_element, inds_left, inds_right = np.intersect1d( + spike_indices_in_seg, random_spikes_indices, return_indices=True + ) + selected_spikes_in_spike_train = inds_left + return selected_spikes_in_spike_train + + + +register_result_extension(SelectRandomSpikes) + + class ComputeWaveforms(ResultExtension): @@ -25,7 +121,7 @@ class ComputeWaveforms(ResultExtension): """ extension_name = "waveforms" - depend_on = [] + depend_on = ["random_spikes"] need_recording = True use_nodepipeline = False need_job_kwargs = True @@ -41,16 +137,19 @@ def nafter(self): def _run(self, **job_kwargs): self.data.clear() - if self.sorting_analyzer.random_spikes_indices is None: - raise ValueError("compute_waveforms need SortingAnalyzer.select_random_spikes() need to be run first") + # if self.sorting_analyzer.random_spikes_indices is None: + # raise ValueError("compute_waveforms need SortingAnalyzer.select_random_spikes() need to be run first") + + # random_spikes_indices = self.sorting_analyzer.get_extension("random_spikes").get_data() recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting unit_ids = sorting.unit_ids # retrieve spike vector and the sampling - spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + # spikes = sorting.to_spike_vector() + # some_spikes = spikes[random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() if self.format == "binary_folder": # in that case waveforms are extacted directly in files @@ -116,9 +215,12 @@ def _set_params( return params def _select_extension_data(self, unit_ids): + # random_spikes_indices = self.sorting_analyzer.get_extension("random_spikes").get_data() + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) spikes = self.sorting_analyzer.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + # some_spikes = spikes[random_spikes_indices] keep_spike_mask = np.isin(some_spikes["unit_index"], keep_unit_indices) new_data = dict() @@ -133,8 +235,9 @@ def get_waveforms_one_unit( ): sorting = self.sorting_analyzer.sorting unit_index = sorting.id_to_index(unit_id) - spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + # spikes = sorting.to_spike_vector() + # some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() spike_mask = some_spikes["unit_index"] == unit_index wfs = self.data["waveforms"][spike_mask, :, :] @@ -219,8 +322,9 @@ def _compute_and_append(self, operators): raise ValueError(f"ComputeTemplates: wrong operator {operator}") self.data[key] = np.zeros((unit_ids.size, num_samples, channel_ids.size)) - spikes = self.sorting_analyzer.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + # spikes = self.sorting_analyzer.sorting.to_spike_vector() + # some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() for unit_index, unit_id in enumerate(unit_ids): spike_mask = some_spikes["unit_index"] == unit_index wfs = waveforms[spike_mask, :, :] @@ -348,7 +452,7 @@ class ComputeFastTemplates(ResultExtension): """ extension_name = "fast_templates" - depend_on = [] + depend_on = ["random_spikes"] need_recording = True use_nodepipeline = False need_job_kwargs = True @@ -364,16 +468,17 @@ def nafter(self): def _run(self, **job_kwargs): self.data.clear() - if self.sorting_analyzer.random_spikes_indices is None: - raise ValueError("compute_waveforms need SortingAnalyzer.select_random_spikes() need to be run first") + # if self.sorting_analyzer.random_spikes_indices is None: + # raise ValueError("compute_waveforms need SortingAnalyzer.select_random_spikes() need to be run first") recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting unit_ids = sorting.unit_ids # retrieve spike vector and the sampling - spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + # spikes = sorting.to_spike_vector() + # some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() return_scaled = self.params["return_scaled"] diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 67a8e673e8..a743310327 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -20,7 +20,7 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes -from .sorting_tools import random_spikes_selection +# from .sorting_tools import random_spikes_selection from .core_tools import check_json from .job_tools import split_job_kwargs from .numpyextractors import SharedMemorySorting @@ -30,9 +30,6 @@ from .node_pipeline import run_node_pipeline -# TODO make some_spikes a method of SortingAnalyzer - - # high level function def create_sorting_analyzer( sorting, recording, format="memory", folder=None, sparse=True, sparsity=None, **sparsity_kwargs @@ -158,7 +155,7 @@ class SortingAnalyzer: """ def __init__( - self, sorting=None, recording=None, rec_attributes=None, format=None, sparsity=None, random_spikes_indices=None + self, sorting=None, recording=None, rec_attributes=None, format=None, sparsity=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -167,7 +164,6 @@ def __init__( self.rec_attributes = rec_attributes self.format = format self.sparsity = sparsity - self.random_spikes_indices = random_spikes_indices # extensions are not loaded at init self.extensions = dict() @@ -323,8 +319,6 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, rec_attribut if sparsity is not None: np.save(folder / "sparsity_mask.npy", sparsity.mask) - # with open(folder / "sparsity.json", mode="w") as f: - # json.dump(check_json(sparsity.to_dict()), f) @classmethod def load_from_binary_folder(cls, folder, recording=None): @@ -365,29 +359,19 @@ def load_from_binary_folder(cls, folder, recording=None): rec_attributes["probegroup"] = None # sparsity - # sparsity_file = folder / "sparsity.json" sparsity_file = folder / "sparsity_mask.npy" if sparsity_file.is_file(): sparsity_mask = np.load(sparsity_file) - # with open(sparsity_file, mode="r") as f: - # sparsity = ChannelSparsity.from_dict(json.load(f)) sparsity = ChannelSparsity(sparsity_mask, sorting.unit_ids, rec_attributes["channel_ids"]) else: sparsity = None - selected_spike_file = folder / "random_spikes_indices.npy" - if selected_spike_file.is_file(): - random_spikes_indices = np.load(selected_spike_file) - else: - random_spikes_indices = None - sorting_analyzer = SortingAnalyzer( sorting=sorting, recording=recording, rec_attributes=rec_attributes, format="binary_folder", sparsity=sparsity, - random_spikes_indices=random_spikes_indices, ) return sorting_analyzer @@ -434,11 +418,9 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): # sorting provenance sort_dict = sorting.to_dict(relative_to=folder, recursive=True) if sorting.check_serializability("json"): - # zarr_root.attrs["sorting_provenance"] = check_json(sort_dict) zarr_sort = np.array([sort_dict], dtype=object) zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON()) elif sorting.check_serializability("pickle"): - # zarr_root.create_dataset("sorting_provenance", data=sort_dict, object_codec=numcodecs.Pickle()) zarr_sort = np.array([sort_dict], dtype=object) zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle()) @@ -456,15 +438,11 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): probegroup = rec_attributes.pop("probegroup") recording_info.attrs["recording_attributes"] = check_json(rec_attributes) - # recording_info.create_dataset("recording_attributes", data=check_json(rec_attributes), object_codec=numcodecs.JSON()) if probegroup is not None: recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) - # recording_info.create_dataset("probegroup", data=check_json(probegroup.to_dict()), object_codec=numcodecs.JSON()) if sparsity is not None: - # zarr_root.attrs["sparsity"] = check_json(sparsity.to_dict()) - # zarr_root.create_dataset("sparsity", data=check_json(sparsity.to_dict()), object_codec=numcodecs.JSON()) zarr_root.create_dataset("sparsity_mask", data=sparsity.mask) # write sorting copy @@ -507,10 +485,8 @@ def load_from_zarr(cls, folder, recording=None): # recording attributes rec_attributes = zarr_root["recording_info"].attrs["recording_attributes"] - # rec_attributes = zarr_root["recording_info"]["recording_attributes"] if "probegroup" in zarr_root["recording_info"].attrs: probegroup_dict = zarr_root["recording_info"].attrs["probegroup"] - # probegroup_dict = zarr_root["recording_info"]["probegroup"] rec_attributes["probegroup"] = probeinterface.ProbeGroup.from_dict(probegroup_dict) else: rec_attributes["probegroup"] = None @@ -522,18 +498,12 @@ def load_from_zarr(cls, folder, recording=None): else: sparsity = None - if "random_spikes_indices" in zarr_root.keys(): - random_spikes_indices = zarr_root["random_spikes_indices"] - else: - random_spikes_indices = None - sorting_analyzer = SortingAnalyzer( sorting=sorting, recording=recording, rec_attributes=rec_attributes, format="zarr", sparsity=sparsity, - random_spikes_indices=random_spikes_indices, ) return sorting_analyzer @@ -590,25 +560,6 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> else: raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") - # propagate random_spikes_indices is already done - if self.random_spikes_indices is not None: - if unit_ids is None: - new_sorting_analyzer.random_spikes_indices = self.random_spikes_indices.copy() - else: - # more tricky - spikes = self.sorting.to_spike_vector() - - keep_unit_indices = np.flatnonzero(np.isin(self.unit_ids, unit_ids)) - keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) - - selected_mask = np.zeros(spikes.size, dtype=bool) - selected_mask[self.random_spikes_indices] = True - - new_sorting_analyzer.random_spikes_indices = np.flatnonzero(selected_mask[keep_spike_mask]) - - # save it - new_sorting_analyzer._save_random_spikes_indices() - # make a copy of extensions # note that the copy of extension handle itself the slicing of units when necessary and also the saveing for extension_name, extension in self.extensions.items(): @@ -1054,38 +1005,6 @@ def has_extension(self, extension_name: str) -> bool: else: return False - ## random_spikes_selection zone - def select_random_spikes(self, **random_kwargs): - # random_spikes_indices is a vector that refer to the spike vector of the sorting in absolut index - assert self.random_spikes_indices is None, "select random spikes is already computed" - - self.random_spikes_indices = random_spikes_selection( - self.sorting, self.rec_attributes["num_samples"], **random_kwargs - ) - self._save_random_spikes_indices() - - def _save_random_spikes_indices(self): - if self.format == "binary_folder": - np.save(self.folder / "random_spikes_indices.npy", self.random_spikes_indices) - elif self.format == "zarr": - zarr_root = self._get_zarr_root() - zarr_root.create_dataset("random_spikes_indices", data=self.random_spikes_indices) - - def get_selected_indices_in_spike_train(self, unit_id, segment_index): - # usefull for Waveforms extractor backwars compatibility - # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain - assert self.random_spikes_indices is not None, "random spikes selection is not computed" - unit_index = self.sorting.id_to_index(unit_id) - spikes = self.sorting.to_spike_vector() - spike_indices_in_seg = np.flatnonzero( - (spikes["segment_index"] == segment_index) & (spikes["unit_index"] == unit_index) - ) - common_element, inds_left, inds_right = np.intersect1d( - spike_indices_in_seg, self.random_spikes_indices, return_indices=True - ) - selected_spikes_in_spike_train = inds_left - return selected_spikes_in_spike_train - global _possible_extensions _possible_extensions = [] diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index 482963ffe1..f94226110f 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -73,13 +73,26 @@ def _check_result_extension(sorting_analyzer, extension_name): # print(k, arr.shape) +@pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) +@pytest.mark.parametrize("sparse", [False, ]) +def test_SelectRandomSpikes(format, sparse): + sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) + + ext = sorting_analyzer.compute("random_spikes", max_spikes_per_unit=10, seed=2205) + indices = ext.data["random_spikes_indices"] + assert indices.size == 10 * sorting_analyzer.sorting.unit_ids.size + # print(indices) + + _check_result_extension(sorting_analyzer, "random_spikes") + + @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) @pytest.mark.parametrize("sparse", [True, False]) def test_ComputeWaveforms(format, sparse): sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) - sorting_analyzer.select_random_spikes(max_spikes_per_unit=50, seed=2205) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205) ext = sorting_analyzer.compute("waveforms", **job_kwargs) wfs = ext.data["waveforms"] _check_result_extension(sorting_analyzer, "waveforms") @@ -90,7 +103,7 @@ def test_ComputeWaveforms(format, sparse): def test_ComputeTemplates(format, sparse): sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) - sorting_analyzer.select_random_spikes(max_spikes_per_unit=20, seed=2205) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=20, seed=2205) with pytest.raises(AssertionError): # This require "waveforms first and should trig an error @@ -145,14 +158,15 @@ def test_ComputeFastTemplates(format, sparse): ms_before = 1.0 ms_after = 2.5 - sorting_analyzer.select_random_spikes(max_spikes_per_unit=20, seed=2205) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=20, seed=2205) + sorting_analyzer.compute("fast_templates", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs) _check_result_extension(sorting_analyzer, "fast_templates") # compare ComputeTemplates with dense and ComputeFastTemplates: should give the same on "average" other_sorting_analyzer = get_sorting_analyzer(format=format, sparse=False) - other_sorting_analyzer.select_random_spikes(max_spikes_per_unit=20, seed=2205) + other_sorting_analyzer.compute("random_spikes", max_spikes_per_unit=20, seed=2205) other_sorting_analyzer.compute( "waveforms", ms_before=ms_before, ms_after=ms_after, return_scaled=True, **job_kwargs ) @@ -191,18 +205,21 @@ def test_ComputeNoiseLevels(format, sparse): if __name__ == "__main__": - # test_ComputeWaveforms(format="memory", sparse=True) - # test_ComputeWaveforms(format="memory", sparse=False) - # test_ComputeWaveforms(format="binary_folder", sparse=True) - # test_ComputeWaveforms(format="binary_folder", sparse=False) - # test_ComputeWaveforms(format="zarr", sparse=True) - # test_ComputeWaveforms(format="zarr", sparse=False) + + test_SelectRandomSpikes(format="memory", sparse=True) + + test_ComputeWaveforms(format="memory", sparse=True) + test_ComputeWaveforms(format="memory", sparse=False) + test_ComputeWaveforms(format="binary_folder", sparse=True) + test_ComputeWaveforms(format="binary_folder", sparse=False) + test_ComputeWaveforms(format="zarr", sparse=True) + test_ComputeWaveforms(format="zarr", sparse=False) test_ComputeTemplates(format="memory", sparse=True) - # test_ComputeTemplates(format="memory", sparse=False) - # test_ComputeTemplates(format="binary_folder", sparse=True) - # test_ComputeTemplates(format="zarr", sparse=True) + test_ComputeTemplates(format="memory", sparse=False) + test_ComputeTemplates(format="binary_folder", sparse=True) + test_ComputeTemplates(format="zarr", sparse=True) - # test_ComputeFastTemplates(format="memory", sparse=True) + test_ComputeFastTemplates(format="memory", sparse=True) - # test_ComputeNoiseLevels(format="memory", sparse=False) + test_ComputeNoiseLevels(format="memory", sparse=False) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index effd116d44..fc23927be7 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -78,7 +78,7 @@ def test_run_node_pipeline(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("fast_templates") extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 3cd1286afb..03c18c2f43 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -96,11 +96,6 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting): assert sorting_analyzer.has_recording() - if sorting_analyzer.random_spikes_indices is None: - sorting_analyzer.select_random_spikes(max_spikes_per_unit=10, seed=2205) - assert sorting_analyzer.random_spikes_indices is not None - assert sorting_analyzer.random_spikes_indices.size == 10 * sorting_analyzer.sorting.unit_ids.size - # save to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": @@ -140,11 +135,6 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting): keep_unit_ids = original_sorting.unit_ids[::2] sorting_analyzer2 = sorting_analyzer.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) - # check that random_spikes_indices are remmaped - assert sorting_analyzer2.random_spikes_indices is not None - some_spikes = sorting_analyzer2.sorting.to_spike_vector()[sorting_analyzer2.random_spikes_indices] - assert np.array_equal(np.unique(some_spikes["unit_index"]), np.arange(keep_unit_ids.size)) - # check propagation of result data and correct sligin assert np.array_equal(keep_unit_ids, sorting_analyzer2.unit_ids) data = sorting_analyzer2.get_extension("dummy").data diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index d650932162..ff92ccbddc 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -200,7 +200,7 @@ def test_compute_sparsity(): recording, sorting = get_dataset() sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("fast_templates", return_scaled=True) sorting_analyzer.compute("noise_levels", return_scaled=True) # this is needed for method="energy" diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index d936674ed5..0ef80d7b08 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -26,7 +26,7 @@ def get_sorting_analyzer(): sorting.set_property("group", [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("fast_templates") return sorting_analyzer diff --git a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py index dcf16bb804..d122723d85 100644 --- a/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py @@ -105,5 +105,5 @@ def test_read_old_waveforms_extractor_binary(): if __name__ == "__main__": - # test_extract_waveforms() - test_read_old_waveforms_extractor_binary() + test_extract_waveforms() + # test_read_old_waveforms_extractor_binary() diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index afda9b3967..56dc17817b 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -23,7 +23,7 @@ from .sparsity import ChannelSparsity from .sortinganalyzer import SortingAnalyzer, load_sorting_analyzer from .base import load_extractor -from .analyzer_extension_core import ComputeWaveforms, ComputeTemplates +from .analyzer_extension_core import SelectRandomSpikes, ComputeWaveforms, ComputeTemplates _backwards_compatibility_msg = """#### # extract_waveforms() and WaveformExtractor() have been replace by SortingAnalyzer since version 0.101 @@ -96,9 +96,7 @@ def extract_waveforms( sorting, recording, format=format, folder=folder, sparse=sparse, sparsity=sparsity, **sparsity_kwargs ) - # TODO propagate job_kwargs - - sorting_analyzer.select_random_spikes(max_spikes_per_unit=max_spikes_per_unit, seed=seed) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=max_spikes_per_unit, seed=seed) waveforms_params = dict(ms_before=ms_before, ms_after=ms_after, return_scaled=return_scaled, dtype=dtype) sorting_analyzer.compute("waveforms", **waveforms_params, **job_kwargs) @@ -232,7 +230,8 @@ def get_sampled_indices(self, unit_id): # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) with a complex dtype as follow selected_spikes = [] for segment_index in range(self.get_num_segments()): - inds = self.sorting_analyzer.get_selected_indices_in_spike_train(unit_id, segment_index) + # inds = self.sorting_analyzer.get_selected_indices_in_spike_train(unit_id, segment_index) + inds = self.sorting_analyzer.get_extension("random_spikes").get_selected_indices_in_spike_train(unit_id, segment_index) sampled_index = np.zeros(inds.size, dtype=[("spike_index", "int64"), ("segment_index", "int64")]) sampled_index["spike_index"] = inds sampled_index["segment_index"][:] = segment_index @@ -251,8 +250,7 @@ def get_waveforms( # lazy and cache are ingnored ext = self.sorting_analyzer.get_extension("waveforms") unit_index = self.sorting.id_to_index(unit_id) - spikes = self.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() spike_mask = some_spikes["unit_index"] == unit_index wfs = ext.data["waveforms"][spike_mask, :, :] @@ -439,7 +437,9 @@ def _read_old_waveforms_extractor_binary(folder): mask = some_spikes["unit_index"] == unit_index waveforms[:, :, : wfs.shape[2]][mask, :, :] = wfs - sorting_analyzer.random_spikes_indices = random_spikes_indices + ext = SelectRandomSpikes(sorting_analyzer) + ext.params = dict() + ext.data = dict(random_spikes_indices=random_spikes_indices) ext = ComputeWaveforms(sorting_analyzer) ext.params = dict( diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 40a6e28e10..0d561227a4 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -27,7 +27,7 @@ def make_sorting_analyzer(sparse=True): ) sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") sorting_analyzer.compute("noise_levels") diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 4dd62a3178..f8dea5b270 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -36,7 +36,7 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation): job_kwargs = dict(n_jobs=-1) sorting_analyzer = create_sorting_analyzer(sorting_with_split, recording, format="memory") - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") diff --git a/src/spikeinterface/curation/tests/test_remove_redundant.py b/src/spikeinterface/curation/tests/test_remove_redundant.py index 2877442cef..9172979bfa 100644 --- a/src/spikeinterface/curation/tests/test_remove_redundant.py +++ b/src/spikeinterface/curation/tests/test_remove_redundant.py @@ -26,7 +26,7 @@ def test_remove_redundant_units(sorting_analyzer_for_curation): job_kwargs = dict(n_jobs=-1) sorting_analyzer = create_sorting_analyzer(sorting_with_dup, recording, format="memory") - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 5e2d47fb60..5ac82aab86 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -38,7 +38,7 @@ # recording, sorting = read_mearec(local_path) # sorting_analyzer = si.create_sorting_analyzer(sorting, recording, format="memory") -# sorting_analyzer.select_random_spikes() +# sorting_analyzer.compute("random_spikes") # sorting_analyzer.compute("waveforms") # sorting_analyzer.compute("templates") # sorting_analyzer.compute("noise_levels") diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index 2b5a813591..800947d033 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -44,7 +44,7 @@ def make_sorting_analyzer(sparse=True, with_group=False): else: sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms") sorting_analyzer.compute("templates") sorting_analyzer.compute("noise_levels") diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index af41f95d87..4b9ab023cb 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -59,6 +59,7 @@ class ComputePrincipalComponents(ResultExtension): extension_name = "principal_components" depend_on = [ + "random_spikes", "waveforms", ] need_recording = False @@ -89,8 +90,7 @@ def _set_params( def _select_extension_data(self, unit_ids): keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - spikes = self.sorting_analyzer.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() keep_spike_mask = np.isin(some_spikes["unit_index"], keep_unit_indices) new_data = dict() @@ -147,8 +147,7 @@ def get_projections_one_unit(self, unit_id, sparse=False): assert self.params["mode"] != "concatenated", "mode concatenated cannot retrieve sparse projection" assert sparsity is not None, "sparse projection need SortingAnalyzer to be sparse" - spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() unit_index = sorting.id_to_index(unit_id) spike_mask = some_spikes["unit_index"] == unit_index @@ -205,8 +204,7 @@ def get_some_projections(self, channel_ids=None, unit_ids=None): sparsity = self.sorting_analyzer.sparsity - spikes = sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() unit_indices = sorting.ids_to_indices(unit_ids) selected_inds = np.flatnonzero(np.isin(some_spikes["unit_index"], unit_indices)) @@ -288,8 +286,7 @@ def _run(self, **job_kwargs): # transform waveforms_ext = self.sorting_analyzer.get_extension("waveforms") some_waveforms = waveforms_ext.data["waveforms"] - spikes = self.sorting_analyzer.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() pca_projection = self._transform_waveforms(some_spikes, some_waveforms, pca_model, progress_bar) @@ -541,8 +538,7 @@ def _get_sparse_waveforms(self, unit_id): waveforms_ext = self.sorting_analyzer.get_extension("waveforms") some_waveforms = waveforms_ext.data["waveforms"] - spikes = self.sorting_analyzer.sorting.to_spike_vector() - some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() return self._get_slice_waveforms(unit_id, some_spikes, some_waveforms) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index a24e962e56..f7ab30bfec 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -91,7 +91,7 @@ def _prepare_sorting_analyzer(self, format, sparse): sorting_analyzer = get_sorting_analyzer( self.recording, self.sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name ) - sorting_analyzer.select_random_spikes(max_spikes_per_unit=50, seed=2205) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205) for dependency_name in self.extension_class.depend_on: if "|" in dependency_name: dependency_name = dependency_name.split("|")[0] diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index c7e9942f2d..c4f378c295 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -62,17 +62,19 @@ def test_get_projections(self): some_unit_ids = sorting_analyzer.unit_ids[::2] some_channel_ids = sorting_analyzer.channel_ids[::2] + random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data() + # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == sorting_analyzer.random_spikes_indices.size + assert spike_unit_index.shape[0] == random_spikes_indices.size assert some_projections.shape[1] == n_components assert some_projections.shape[2] == num_chans # this should be some spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < sorting_analyzer.random_spikes_indices.size + assert spike_unit_index.shape[0] < random_spikes_indices.size assert some_projections.shape[1] == n_components assert some_projections.shape[2] == num_chans assert 1 not in spike_unit_index @@ -82,7 +84,7 @@ def test_get_projections(self): channel_ids=some_channel_ids, unit_ids=some_unit_ids ) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < sorting_analyzer.random_spikes_indices.size + assert spike_unit_index.shape[0] < random_spikes_indices.size assert some_projections.shape[1] == n_components assert some_projections.shape[2] == some_channel_ids.size assert 1 not in spike_unit_index diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index b8fc608d2e..746c45da08 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -21,7 +21,7 @@ def test_check_equal_template_with_distribution_overlap(): recording, sorting = get_dataset() sorting_analyzer = get_sorting_analyzer(recording, sorting, sparsity=None) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms") sorting_analyzer.compute("templates") diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 3f3cec54fe..c97223dd70 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -62,7 +62,7 @@ def _sorting_analyzer_simple(): sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - sorting_analyzer.select_random_spikes(max_spikes_per_unit=300, seed=2205) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) sorting_analyzer.compute("noise_levels") sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 6aa0ba73d6..526f506154 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -39,7 +39,7 @@ def _sorting_analyzer_simple(): sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - sorting_analyzer.select_random_spikes(max_spikes_per_unit=300, seed=2205) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) sorting_analyzer.compute("noise_levels") sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates", operators=["average", "std", "median"]) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 4f83bc8986..8e1be24753 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -53,7 +53,7 @@ def get_sorting_analyzer(seed=2205): sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - sorting_analyzer.select_random_spikes(max_spikes_per_unit=300, seed=seed) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=seed) sorting_analyzer.compute("noise_levels") sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") @@ -146,7 +146,7 @@ def test_empty_units(sorting_analyzer_simple): assert len(sorting_empty.get_empty_unit_ids()) == 3 sorting_analyzer_empty = create_sorting_analyzer(sorting_empty, sorting_analyzer.recording, format="memory") - sorting_analyzer_empty.select_random_spikes(max_spikes_per_unit=300, seed=2205) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) sorting_analyzer_empty.compute("noise_levels") sorting_analyzer_empty.compute("waveforms", **job_kwargs) sorting_analyzer_empty.compute("templates") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 4a418037b3..96b83c1fe4 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -36,7 +36,7 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.indices = indices sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") ext = sorting_analyzer.compute('fast_templates') extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") @@ -76,12 +76,12 @@ def compute_result(self, **result_params): exhaustive_gt=self.exhaustive_gt) sorting_analyzer = create_sorting_analyzer(self.result['sliced_gt_sorting'], self.recording, format='memory', sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") ext = sorting_analyzer.compute('fast_templates') self.result['sliced_gt_templates'] = ext.get_data(outputs="Templates") sorting_analyzer = create_sorting_analyzer(self.result['clustering'], self.recording, format='memory', sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") ext = sorting_analyzer.compute('fast_templates') self.result['clustering_templates'] = ext.get_data(outputs="Templates") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index 6f75a57a78..61e1ce2098 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -37,7 +37,7 @@ def __init__(self, recording, gt_sorting, params, gt_positions): def run(self, **job_kwargs): sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") ext = sorting_analyzer.compute('fast_templates', **self.templates_params) templates = ext.get_data(outputs='Templates') ext = sorting_analyzer.compute("spike_locations", **self.params) @@ -181,8 +181,7 @@ def __init__(self, recording, gt_sorting, params, gt_positions): def run(self, **job_kwargs): sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory') - sorting_analyzer.select_random_spikes() - + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute('waveforms', **self.waveforms_params, **job_kwargs) ext = sorting_analyzer.compute('templates') templates = ext.get_data(outputs='Templates') diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 0e9f2b9052..1f97f0a0c6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -40,7 +40,7 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.indices = indices sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") ext = sorting_analyzer.compute('fast_templates') extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") @@ -80,12 +80,12 @@ def compute_result(self, **result_params): exhaustive_gt=self.exhaustive_gt) sorting_analyzer = create_sorting_analyzer(self.result['sliced_gt_sorting'], self.recording, format='memory', sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") ext = sorting_analyzer.compute('fast_templates') self.result['sliced_gt_templates'] = ext.get_data(outputs="Templates") sorting_analyzer = create_sorting_analyzer(self.result['clustering'], self.recording, format='memory', sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") ext = sorting_analyzer.compute('fast_templates') self.result['clustering_templates'] = ext.get_data(outputs="Templates") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index e92c6af724..8a239453e5 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -214,12 +214,10 @@ def compute_results(self, case_keys=None, verbose=False, **result_params): benchmark.compute_result(**result_params) benchmark.save_result(self.folder / "results" / self.key_to_str(key)) - def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, **kwargs): + def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, random_params={}, **job_kwargs): if case_keys is None: case_keys = self.cases.keys() - select_params, job_kwargs = split_job_kwargs(kwargs) - base_folder = self.folder / "sorting_analyzer" base_folder.mkdir(exist_ok=True) @@ -230,7 +228,7 @@ def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, **kwarg folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) - sorting_analyzer.select_random_spikes(**select_params) + sorting_analyzer.compute("random_spikes", **random_params) sorting_analyzer.compute("waveforms", return_scaled=return_scaled, **job_kwargs) sorting_analyzer.compute("templates") sorting_analyzer.compute("noise_levels", return_scaled=return_scaled) diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index a73ce93b4c..0e065c5f3f 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -15,7 +15,7 @@ def get_sorting_analyzer(): recording, sorting = make_dataset() sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=False) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("fast_templates", **job_kwargs) sorting_analyzer.compute("noise_levels") return sorting_analyzer diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index fb9d1010f8..029a7f44b0 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -141,7 +141,7 @@ def fit( # TODO alessio, herberto : the fitting is done with a SortingAnalyzer which is a postprocessing object, I think we should not do this for a component sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True) - sorting_analyzer.select_random_spikes() + sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms", ms_before=ms_before, ms_after=ms_after) sorting_analyzer.compute( "principal_components", n_components=n_components, mode="by_channel_global", whiten=whiten diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 141f73e881..8b5a67c796 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -77,7 +77,7 @@ def setUpClass(cls): # create dense cls.sorting_analyzer_dense = create_sorting_analyzer(cls.sorting, cls.recording, format="memory", sparse=False) - cls.sorting_analyzer_dense.select_random_spikes() + cls.sorting_analyzer_dense.compute("random_spikes") cls.sorting_analyzer_dense.compute(extensions_to_compute, **job_kwargs) sw.set_default_plotter_backend("matplotlib") @@ -92,7 +92,7 @@ def setUpClass(cls): cls.sorting_analyzer_sparse = create_sorting_analyzer( cls.sorting, cls.recording, format="memory", sparsity=cls.sparsity_radius ) - cls.sorting_analyzer_sparse.select_random_spikes() + cls.sorting_analyzer_sparse.compute("random_spikes") cls.sorting_analyzer_sparse.compute(extensions_to_compute, **job_kwargs) cls.skip_backends = ["ipywidgets", "ephyviewer"] From 868e27d018cc738ec926b59d3db871b89b973762 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 12:47:47 +0100 Subject: [PATCH 135/192] estimate_sparsity() also handle probegroup case. --- src/spikeinterface/core/sparsity.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 48c4cdf1e0..ca7cf59380 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -2,9 +2,9 @@ import numpy as np + from .basesorting import BaseSorting from .baserecording import BaseRecording -from .recording_tools import get_noise_levels from .sorting_tools import random_spikes_selection from .job_tools import _shared_job_kwargs_doc from .waveform_tools import estimate_templates_average @@ -593,6 +593,15 @@ def estimate_sparsity( len(recording.get_probes()) == 1 ), "The 'radius' method of `estimate_sparsity()` can handle only one probe" + if recording.get_probes() == 1: + # standard case + probe = recording.get_probe() + else: + # if many probe or no probe then we use channel location and create a dummy probe with all channels + # note that get_channel_locations() is checking that channel are not spatialy overlapping so the radius method is OK. + chan_locs = recording.get_channel_locations() + probe = recording.create_dummy_probe_from_locations(chan_locs) + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) @@ -625,7 +634,7 @@ def estimate_sparsity( sparsity_mask=None, channel_ids=recording.channel_ids, unit_ids=sorting.unit_ids, - probe=recording.get_probe(), + probe=probe, ) sparsity = compute_sparsity( From 898440564f494044cb7e99526e4c939ec16835f0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 13:02:47 +0100 Subject: [PATCH 136/192] Feedback from Zach on SortingAnalyzer.compute(list) --- src/spikeinterface/core/sortinganalyzer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a743310327..9ad3bb6869 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -747,6 +747,14 @@ def compute(self, input, save=True, **kwargs): params_, job_kwargs = split_job_kwargs(kwargs) assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" self.compute_several_extensions(extensions=input, save=save, **job_kwargs) + elif isinstance(input, list): + params_, job_kwargs = split_job_kwargs(kwargs) + assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" + extensions = {k : {} for k in input} + self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs) + else: + raise ValueError("SortingAnalyzer.compute() need str, dict or list") + def compute_one_extension(self, extension_name, save=True, **kwargs): """ From 393e9278fe4a544cdce8c3d63df07b90b7cb3033 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 13:14:15 +0100 Subject: [PATCH 137/192] some clean --- .../comparison/tests/test_templatecomparison.py | 3 +-- src/spikeinterface/core/sortinganalyzer.py | 8 +++++++- src/spikeinterface/core/tests/test_node_pipeline.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_templatecomparison.py b/src/spikeinterface/comparison/tests/test_templatecomparison.py index d361e3ed36..595820b00b 100644 --- a/src/spikeinterface/comparison/tests/test_templatecomparison.py +++ b/src/spikeinterface/comparison/tests/test_templatecomparison.py @@ -46,8 +46,7 @@ def test_compare_multiple_templates(): sorting_analyzer_3 = create_sorting_analyzer(sort3, rec3, format="memory") for sorting_analyzer in (sorting_analyzer_1, sorting_analyzer_2, sorting_analyzer_3): - sorting_analyzer.compute("random_spikes") - sorting_analyzer.compute("fast_templates") + sorting_analyzer.compute(["random_spikes", "fast_templates"]) # paired comparison temp_cmp = compare_templates(sorting_analyzer_1, sorting_analyzer_2) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 9ad3bb6869..a789729ced 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -20,7 +20,6 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes -# from .sorting_tools import random_spikes_selection from .core_tools import check_json from .job_tools import split_job_kwargs from .numpyextractors import SharedMemorySorting @@ -88,6 +87,13 @@ def create_sorting_analyzer( >>> # Can make a copy with a subset of units (extensions are propagated for the unit subset) >>> sorting_analyzer4 = sorting_analyzer.select_units(unit_ids=sorting.units_ids[:5], format="memory") >>> sorting_analyzer5 = sorting_analyzer.select_units(unit_ids=sorting.units_ids[:5], format="binary_folder", folder="/result_5units") + + Notes + ----- + + By default creating a SortingAnalyzer can be slow because the sparsity is estimated by default. + In some situation, sparsity is not needed, so to make it fast creation, you need to turn + sparsity off (or give external sparsity) like this. """ # handle sparsity diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index fc23927be7..a30f1d273c 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -79,7 +79,7 @@ def test_run_node_pipeline(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute("random_spikes") - sorting_analyzer.compute("fast_templates") + sorting_analyzer.compute("fast_templates", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) From 9733d895205334aea83ab9231e8f3f41263e697a Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 26 Feb 2024 14:38:14 +0100 Subject: [PATCH 138/192] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 52ed56b52e..d35fa63104 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -29,13 +29,13 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 5}, + "sparsity": {"method": "ptp", "threshold": 0.25}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { "method": "smart_sampling_amplitudes", "n_peaks_per_channel": 5000, - "min_n_peaks": 20000, + "min_n_peaks": 50000, "select_per_channel": False, }, "clustering": {"legacy": False}, From c957b77f58599ae27ecd5d94ebf44e3c06d047fd Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 16:10:13 +0100 Subject: [PATCH 139/192] oups --- src/spikeinterface/core/sparsity.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index ca7cf59380..5d69464569 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -588,10 +588,6 @@ def estimate_sparsity( from .template import Templates assert method in ("radius", "best_channels"), "estimate_sparsity() handle only method='radius' or 'best_channel'" - if method == "radius": - assert ( - len(recording.get_probes()) == 1 - ), "The 'radius' method of `estimate_sparsity()` can handle only one probe" if recording.get_probes() == 1: # standard case From abff9b2dfbfdf5660c764123866cbaac3fd470ad Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 16:11:31 +0100 Subject: [PATCH 140/192] oups --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a789729ced..e096d2bc8a 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -104,7 +104,7 @@ def create_sorting_analyzer( sorting.unit_ids, sparsity.unit_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" assert np.array_equal( - recording.channel_ids, recording.channel_ids + recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" elif sparse: sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs) From af9c0c61b7c6440a6aae74e4ff15dbcfac8feefc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 16:39:18 +0100 Subject: [PATCH 141/192] Run spikeinterface_gui using plot_sorting_summary(analyzer, backend="spikeinterface_gui") --- src/spikeinterface/widgets/base.py | 2 ++ src/spikeinterface/widgets/sorting_summary.py | 17 +++++++++++++++-- .../widgets/tests/test_widgets.py | 10 +++++----- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 3bfacfc370..7440c240ce 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -46,6 +46,7 @@ def set_default_plotter_backend(backend): # "controllers": "" }, "ephyviewer": {}, + "spikeinterface_gui": {}, } default_backend_kwargs = { @@ -53,6 +54,7 @@ def set_default_plotter_backend(backend): "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True, "controllers": None}, "ephyviewer": {}, + "spikeinterface_gui": {}, } diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 78293757ec..c4fce76dad 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -16,7 +16,9 @@ class SortingSummaryWidget(BaseWidget): """ - Plots spike sorting summary + Plots spike sorting summary. + This is the main viewer to visualize the final result with several sub view. + This use sortingview (in a web browser) or spikeinterface-gui (with Qt). Parameters ---------- @@ -59,7 +61,7 @@ def __init__( **backend_kwargs, ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - self.check_extensions(sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "similarity"]) + self.check_extensions(sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"]) sorting = sorting_analyzer.sorting if unit_ids is None: @@ -177,3 +179,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.view = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) self.url = handle_display_and_url(self, self.view, **backend_kwargs) + + def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): + sorting_analyzer = data_plot["sorting_analyzer"] + + + import spikeinterface_gui + app = spikeinterface_gui.mkQApp() + win = spikeinterface_gui.MainWindow(sorting_analyzer) + win.show() + app.exec_() + diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 8b5a67c796..7f761190f4 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -95,7 +95,7 @@ def setUpClass(cls): cls.sorting_analyzer_sparse.compute("random_spikes") cls.sorting_analyzer_sparse.compute(extensions_to_compute, **job_kwargs) - cls.skip_backends = ["ipywidgets", "ephyviewer"] + cls.skip_backends = ["ipywidgets", "ephyviewer", "spikeinterface_gui"] # cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] if ON_GITHUB and not KACHERY_CLOUD_SET: @@ -103,7 +103,7 @@ def setUpClass(cls): print(f"Widgets tests: skipping backends - {cls.skip_backends}") - cls.backend_kwargs = {"matplotlib": {}, "sortingview": {}, "ipywidgets": {"display": False}} + cls.backend_kwargs = {"matplotlib": {}, "sortingview": {}, "ipywidgets": {"display": False}, "spikeinterface_gui": {}} cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) @@ -573,7 +573,7 @@ def test_plot_multicomparison(self): TestWidgets.setUpClass() mytest = TestWidgets() - mytest.test_plot_unit_waveforms_density_map() + # mytest.test_plot_unit_waveforms_density_map() # mytest.test_plot_unit_summary() # mytest.test_plot_all_amplitudes_distributions() # mytest.test_plot_traces() @@ -598,7 +598,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_unit_presence() # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() - # mytest.test_plot_sorting_summary() - plt.show() + mytest.test_plot_sorting_summary() + # plt.show() # TestWidgets.tearDownClass() From 57608e3d26a7bd1c2405d5ffaf10aedfbb48b7f3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 16:52:58 +0100 Subject: [PATCH 142/192] example rename --- .../{plot_4_waveform_extractor.py => plot_4_sorting_analyzer.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/modules_gallery/core/{plot_4_waveform_extractor.py => plot_4_sorting_analyzer.py} (100%) diff --git a/examples/modules_gallery/core/plot_4_waveform_extractor.py b/examples/modules_gallery/core/plot_4_sorting_analyzer.py similarity index 100% rename from examples/modules_gallery/core/plot_4_waveform_extractor.py rename to examples/modules_gallery/core/plot_4_sorting_analyzer.py From b385693bf516c994a31ded48a7a9baf53f437169 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 19:24:11 +0100 Subject: [PATCH 143/192] wip examples --- .../core/plot_4_sorting_analyzer.py | 236 ++++++------------ src/spikeinterface/core/sortinganalyzer.py | 4 +- 2 files changed, 85 insertions(+), 155 deletions(-) diff --git a/examples/modules_gallery/core/plot_4_sorting_analyzer.py b/examples/modules_gallery/core/plot_4_sorting_analyzer.py index bee8f4061b..593b1103e1 100644 --- a/examples/modules_gallery/core/plot_4_sorting_analyzer.py +++ b/examples/modules_gallery/core/plot_4_sorting_analyzer.py @@ -1,23 +1,33 @@ ''' -Waveform Extractor -================== +SortingAnalyzer +=============== -SpikeInterface provides an efficient mechanism to extract waveform snippets. +SpikeInterface provides an object to gather a Recording and a Sorting to make +analyzer and visualization of the sorting : :py:class:`~spikeinterface.core.SortingAnalyzer`. -The :py:class:`~spikeinterface.core.WaveformExtractor` class: +This :py:class:`~spikeinterface.core.SortingAnalyzer` class: + + * is the first step for all post post processing, quality metrics, and visualization. + * gather a recording and a sorting + * can be sparse or dense : all channel are used for all units or not. + * handle a list of "extensions" + * "core extensions" are the one to extract some waveforms to compute templates: + * "random_spikes" : select randomly a subset of spikes per unit + * "waveforms" : extract waveforms per unit + * "templates": compute template using average or median + * "noise_levels" : compute noise level from traces (usefull to get snr of units) + * can be in memory or persistent to disk (2 formats binary/npy or zarr) + +More extesions are available in `spikeinterface.postprocessing` like "principal_components", "spike_amplitudes", +"unit_lcations", ... - * randomly samples a subset spikes with max_spikes_per_unit - * extracts all waveforms snippets for each unit - * saves waveforms in a local folder - * can load stored waveforms - * retrieves template (average or median waveform) for each unit Here the how! ''' import matplotlib.pyplot as plt from spikeinterface import download_dataset -from spikeinterface import WaveformExtractor, extract_waveforms +from spikeinterface import create_sorting_analyzer, load_sorting_analyzer import spikeinterface.extractors as se ############################################################################## @@ -48,184 +58,102 @@ plot_probe(probe) ############################################################################### -# A :py:class:`~spikeinterface.core.WaveformExtractor` object can be created with the -# :py:func:`~spikeinterface.core.extract_waveforms` function (this defaults to a sparse -# representation of the waveforms): - -folder = 'waveform_folder' -we = extract_waveforms( - recording, - sorting, - folder, - ms_before=1.5, - ms_after=2., - max_spikes_per_unit=500, - overwrite=True -) -print(we) +# A :py:class:`~spikeinterface.core.SortingAnalyzer` object can be created with the +# :py:func:`~spikeinterface.core.create_sorting_analyzer` function (this defaults to a sparse +# representation of the waveforms) +# Here the format is "memory". + +analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory") +print(analyzer) ############################################################################### -# Alternatively, the :py:class:`~spikeinterface.core.WaveformExtractor` object can be instantiated -# directly. In this case, we need to :py:func:`~spikeinterface.core.WaveformExtractor.set_params` to set the desired -# parameters: +# A :py:class:`~spikeinterface.core.SortingAnalyzer` object can be persistane to disk +# when using format="binary_folder" or format="zarr" -folder = 'waveform_folder2' -we = WaveformExtractor.create(recording, sorting, folder, remove_if_exists=True) -we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=1000) -we.run_extract_waveforms(n_jobs=1, chunk_size=30000, progress_bar=True) -print(we) +folder = "analyzer_folder" +analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="binary_folder", folder=folder) +print(analyzer) +# then it can be load back +analyzer = load_sorting_analyzer(folder) +print(analyzer) ############################################################################### -# To speed up computation, waveforms can also be extracted using parallel -# processing (recommended!). We can define some :code:`'job_kwargs'` to pass -# to the function as extra arguments: - -job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) - -folder = 'waveform_folder_parallel' -we = extract_waveforms( - recording, - sorting, - folder, - sparse=False, - ms_before=3., - ms_after=4., - max_spikes_per_unit=500, - overwrite=True, - **job_kwargs -) -print(we) +# No extension are computed yet. +# Lets compute the most basic ones : select some random spikes per units, +# extract waveforms (sparse in this examples) and compute templates. +# You can see that printing the object indicate which extension are computed yet. +analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500,) +analyzer.compute("waveforms", ms_before=1.0, ms_after=2.0, return_scaled=True) +analyzer.compute("templates", operators=["average", "median", "std"]) +print(analyzer) -############################################################################### -# The :code:`'waveform_folder'` folder contains: -# * the dumped recording (json) -# * the dumped sorting (json) -# * the parameters (json) -# * a subfolder with "waveforms_XXX.npy" and "sampled_index_XXX.npy" - -import os -print(os.listdir(folder)) -print(os.listdir(folder + '/waveforms')) ############################################################################### -# Now we can retrieve waveforms per unit on-the-fly. The waveforms shape -# is (num_spikes, num_sample, num_channel): - -unit_ids = sorting.unit_ids +# To speed up computation, some steps like ""waveforms" can also be extracted +# using parallel processing (recommended!). Like this -for unit_id in unit_ids: - wfs = we.get_waveforms(unit_id) - print(unit_id, ':', wfs.shape) - -############################################################################### -# We can also get the template for each units either using the median or the -# average: +analyzer.compute("waveforms", ms_before=1.0, ms_after=2.0, return_scaled=True, + n_jobs=8, chunk_duration="1s", progress_bar=True) -for unit_id in unit_ids[:3]: - fig, ax = plt.subplots() - template = we.get_template(unit_id=unit_id, mode='median') - print(template.shape) - ax.plot(template) - ax.set_title(f'{unit_id}') +# which is equivalent of this +job_kwargs = dict(n_jobs=8, chunk_duration="1s", progress_bar=True) +analyzer.compute("waveforms", ms_before=1.0, ms_after=2.0, return_scaled=True, **job_kwargs) ############################################################################### -# Or retrieve templates for all units at once: +# Each extension can retrieve some data +# For instance "waveforms" extension can retrieve wavfroms per units +# which is a numpy array of shape (num_spikes, num_sample, num_channel): -all_templates = we.get_all_templates() -print(all_templates.shape) - - -''' -Sparse Waveform Extractor -------------------------- +ext_wf = analyzer.get_extension("waveforms") +for unit_id in analyzer.unit_ids: + wfs = ext_wf.get_waveforms_one_unit(unit_id) + print(unit_id, ':', wfs.shape) -''' ############################################################################### -# For high-density probes, such as Neuropixels, we may want to work with sparse -# waveforms, i.e., waveforms computed on a subset of channels. To do so, we -# two options. -# -# Option 1) Save a dense waveform extractor to sparse: -# -# In this case, from an existing (dense) waveform extractor, we can first estimate a -# sparsity (which channels each unit is defined on) and then save to a new -# folder in sparse mode: - -from spikeinterface import compute_sparsity +# Same for the "templates" extension. Here we can get all templates at once +# with shape (num_units, num_sample, num_channel): +# For this extension, we can get the template for all units either using the median +# or the average -# define sparsity within a radius of 40um -sparsity = compute_sparsity(we, method="radius", radius_um=40) -print(sparsity) +ext_templates = analyzer.get_extension("templates") -# save sparse waveforms -folder = 'waveform_folder_sparse' -we_sparse = we.save(folder=folder, sparsity=sparsity, overwrite=True) +av_templates = ext_templates.get_data(operator="average") +print(av_templates.shape) -# we_sparse is a sparse WaveformExtractor -print(we_sparse) +median_templates = ext_templates.get_data(operator="median") +print(median_templates.shape) -wf_full = we.get_waveforms(we.sorting.unit_ids[0]) -print(f"Dense waveforms shape for unit {we.sorting.unit_ids[0]}: {wf_full.shape}") -wf_sparse = we_sparse.get_waveforms(we.sorting.unit_ids[0]) -print(f"Sparse waveforms shape for unit {we.sorting.unit_ids[0]}: {wf_sparse.shape}") ############################################################################### -# Option 2) Directly extract sparse waveforms (current spikeinterface default): -# -# We can also directly extract sparse waveforms. To do so, dense waveforms are -# extracted first using a small number of spikes (:code:`'num_spikes_for_sparsity'`) - -folder = 'waveform_folder_sparse_direct' -we_sparse_direct = extract_waveforms( - recording, - sorting, - folder, - ms_before=3., - ms_after=4., - max_spikes_per_unit=500, - overwrite=True, - sparse=True, - num_spikes_for_sparsity=100, - method="radius", - radius_um=40, - **job_kwargs -) -print(we_sparse_direct) - -template_full = we.get_template(we.sorting.unit_ids[0]) -print(f"Dense template shape for unit {we.sorting.unit_ids[0]}: {template_full.shape}") -template_sparse = we_sparse_direct.get_template(we.sorting.unit_ids[0]) -print(f"Sparse template shape for unit {we.sorting.unit_ids[0]}: {template_sparse.shape}") +# This can be plot easily. + +for unit_index, unit_id in enumerate(analyzer.unit_ids[:3]): + fig, ax = plt.subplots() + template = av_templates[unit_index] + ax.plot(template) + ax.set_title(f'{unit_id}') ############################################################################### -# As shown above, when retrieving waveforms/template for a unit from a sparse -# :code:`'WaveformExtractor'`, the waveforms are returned on a subset of channels. -# To retrieve which channels each unit is associated with, we can use the sparsity -# object: +# The SortingAnalyzer can be saved as to another format using save_as() +# So the computation can be done with format="memory" and -# retrive channel ids for first unit: -unit_ids = we_sparse.unit_ids -channel_ids_0 = we_sparse.sparsity.unit_id_to_channel_ids[unit_ids[0]] -print(f"Channel ids associated to {unit_ids[0]}: {channel_ids_0}") +analyzer.save_as(folder="analyzer.zarr", format="zarr") ############################################################################### -# However, when retrieving all templates, a dense shape is returned. This is -# because different channels might have a different number of sparse channels! -# In this case, values on channels not belonging to a unit are filled with 0s. +# The SortingAnalyzer offer also select_units() method wich allows to export +# only some relevant units for instance to a new SortingAnalyzer instance. + +analyzer_some_units = analyzer.select_units(unit_ids=analyzer.unit_ids[:5], + format="binary_folder", folder="analyzer_some_units") +print(analyzer_some_units) -all_sparse_templates = we_sparse.get_all_templates() -# this is a boolean mask with sparse channels for the 1st unit -mask0 = we_sparse.sparsity.mask[0] -# Let's plot values for the first 5 samples inside and outside sparsity mask -print("Values inside sparsity:\n", all_sparse_templates[0, :5, mask0]) -print("Values outside sparsity:\n", all_sparse_templates[0, :5, ~mask0]) plt.show() diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index e096d2bc8a..02d2f0e1d5 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -409,12 +409,14 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): # the recording rec_dict = recording.to_dict(relative_to=folder, recursive=True) - zarr_rec = np.array([rec_dict], dtype=object) + if recording.check_serializability("json"): # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) + zarr_rec = np.array([check_json(rec_dict)], dtype=object) zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) elif recording.check_serializability("pickle"): # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) + zarr_rec = np.array([rec_dict], dtype=object) zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) else: warnings.warn( From 9d60a1b0289bb090b85810a949c7e332b0495cd1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 20:36:14 +0100 Subject: [PATCH 144/192] wip analyzer in examples --- .../qualitymetrics/plot_3_quality_mertics.py | 58 ++++++++++--------- .../qualitymetrics/plot_4_curation.py | 37 +++++++----- src/spikeinterface/core/sortinganalyzer.py | 2 + 3 files changed, 53 insertions(+), 44 deletions(-) diff --git a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py index 7b2fa565b5..557fff229b 100644 --- a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py +++ b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py @@ -23,49 +23,51 @@ print(sorting) ############################################################################## -# Extract spike waveforms +# Create SortingAnalyzer # ----------------------- # -# For convenience, metrics are computed on the :code:`WaveformExtractor` object, -# because it contains a reference to the "Recording" and the "Sorting" objects: - -we = si.extract_waveforms(recording=recording, - sorting=sorting, - folder='waveforms_mearec', - sparse=False, - ms_before=1, - ms_after=2., - max_spikes_per_unit=500, - n_jobs=1, - chunk_durations='1s') -print(we) +# For quality metrics we need first to create a :code:`SortingAnalyzer`. + +analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording, format="memory") +print(analyzer) + +############################################################################## +# Depending on which metrics we want to compute we will need first to compute +# some necessary extensions. (if not computed an error message will be raised) + +analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=600, seed=2205) +analyzer.compute("waveforms",ms_before=1.3, ms_after=2.6, n_jobs=2) +analyzer.compute("templates", operators=["average", "median", "std"]) +analyzer.compute("noise_levels") + +print(analyzer) + ############################################################################## # The :code:`spikeinterface.qualitymetrics` submodule has a set of functions that allow users to compute # metrics in a compact and easy way. To compute a single metric, one can simply run one of the # quality metric functions as shown below. Each function has a variety of adjustable parameters that can be tuned. -firing_rates = compute_firing_rates(we) +firing_rates = compute_firing_rates(analyzer) print(firing_rates) -isi_violation_ratio, isi_violations_count = compute_isi_violations(we) +isi_violation_ratio, isi_violations_count = compute_isi_violations(analyzer) print(isi_violation_ratio) -snrs = compute_snrs(we) +snrs = compute_snrs(analyzer) print(snrs) -############################################################################## -# Some metrics are based on the principal component scores, so they require a -# :code:`WaveformsPrincipalComponent` object as input: - -pc = compute_principal_components(waveform_extractor=we, load_if_exists=True, - n_components=3, mode='by_channel_local') -print(pc) - -pc_metrics = calculate_pc_metrics(pc, metric_names=['nearest_neighbor']) -print(pc_metrics) ############################################################################## # To compute more than one metric at once, we can use the :code:`compute_quality_metrics` function and indicate # which metrics we want to compute. This will return a pandas dataframe: -metrics = compute_quality_metrics(we) +metrics = compute_quality_metrics(analyzer, metric_names=["firing_rate", "snr", "amplitude_cutoff"]) +print(metrics) + +############################################################################## +# Some metrics are based on the principal component scores, so the exwtension +# need to be computed before. For instance: + +analyzer.compute("principal_components", n_components=3, mode="by_channel_global", whiten=True) + +metrics = compute_quality_metrics(analyzer, metric_names=["isolation_distance", "d_prime",]) print(metrics) diff --git a/examples/modules_gallery/qualitymetrics/plot_4_curation.py b/examples/modules_gallery/qualitymetrics/plot_4_curation.py index 2568452de3..da379f0789 100644 --- a/examples/modules_gallery/qualitymetrics/plot_4_curation.py +++ b/examples/modules_gallery/qualitymetrics/plot_4_curation.py @@ -28,26 +28,24 @@ print(sorting) ############################################################################## -# First, we extract waveforms (to be saved in the folder 'wfs_mearec') and -# compute their PC (principal component) scores: +# Create SortingAnalyzer +# ----------------------- +# +# For this example, we will need a :code:`SortingAnalyzer` and some extension +# to be computed fist + -we = si.extract_waveforms(recording=recording, - sorting=sorting, - folder='wfs_mearec', - ms_before=1, - ms_after=2., - max_spikes_per_unit=500, - n_jobs=1, - chunk_size=30000) -print(we) +analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording, format="memory") +analyzer.compute(["random_spikes", "waveforms", "templates", "noise_levels"]) -pc = compute_principal_components(we, load_if_exists=True, n_components=3, mode='by_channel_local') +analyzer.compute("principal_components", n_components=3, mode='by_channel_local') +print(analyzer) ############################################################################## # Then we compute some quality metrics: -metrics = compute_quality_metrics(waveform_extractor=we, metric_names=['snr', 'isi_violation', 'nearest_neighbor']) +metrics = compute_quality_metrics(analyzer, metric_names=['snr', 'isi_violation', 'nearest_neighbor']) print(metrics) ############################################################################## @@ -65,10 +63,17 @@ print(keep_unit_ids) ############################################################################## -# And now let's create a sorting that contains only curated units and save it, -# for example to an NPZ file. +# And now let's create a sorting that contains only curated units and save it. curated_sorting = sorting.select_units(keep_unit_ids) print(curated_sorting) -se.NpzSortingExtractor.write_sorting(sorting=curated_sorting, save_path='curated_sorting.npz') + +curated_sorting.save(folder='curated_sorting') + +############################################################################## +# We can also save the analyzer with only theses units + +clean_analyzer = analyzer.select_units(unit_ids=keep_unit_ids, format="zarr", folder="clean_analyzer") + +print(clean_analyzer) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 02d2f0e1d5..9af87c8b48 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -562,6 +562,8 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" folder = Path(folder) + if folder.suffix != ".zarr": + folder = folder.parent / f"{folder.stem}.zarr" SortingAnalyzer.create_zarr(folder, sorting_provenance, recording, sparsity, self.rec_attributes) new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) new_sorting_analyzer.folder = folder From 31c0ffac06beeb973bc55e440f649fb6a7aafcf4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 26 Feb 2024 21:38:01 +0100 Subject: [PATCH 145/192] ResultExtension > AnalyzerExtension --- doc/api.rst | 7 ++-- .../widgets/plot_3_waveforms_gallery.py | 35 ++++++++----------- src/spikeinterface/core/__init__.py | 4 +-- .../core/analyzer_extension_core.py | 24 ++++++------- src/spikeinterface/core/sortinganalyzer.py | 20 +++++------ .../core/tests/test_sortinganalyzer.py | 16 ++++----- .../postprocessing/amplitude_scalings.py | 6 ++-- .../postprocessing/correlograms.py | 6 ++-- src/spikeinterface/postprocessing/isi.py | 6 ++-- .../postprocessing/principal_component.py | 6 ++-- .../postprocessing/spike_amplitudes.py | 8 ++--- .../postprocessing/spike_locations.py | 6 ++-- .../postprocessing/template_metrics.py | 4 +-- .../postprocessing/template_similarity.py | 6 ++-- .../tests/common_extension_tests.py | 2 +- .../tests/test_amplitude_scalings.py | 4 +-- .../postprocessing/tests/test_correlograms.py | 4 +-- .../postprocessing/tests/test_isi.py | 4 +-- .../tests/test_principal_component.py | 4 +-- .../tests/test_spike_amplitudes.py | 4 +-- .../tests/test_spike_locations.py | 4 +-- .../tests/test_template_metrics.py | 4 +-- .../tests/test_template_similarity.py | 4 +-- .../tests/test_unit_localization.py | 4 +-- .../postprocessing/unit_localization.py | 6 ++-- .../quality_metric_calculator.py | 4 +-- 26 files changed, 99 insertions(+), 103 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 62ce3f889f..a7476cd62f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -14,11 +14,12 @@ spikeinterface.core :members: .. autoclass:: BaseEvent :members: - .. autoclass:: WaveformExtractor + .. autoclass:: SortingAnalyzer :members: - .. autofunction:: extract_waveforms - .. autofunction:: load_waveforms + .. autofunction:: create_sorting_analyzer + .. autofunction:: load_sorting_analyzer .. autofunction:: compute_sparsity + .. autofunction:: estimate_sparsity .. autoclass:: ChannelSparsity :members: .. autoclass:: BinaryRecordingExtractor diff --git a/examples/modules_gallery/widgets/plot_3_waveforms_gallery.py b/examples/modules_gallery/widgets/plot_3_waveforms_gallery.py index 1bc4d0afd7..a02b62bbd8 100644 --- a/examples/modules_gallery/widgets/plot_3_waveforms_gallery.py +++ b/examples/modules_gallery/widgets/plot_3_waveforms_gallery.py @@ -16,8 +16,7 @@ # from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') -recording = se.MEArecRecordingExtractor(local_path) -sorting = se.MEArecSortingExtractor(local_path) +recording, sorting = si.read_mearec(local_path) print(recording) print(sorting) @@ -28,17 +27,13 @@ # For convenience, metrics are computed on the WaveformExtractor object that gather recording/sorting and # extracted waveforms in a single object -folder = 'waveforms_mearec' -we = si.extract_waveforms(recording, sorting, folder, - load_if_exists=True, - ms_before=1, ms_after=2., max_spikes_per_unit=500, - n_jobs=1, chunk_size=30000) -# pre-compute postprocessing data -_ = spost.compute_spike_amplitudes(we) -_ = spost.compute_unit_locations(we) -_ = spost.compute_spike_locations(we) -_ = spost.compute_template_metrics(we) +analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording, format="memory") +# core extensions +analyzer.compute(["random_spikes", "waveforms", "templates", "noise_levels"]) + +# more extensions +analyzer.compute(["spike_amplitudes", "unit_locations", "spike_locations", "template_metrics"]) ############################################################################## @@ -47,7 +42,7 @@ unit_ids = sorting.unit_ids[:4] -sw.plot_unit_waveforms(we, unit_ids=unit_ids, figsize=(16,4)) +sw.plot_unit_waveforms(analyzer, unit_ids=unit_ids, figsize=(16,4)) ############################################################################## # plot_unit_templates() @@ -55,21 +50,21 @@ unit_ids = sorting.unit_ids -sw.plot_unit_templates(we, unit_ids=unit_ids, ncols=5, figsize=(16,8)) +sw.plot_unit_templates(analyzer, unit_ids=unit_ids, ncols=5, figsize=(16,8)) ############################################################################## # plot_amplitudes() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -sw.plot_amplitudes(we, plot_histograms=True, figsize=(12,8)) +sw.plot_amplitudes(analyzer, plot_histograms=True, figsize=(12,8)) ############################################################################## # plot_unit_locations() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -sw.plot_unit_locations(we, figsize=(4,8)) +sw.plot_unit_locations(analyzer, figsize=(4,8)) ############################################################################## @@ -79,7 +74,7 @@ # This is your best friend to check over merge unit_ids = sorting.unit_ids[:4] -sw.plot_unit_waveforms_density_map(we, unit_ids=unit_ids, figsize=(14,8)) +sw.plot_unit_waveforms_density_map(analyzer, unit_ids=unit_ids, figsize=(14,8)) @@ -87,13 +82,13 @@ # plot_amplitudes_distribution() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -sw.plot_all_amplitudes_distributions(we, figsize=(10,10)) +sw.plot_all_amplitudes_distributions(analyzer, figsize=(10,10)) ############################################################################## # plot_units_depths() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -sw.plot_unit_depths(we, figsize=(10,10)) +sw.plot_unit_depths(analyzer, figsize=(10,10)) ############################################################################## @@ -101,7 +96,7 @@ # ~~~~~~~~~~~~~~~~~~~~~ unit_ids = sorting.unit_ids[:4] -sw.plot_unit_probe_map(we, unit_ids=unit_ids, figsize=(20,8)) +sw.plot_unit_probe_map(analyzer, unit_ids=unit_ids, figsize=(20,8)) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index f5a712f247..b18676af92 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -143,8 +143,8 @@ from .template import Templates -# SortingAnalyzer and ResultExtension -from .sortinganalyzer import SortingAnalyzer, ResultExtension, create_sorting_analyzer, load_sorting_analyzer +# SortingAnalyzer and AnalyzerExtension +from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, create_sorting_analyzer, load_sorting_analyzer from .analyzer_extension_core import ( ComputeWaveforms, compute_waveforms, diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 0d482f5369..79fe6ab600 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -1,5 +1,5 @@ """ -Implement ResultExtension that are essential and imported in core +Implement AnalyzerExtension that are essential and imported in core * ComputeWaveforms * ComputeTemplates Theses two classes replace the WaveformExtractor @@ -11,16 +11,16 @@ import numpy as np -from .sortinganalyzer import ResultExtension, register_result_extension +from .sortinganalyzer import AnalyzerExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_average from .recording_tools import get_noise_levels from .template import Templates from .sorting_tools import random_spikes_selection -class SelectRandomSpikes(ResultExtension): +class SelectRandomSpikes(AnalyzerExtension): """ - ResultExtension that select some random spikes. + AnalyzerExtension that select some random spikes. This will be used by "compute_waveforms" and so "compute_templates" or "compute_fast_templates" @@ -113,9 +113,9 @@ def get_selected_indices_in_spike_train(self, unit_id, segment_index): -class ComputeWaveforms(ResultExtension): +class ComputeWaveforms(AnalyzerExtension): """ - ResultExtension that extract some waveforms of each units. + AnalyzerExtension that extract some waveforms of each units. The sparsity is controlled by the SortingAnalyzer sparsity. """ @@ -260,9 +260,9 @@ def _get_data(self): register_result_extension(ComputeWaveforms) -class ComputeTemplates(ResultExtension): +class ComputeTemplates(AnalyzerExtension): """ - ResultExtension that compute templates (average, str, median, percentile, ...) + AnalyzerExtension that compute templates (average, str, median, percentile, ...) This must be run after "waveforms" extension (`SortingAnalyzer.compute("waveforms")`) @@ -445,9 +445,9 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save register_result_extension(ComputeTemplates) -class ComputeFastTemplates(ResultExtension): +class ComputeFastTemplates(AnalyzerExtension): """ - ResultExtension which is similar to the extension "templates" (ComputeTemplates) **but only for average**. + AnalyzerExtension which is similar to the extension "templates" (ComputeTemplates) **but only for average**. This is way faster because it do not need "waveforms" to be computed first. """ @@ -530,7 +530,7 @@ def _select_extension_data(self, unit_ids): register_result_extension(ComputeFastTemplates) -class ComputeNoiseLevels(ResultExtension): +class ComputeNoiseLevels(AnalyzerExtension): """ Computes the noise level associated to each recording channel. @@ -561,7 +561,7 @@ class ComputeNoiseLevels(ResultExtension): need_job_kwargs = False def __init__(self, sorting_analyzer): - ResultExtension.__init__(self, sorting_analyzer) + AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, return_scaled=True, seed=None): params = dict( diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 9af87c8b48..2ba9737ee5 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -36,7 +36,7 @@ def create_sorting_analyzer( """ Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. - This object will handle a list of ResultExtension for all the post processing steps like: waveforms, + This object will handle a list of AnalyzerExtension for all the post processing steps like: waveforms, templates, unit locations, spike locations, quality metrics ... This object will be also use used for plotting purpose. @@ -785,7 +785,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): Returns ------- - result_extension: ResultExtension + result_extension: AnalyzerExtension Return the extension instance. Examples @@ -938,7 +938,7 @@ def get_saved_extension_names(self): def get_extension(self, extension_name: str): """ - Get a ResultExtension. + Get a AnalyzerExtension. If not loaded then load is automatic. Return None if the extension is not computed yet (this avoids the use of has_extension() and then get it) @@ -1040,7 +1040,7 @@ def register_result_extension(extension_class): import spikeinterface.postprocessing more extensions will be available """ - assert issubclass(extension_class, ResultExtension) + assert issubclass(extension_class, AnalyzerExtension) assert extension_class.extension_name is not None, "extension_name must not be None" global _possible_extensions @@ -1076,7 +1076,7 @@ def get_extension_class(extension_name: str): return ext_class -class ResultExtension: +class AnalyzerExtension: """ This the base class to extend the SortingAnalyzer. It can handle persistency to disk for any computations related to: @@ -1106,7 +1106,7 @@ class ResultExtension: The subclass must also hanle an attribute `data` which is a dict contain the results after the `run()`. - All ResultExtension will have a function associate for instance (this use the function_factory): + All AnalyzerExtension will have a function associate for instance (this use the function_factory): compute_unit_location(sorting_analyzer, ...) will be equivalent to sorting_analyzer.compute("unit_location", ...) @@ -1127,7 +1127,7 @@ def __init__(self, sorting_analyzer): ####### # This 3 methods must be implemented in the subclass!!! - # See DummyResultExtension in test_sortinganalyzer.py as a simple example + # See DummyAnalyzerExtension in test_sortinganalyzer.py as a simple example def _run(self, **kwargs): # must be implemented in subclass # must populate the self.data dictionary @@ -1193,8 +1193,8 @@ def __call__(self, sorting_analyzer, load_if_exists=None, *args, **kwargs): @property def sorting_analyzer(self): - # Important : to avoid the SortingAnalyzer referencing a ResultExtension - # and ResultExtension referencing a SortingAnalyzer we need a weakref. + # Important : to avoid the SortingAnalyzer referencing a AnalyzerExtension + # and AnalyzerExtension referencing a SortingAnalyzer we need a weakref. # Otherwise the garbage collector is not working properly. # and so the SortingAnalyzer + its recording are still alive even after deleting explicitly # the SortingAnalyzer which makes it impossible to delete the folder when using memmap. @@ -1455,7 +1455,7 @@ def _save_params(self): def get_pipeline_nodes(self): assert ( self.use_nodepipeline - ), "ResultExtension.get_pipeline_nodes() must be called only when use_nodepipeline=True" + ), "AnalyzerExtension.get_pipeline_nodes() must be called only when use_nodepipeline=True" return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 03c18c2f43..e0b3cfc31b 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -5,7 +5,7 @@ from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import SortingAnalyzer, create_sorting_analyzer, load_sorting_analyzer -from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension import numpy as np @@ -70,7 +70,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting): print() print(sorting_analyzer) - register_result_extension(DummyResultExtension) + register_result_extension(DummyAnalyzerExtension) assert "channel_ids" in sorting_analyzer.rec_attributes assert "sampling_frequency" in sorting_analyzer.rec_attributes @@ -143,7 +143,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting): assert np.all(~np.isin(data["result_two"], [1, 3])) -class DummyResultExtension(ResultExtension): +class DummyAnalyzerExtension(AnalyzerExtension): extension_name = "dummy" depend_on = [] need_recording = False @@ -179,21 +179,21 @@ def _get_data(self): return self.data["result_one"] -compute_dummy = DummyResultExtension.function_factory() +compute_dummy = DummyAnalyzerExtension.function_factory() -class DummyResultExtension2(ResultExtension): +class DummyAnalyzerExtension2(AnalyzerExtension): extension_name = "dummy" def test_extension(): - register_result_extension(DummyResultExtension) + register_result_extension(DummyAnalyzerExtension) # can be register twice without error - register_result_extension(DummyResultExtension) + register_result_extension(DummyAnalyzerExtension) # other extension with same name should trigger an error with pytest.raises(AssertionError): - register_result_extension(DummyResultExtension2) + register_result_extension(DummyAnalyzerExtension2) if __name__ == "__main__": diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index a02c437483..9d4e766f4b 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -7,7 +7,7 @@ from spikeinterface.core.template_tools import get_template_extremum_channel -from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type @@ -19,7 +19,7 @@ # TODO extra sparsity and job_kwargs handling -class ComputeAmplitudeScalings(ResultExtension): +class ComputeAmplitudeScalings(AnalyzerExtension): """ Computes the amplitude scalings from a SortingAnalyzer. @@ -69,7 +69,7 @@ class ComputeAmplitudeScalings(ResultExtension): need_job_kwargs = True def __init__(self, sorting_analyzer): - ResultExtension.__init__(self, sorting_analyzer) + AnalyzerExtension.__init__(self, sorting_analyzer) self.collisions = None diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index bf6f8585d9..f826cf9e8d 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension, SortingAnalyzer +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension, SortingAnalyzer try: import numba @@ -12,7 +12,7 @@ HAVE_NUMBA = False -class ComputeCorrelograms(ResultExtension): +class ComputeCorrelograms(AnalyzerExtension): """ Compute auto and cross correlograms. @@ -53,7 +53,7 @@ class ComputeCorrelograms(ResultExtension): need_job_kwargs = False def __init__(self, sorting_analyzer): - ResultExtension.__init__(self, sorting_analyzer) + AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index a99a677d65..22aee972b9 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension try: import numba @@ -12,7 +12,7 @@ HAVE_NUMBA = False -class ComputeISIHistograms(ResultExtension): +class ComputeISIHistograms(AnalyzerExtension): """Compute ISI histograms. Parameters @@ -41,7 +41,7 @@ class ComputeISIHistograms(ResultExtension): need_job_kwargs = False def __init__(self, sorting_analyzer): - ResultExtension.__init__(self, sorting_analyzer) + AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 4b9ab023cb..fb3a367f9b 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -9,14 +9,14 @@ import numpy as np -from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs _possible_modes = ["by_channel_local", "by_channel_global", "concatenated"] -class ComputePrincipalComponents(ResultExtension): +class ComputePrincipalComponents(AnalyzerExtension): """ Compute PC scores from waveform extractor. The PCA projections are pre-computed only on the sampled waveforms available from the extensions "waveforms". @@ -67,7 +67,7 @@ class ComputePrincipalComponents(ResultExtension): need_job_kwargs = True def __init__(self, sorting_analyzer): - ResultExtension.__init__(self, sorting_analyzer) + AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params( self, diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 68894e3646..30aeca4b4b 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -7,14 +7,14 @@ from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift -from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type from spikeinterface.core.sorting_tools import spike_vector_to_indices -class ComputeSpikeAmplitudes(ResultExtension): +class ComputeSpikeAmplitudes(AnalyzerExtension): """ - ResultExtension + AnalyzerExtension Computes the spike amplitudes. Need "templates" or "fast_templates" to be computed first. @@ -64,7 +64,7 @@ class ComputeSpikeAmplitudes(ResultExtension): need_job_kwargs = True def __init__(self, sorting_analyzer): - ResultExtension.__init__(self, sorting_analyzer) + AnalyzerExtension.__init__(self, sorting_analyzer) self._all_spikes = None diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 76602e2763..f5b6ca4fdc 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -3,7 +3,7 @@ import numpy as np from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs -from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sorting_tools import spike_vector_to_indices @@ -11,7 +11,7 @@ from spikeinterface.core.node_pipeline import SpikeRetriever, run_node_pipeline -class ComputeSpikeLocations(ResultExtension): +class ComputeSpikeLocations(AnalyzerExtension): """ Localize spikes in 2D or 3D with several methods given the template. @@ -59,7 +59,7 @@ class ComputeSpikeLocations(ResultExtension): need_job_kwargs = True def __init__(self, sorting_analyzer): - ResultExtension.__init__(self, sorting_analyzer) + AnalyzerExtension.__init__(self, sorting_analyzer) extremum_channel_inds = get_template_extremum_channel(self.sorting_analyzer, outputs="index") self.spikes = self.sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index e4fd456107..9d57e5364d 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -11,7 +11,7 @@ from typing import Optional from copy import deepcopy -from ..core.sortinganalyzer import register_result_extension, ResultExtension +from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.template_tools import _get_dense_templates_array @@ -31,7 +31,7 @@ def get_template_metric_names(): return get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() -class ComputeTemplateMetrics(ResultExtension): +class ComputeTemplateMetrics(AnalyzerExtension): """ Compute template metrics including: * peak_to_valley diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 99f804b124..18d7c868da 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -2,11 +2,11 @@ import numpy as np -from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from ..core.template_tools import _get_dense_templates_array -class ComputeTemplateSimilarity(ResultExtension): +class ComputeTemplateSimilarity(AnalyzerExtension): """Compute similarity between templates with several methods. @@ -32,7 +32,7 @@ class ComputeTemplateSimilarity(ResultExtension): need_job_kwargs = False def __init__(self, sorting_analyzer): - ResultExtension.__init__(self, sorting_analyzer) + AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params(self, method="cosine_similarity"): params = dict(method=method) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index f7ab30bfec..29c1d0d499 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -61,7 +61,7 @@ def get_sorting_analyzer(recording, sorting, format="memory", sparsity=None, nam return sorting_analyzer -class ResultExtensionCommonTestSuite: +class AnalyzerExtensionCommonTestSuite: """ Common tests with class approach to compute extension on several cases (3 format x 2 sparsity) diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index 40034b7363..b59aca16a8 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -2,12 +2,12 @@ import numpy as np -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeAmplitudeScalings -class AmplitudeScalingsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): +class AmplitudeScalingsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeAmplitudeScalings extension_function_params_list = [ dict(handle_collisions=True), diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index b9fbde18f8..6d727e6448 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -11,12 +11,12 @@ from spikeinterface import NumpySorting, generate_sorting -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeCorrelograms from spikeinterface.postprocessing.correlograms import compute_correlograms_on_sorting, _make_bins -class ComputeCorrelogramsTest(ResultExtensionCommonTestSuite, unittest.TestCase): +class ComputeCorrelogramsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeCorrelograms extension_function_params_list = [ dict(method="numpy"), diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 89ed1257bc..8626e56453 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -3,7 +3,7 @@ from typing import List -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import compute_isi_histograms, ComputeISIHistograms from spikeinterface.postprocessing.isi import _compute_isi_histograms @@ -16,7 +16,7 @@ HAVE_NUMBA = False -class ComputeISIHistogramsTest(ResultExtensionCommonTestSuite, unittest.TestCase): +class ComputeISIHistogramsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeISIHistograms extension_function_params_list = [ dict(method="numpy"), diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index c4f378c295..d94d7ea586 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -5,13 +5,13 @@ import numpy as np from spikeinterface.postprocessing import ComputePrincipalComponents, compute_principal_components -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite, cache_folder +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite, cache_folder DEBUG = False -class PrincipalComponentsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): +class PrincipalComponentsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputePrincipalComponents extension_function_params_list = [ dict(mode="by_channel_local"), diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index bee4816e80..e02c981774 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -2,10 +2,10 @@ import numpy as np from spikeinterface.postprocessing import ComputeSpikeAmplitudes -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -class ComputeSpikeAmplitudesTest(ResultExtensionCommonTestSuite, unittest.TestCase): +class ComputeSpikeAmplitudesTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeSpikeAmplitudes extension_function_params_list = [ dict(return_scaled=True), diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index c1f49bc849..d48ff3d84b 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -2,10 +2,10 @@ import numpy as np from spikeinterface.postprocessing import ComputeSpikeLocations -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -class SpikeLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): +class SpikeLocationsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeSpikeLocations extension_function_params_list = [ dict( diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 5954db646a..360f0f379f 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,11 +1,11 @@ import unittest -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeTemplateMetrics -class TemplateMetricsTest(ResultExtensionCommonTestSuite, unittest.TestCase): +class TemplateMetricsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeTemplateMetrics extension_function_params_list = [ dict(), diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 746c45da08..534c909592 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -1,7 +1,7 @@ import unittest from spikeinterface.postprocessing.tests.common_extension_tests import ( - ResultExtensionCommonTestSuite, + AnalyzerExtensionCommonTestSuite, get_sorting_analyzer, get_dataset, ) @@ -9,7 +9,7 @@ from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity -class SimilarityExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): +class SimilarityExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeTemplateSimilarity extension_function_params_list = [ dict(method="cosine_similarity"), diff --git a/src/spikeinterface/postprocessing/tests/test_unit_localization.py b/src/spikeinterface/postprocessing/tests/test_unit_localization.py index a46c743cb5..b23adf5868 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_localization.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_localization.py @@ -1,9 +1,9 @@ import unittest -from spikeinterface.postprocessing.tests.common_extension_tests import ResultExtensionCommonTestSuite +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeUnitLocations -class UnitLocationsExtensionTest(ResultExtensionCommonTestSuite, unittest.TestCase): +class UnitLocationsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): extension_class = ComputeUnitLocations extension_function_params_list = [ dict(method="center_of_mass", radius_um=100), diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index d0f9830091..eabec5e610 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -12,7 +12,7 @@ except ImportError: HAVE_NUMBA = False -from ..core.sortinganalyzer import register_result_extension, ResultExtension +from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension from ..core import compute_sparsity from ..core.template_tools import get_template_extremum_channel, _get_nbefore, _get_dense_templates_array @@ -27,7 +27,7 @@ possible_localization_methods = list(dtype_localize_by_method.keys()) -class ComputeUnitLocations(ResultExtension): +class ComputeUnitLocations(AnalyzerExtension): """ Localize units in 2D or 3D with several methods given the template. @@ -57,7 +57,7 @@ class ComputeUnitLocations(ResultExtension): need_job_kwargs = False def __init__(self, sorting_analyzer): - ResultExtension.__init__(self, sorting_analyzer) + AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params(self, method="monopolar_triangulation", **method_kwargs): params = dict(method=method, method_kwargs=method_kwargs) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index fb32280a3b..5deb08a7d2 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -9,7 +9,7 @@ import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.core.sortinganalyzer import register_result_extension, ResultExtension +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from .quality_metric_list import calculate_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names @@ -17,7 +17,7 @@ from .pca_metrics import _default_params as pca_metrics_params -class ComputeQualityMetrics(ResultExtension): +class ComputeQualityMetrics(AnalyzerExtension): """ Compute quality metrics on sorting_. From 9f3668efb080ceb3d05dd0a0e60ff266de195fe5 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 26 Feb 2024 22:32:35 +0100 Subject: [PATCH 146/192] Fix test peak localization --- .../benchmark/benchmark_peak_localization.py | 21 +++++++++---------- .../tests/test_benchmark_peak_localization.py | 4 ++-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index 61e1ce2098..18b988953e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -88,7 +88,7 @@ def create_benchmark(self, key): benchmark = PeakLocalizationBenchmark(recording, gt_sorting, params, **init_kwargs) return benchmark - def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): + def plot_comparison_positions(self, case_keys=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -110,16 +110,15 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): zdx = np.argsort(distances_to_center) idx = np.argsort(norms) - from scipy.signal import savgol_filter wdx = np.argsort(snrs) data = result["medians_over_templates"] axs[0].plot( - snrs[wdx], savgol_filter(data[wdx], smoothing_factor, 3), lw=2, label=self.cases[key]['label'] + snrs[wdx], data[wdx], lw=2, label=self.cases[key]['label'] ) - ymin = savgol_filter((data - result["mads_over_templates"])[wdx], smoothing_factor, 3) - ymax = savgol_filter((data + result["mads_over_templates"])[wdx], smoothing_factor, 3) + ymin = (data - result["mads_over_templates"])[wdx] + ymax = (data + result["mads_over_templates"])[wdx] axs[0].fill_between(snrs[wdx], ymin, ymax, alpha=0.5) axs[0].set_xlabel("snr") @@ -127,12 +126,12 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): axs[1].plot( distances_to_center[zdx], - savgol_filter(data[zdx], smoothing_factor, 3), + data[zdx], lw=2, label=self.cases[key]['label'], ) - ymin = savgol_filter((data - result["mads_over_templates"])[zdx], smoothing_factor, 3) - ymax = savgol_filter((data + result["mads_over_templates"])[zdx], smoothing_factor, 3) + ymin = (data - result["mads_over_templates"])[zdx] + ymax = (data + result["mads_over_templates"])[zdx] axs[1].fill_between(distances_to_center[zdx], ymin, ymax, alpha=0.5) axs[1].set_xlabel("distance to center (um)") @@ -245,7 +244,7 @@ def plot_template_errors(self, case_keys=None): axs.legend() - def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): + def plot_comparison_positions(self, case_keys=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -272,7 +271,7 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): data = result["errors"] axs[0].plot( - snrs[wdx], savgol_filter(data[wdx], smoothing_factor, 3), lw=2, label=self.cases[key]['label'] + snrs[wdx], data[wdx], lw=2, label=self.cases[key]['label'] ) axs[0].set_xlabel("snr") @@ -280,7 +279,7 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): axs[1].plot( distances_to_center[zdx], - savgol_filter(data[zdx], smoothing_factor, 3), + data[zdx], lw=2, label=self.cases[key]['label'], ) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py index baa756d521..a1555b8dba 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py @@ -53,7 +53,7 @@ def test_benchmark_peak_localization(): # load study to check persistency study = PeakLocalizationStudy(study_folder) - study.plot_comparison_positions(smoothing_factor=31) + study.plot_comparison_positions() study.plot_run_times() plt.show() @@ -95,7 +95,7 @@ def test_benchmark_unit_localization(): # load study to check persistency study = UnitLocalizationStudy(study_folder) - study.plot_comparison_positions(smoothing_factor=31) + study.plot_comparison_positions() study.plot_run_times() plt.show() From 2b5c2e539088136d5048df7b062e159512838654 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 21:34:44 +0000 Subject: [PATCH 147/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/plot_4_sorting_analyzer.py | 4 +- .../qualitymetrics/plot_3_quality_mertics.py | 2 +- .../core/analyzer_extension_core.py | 25 ++-- src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/sortinganalyzer.py | 14 +- .../tests/test_analyzer_extension_core.py | 7 +- ...forms_extractor_backwards_compatibility.py | 4 +- .../postprocessing/spike_amplitudes.py | 2 +- .../benchmark/benchmark_clustering.py | 124 +++++++++--------- .../benchmark/benchmark_matching.py | 76 +++++------ .../benchmark/benchmark_motion_estimation.py | 80 +++++------ .../benchmark_motion_interpolation.py | 61 +++++---- .../benchmark/benchmark_peak_localization.py | 106 +++++++-------- .../benchmark/benchmark_peak_selection.py | 64 ++++----- .../benchmark/benchmark_tools.py | 45 ++++--- .../tests/common_benchmark_testing.py | 105 +++++++-------- .../tests/test_benchmark_clustering.py | 20 +-- .../tests/test_benchmark_matching.py | 28 ++-- .../tests/test_benchmark_motion_estimation.py | 35 +++-- .../test_benchmark_motion_interpolation.py | 52 ++++---- .../tests/test_benchmark_peak_localization.py | 43 +++--- .../tests/test_benchmark_peak_selection.py | 4 - src/spikeinterface/widgets/sorting_summary.py | 9 +- .../widgets/tests/test_widgets.py | 7 +- 24 files changed, 464 insertions(+), 455 deletions(-) diff --git a/examples/modules_gallery/core/plot_4_sorting_analyzer.py b/examples/modules_gallery/core/plot_4_sorting_analyzer.py index 593b1103e1..98b2a49b30 100644 --- a/examples/modules_gallery/core/plot_4_sorting_analyzer.py +++ b/examples/modules_gallery/core/plot_4_sorting_analyzer.py @@ -18,7 +18,7 @@ * "noise_levels" : compute noise level from traces (usefull to get snr of units) * can be in memory or persistent to disk (2 formats binary/npy or zarr) -More extesions are available in `spikeinterface.postprocessing` like "principal_components", "spike_amplitudes", +More extesions are available in `spikeinterface.postprocessing` like "principal_components", "spike_amplitudes", "unit_lcations", ... @@ -92,7 +92,7 @@ ############################################################################### -# To speed up computation, some steps like ""waveforms" can also be extracted +# To speed up computation, some steps like ""waveforms" can also be extracted # using parallel processing (recommended!). Like this analyzer.compute("waveforms", ms_before=1.0, ms_after=2.0, return_scaled=True, diff --git a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py index 557fff229b..7e5eaaccbf 100644 --- a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py +++ b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py @@ -32,7 +32,7 @@ print(analyzer) ############################################################################## -# Depending on which metrics we want to compute we will need first to compute +# Depending on which metrics we want to compute we will need first to compute # some necessary extensions. (if not computed an error message will be raised) analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=600, seed=2205) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 79fe6ab600..268513dac8 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -41,24 +41,24 @@ class SelectRandomSpikes(AnalyzerExtension): ------- """ + extension_name = "random_spikes" depend_on = [] need_recording = False use_nodepipeline = False need_job_kwargs = False - def _run(self, + def _run( + self, ): - self.data["random_spikes_indices"] = random_spikes_selection( - self.sorting_analyzer.sorting, num_samples=self.sorting_analyzer.rec_attributes["num_samples"], - **self.params) + self.data["random_spikes_indices"] = random_spikes_selection( + self.sorting_analyzer.sorting, + num_samples=self.sorting_analyzer.rec_attributes["num_samples"], + **self.params, + ) def _set_params(self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None): - params = dict( - method=method, - max_spikes_per_unit=max_spikes_per_unit, - margin_size=margin_size, - seed=seed) + params = dict(method=method, max_spikes_per_unit=max_spikes_per_unit, margin_size=margin_size, seed=seed) return params def _select_extension_data(self, unit_ids): @@ -76,7 +76,6 @@ def _select_extension_data(self, unit_ids): new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_spike_mask]) return new_data - def _get_data(self): return self.data["random_spikes_indices"] @@ -88,7 +87,6 @@ def some_spikes(self): self._some_spikes = spikes[self.data["random_spikes_indices"]] return self._some_spikes - def get_selected_indices_in_spike_train(self, unit_id, segment_index): # usefull for Waveforms extractor backwars compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain @@ -107,12 +105,9 @@ def get_selected_indices_in_spike_train(self, unit_id, segment_index): return selected_spikes_in_spike_train - register_result_extension(SelectRandomSpikes) - - class ComputeWaveforms(AnalyzerExtension): """ AnalyzerExtension that extract some waveforms of each units. @@ -197,7 +192,7 @@ def _set_params( if return_scaled: # check if has scaled values: - if not recording.has_scaled() and recording.get_dtype().kind == 'i': + if not recording.has_scaled() and recording.get_dtype().kind == "i": print("Setting 'return_scaled' to False") return_scaled = False diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b834cbac96..74937d0861 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -310,7 +310,7 @@ def get_traces( warnings.warn(message) if not self.has_scaled(): - if self._dtype.kind == 'f': + if self._dtype.kind == "f": # here we do not truely have scale but we assume this is scaled # this helps a lot for simulated data pass diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 2ba9737ee5..f1858810fb 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -161,7 +161,12 @@ class SortingAnalyzer: """ def __init__( - self, sorting=None, recording=None, rec_attributes=None, format=None, sparsity=None, + self, + sorting=None, + recording=None, + rec_attributes=None, + format=None, + sparsity=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -409,7 +414,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): # the recording rec_dict = recording.to_dict(relative_to=folder, recursive=True) - + if recording.check_serializability("json"): # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) zarr_rec = np.array([check_json(rec_dict)], dtype=object) @@ -760,12 +765,11 @@ def compute(self, input, save=True, **kwargs): elif isinstance(input, list): params_, job_kwargs = split_job_kwargs(kwargs) assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" - extensions = {k : {} for k in input} - self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs) + extensions = {k: {} for k in input} + self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs) else: raise ValueError("SortingAnalyzer.compute() need str, dict or list") - def compute_one_extension(self, extension_name, save=True, **kwargs): """ Compute one extension diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index f94226110f..cb70b21d69 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -74,7 +74,12 @@ def _check_result_extension(sorting_analyzer, extension_name): @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) -@pytest.mark.parametrize("sparse", [False, ]) +@pytest.mark.parametrize( + "sparse", + [ + False, + ], +) def test_SelectRandomSpikes(format, sparse): sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 56dc17817b..f10454e085 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -231,7 +231,9 @@ def get_sampled_indices(self, unit_id): selected_spikes = [] for segment_index in range(self.get_num_segments()): # inds = self.sorting_analyzer.get_selected_indices_in_spike_train(unit_id, segment_index) - inds = self.sorting_analyzer.get_extension("random_spikes").get_selected_indices_in_spike_train(unit_id, segment_index) + inds = self.sorting_analyzer.get_extension("random_spikes").get_selected_indices_in_spike_train( + unit_id, segment_index + ) sampled_index = np.zeros(inds.size, dtype=[("spike_index", "int64"), ("segment_index", "int64")]) sampled_index["spike_index"] = inds sampled_index["segment_index"][:] = segment_index diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 30aeca4b4b..7362dfc4dd 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -98,7 +98,7 @@ def _get_pipeline_nodes(self): if return_scaled: # check if has scaled values: - if not recording.has_scaled_traces() and recording.get_dtype().kind == 'i': + if not recording.has_scaled_traces() and recording.get_dtype().kind == "i": warnings.warn("Recording doesn't have scaled traces! Setting 'return_scaled' to False") return_scaled = False diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index c01a10eb9d..373fe6b37b 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -11,7 +11,8 @@ plot_unit_waveforms, ) from spikeinterface.comparison.comparisontools import make_matching_events -#from spikeinterface.postprocessing import get_template_extremum_channel + +# from spikeinterface.postprocessing import get_template_extremum_channel from spikeinterface.core import get_noise_levels import pylab as plt @@ -32,9 +33,9 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.gt_sorting = gt_sorting self.indices = indices - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') + ext = sorting_analyzer.compute("fast_templates") extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) @@ -43,55 +44,57 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.peaks = peaks[self.indices] self.params = params self.exhaustive_gt = exhaustive_gt - self.method = params['method'] - self.method_kwargs = params['method_kwargs'] + self.method = params["method"] + self.method_kwargs = params["method_kwargs"] self.result = {} - - def run(self, **job_kwargs): + + def run(self, **job_kwargs): labels, peak_labels = find_cluster_from_peaks( - self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs + self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs ) - self.result['peak_labels'] = peak_labels + self.result["peak_labels"] = peak_labels def compute_result(self, **result_params): - self.noise = self.result['peak_labels'] < 0 + self.noise = self.result["peak_labels"] < 0 spikes = self.gt_sorting.to_spike_vector() - self.result['sliced_gt_sorting'] = NumpySorting(spikes[self.indices], - self.recording.sampling_frequency, - self.gt_sorting.unit_ids) + self.result["sliced_gt_sorting"] = NumpySorting( + spikes[self.indices], self.recording.sampling_frequency, self.gt_sorting.unit_ids + ) data = spikes[self.indices][~self.noise] - data["unit_index"] = self.result['peak_labels'][~self.noise] + data["unit_index"] = self.result["peak_labels"][~self.noise] + + self.result["clustering"] = NumpySorting.from_times_labels( + data["sample_index"], self.result["peak_labels"][~self.noise], self.recording.sampling_frequency + ) - self.result['clustering'] = NumpySorting.from_times_labels(data["sample_index"], - self.result['peak_labels'][~self.noise], - self.recording.sampling_frequency) - - self.result['gt_comparison'] = GroundTruthComparison(self.result['sliced_gt_sorting'], - self.result['clustering'], - exhaustive_gt=self.exhaustive_gt) + self.result["gt_comparison"] = GroundTruthComparison( + self.result["sliced_gt_sorting"], self.result["clustering"], exhaustive_gt=self.exhaustive_gt + ) - sorting_analyzer = create_sorting_analyzer(self.result['sliced_gt_sorting'], self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer( + self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False + ) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') - self.result['sliced_gt_templates'] = ext.get_data(outputs="Templates") + ext = sorting_analyzer.compute("fast_templates") + self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates") - sorting_analyzer = create_sorting_analyzer(self.result['clustering'], self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer( + self.result["clustering"], self.recording, format="memory", sparse=False + ) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') - self.result['clustering_templates'] = ext.get_data(outputs="Templates") + ext = sorting_analyzer.compute("fast_templates") + self.result["clustering_templates"] = ext.get_data(outputs="Templates") - _run_key_saved = [ - ("peak_labels", "npy") - ] + _run_key_saved = [("peak_labels", "npy")] _result_key_saved = [ ("gt_comparison", "pickle"), ("sliced_gt_sorting", "sorting"), ("clustering", "sorting"), ("sliced_gt_templates", "zarr_templates"), - ("clustering_templates", "zarr_templates") + ("clustering_templates", "zarr_templates"), ] @@ -108,23 +111,24 @@ def create_benchmark(self, key): return benchmark def homogeneity_score(self, ignore_noise=True, case_keys=None): - + if case_keys is None: case_keys = list(self.cases.keys()) - + for count, key in enumerate(case_keys): result = self.get_result(key) noise = result["peak_labels"] < 0 from sklearn.metrics import homogeneity_score + gt_labels = self.benchmarks[key].gt_sorting.to_spike_vector()["unit_index"] gt_labels = gt_labels[self.benchmarks[key].indices] - found_labels = result['peak_labels'] + found_labels = result["peak_labels"] if ignore_noise: gt_labels = gt_labels[~noise] found_labels = found_labels[~noise] - print(self.cases[key]['label'], homogeneity_score(gt_labels, found_labels), np.mean(noise)) + print(self.cases[key]["label"], homogeneity_score(gt_labels, found_labels), np.mean(noise)) - def plot_agreements(self, case_keys=None, figsize=(15,15)): + def plot_agreements(self, case_keys=None, figsize=(15, 15)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -132,32 +136,32 @@ def plot_agreements(self, case_keys=None, figsize=(15,15)): for count, key in enumerate(case_keys): ax = axs[count] - ax.set_title(self.cases[key]['label']) - plot_agreement_matrix(self.get_result(key)['gt_comparison'], ax=ax) + ax.set_title(self.cases[key]["label"]) + plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) - def plot_performances_vs_snr(self, case_keys=None, figsize=(15,15)): + def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): if case_keys is None: case_keys = list(self.cases.keys()) fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) for count, k in enumerate(("accuracy", "recall", "precision")): - + ax = axs[count] for key in case_keys: label = self.cases[key]["label"] - + analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() + metrics = analyzer.get_extension("quality_metrics").get_data() x = metrics["snr"].values - y = self.get_result(key)['gt_comparison'].get_performance()[k].values + y = self.get_result(key)["gt_comparison"].get_performance()[k].values ax.scatter(x, y, marker=".", label=label) ax.set_title(k) if count == 2: ax.legend() - - def plot_error_metrics(self, metric='cosine', case_keys=None, figsize=(15,5)): + + def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -167,14 +171,14 @@ def plot_error_metrics(self, metric='cosine', case_keys=None, figsize=(15,5)): for count, key in enumerate(case_keys): result = self.get_result(key) - scores = result['gt_comparison'].get_ordered_agreement_scores() + scores = result["gt_comparison"].get_ordered_agreement_scores() unit_ids1 = scores.index.values unit_ids2 = scores.columns.values - inds_1 = result['gt_comparison'].sorting1.ids_to_indices(unit_ids1) - inds_2 = result['gt_comparison'].sorting2.ids_to_indices(unit_ids2) + inds_1 = result["gt_comparison"].sorting1.ids_to_indices(unit_ids1) + inds_2 = result["gt_comparison"].sorting2.ids_to_indices(unit_ids2) t1 = result["sliced_gt_templates"].templates_array - t2 = result['clustering_templates'].templates_array + t2 = result["clustering_templates"].templates_array a = t1.reshape(len(t1), -1)[inds_1] b = t2.reshape(len(t2), -1)[inds_2] @@ -191,8 +195,7 @@ def plot_error_metrics(self, metric='cosine', case_keys=None, figsize=(15,5)): label = self.cases[key]["label"] axs[count].set_title(label) - - def plot_metrics_vs_snr(self, metric='cosine', case_keys=None, figsize=(15,5)): + def plot_metrics_vs_snr(self, metric="cosine", case_keys=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -202,17 +205,17 @@ def plot_metrics_vs_snr(self, metric='cosine', case_keys=None, figsize=(15,5)): for count, key in enumerate(case_keys): result = self.get_result(key) - scores = result['gt_comparison'].get_ordered_agreement_scores() + scores = result["gt_comparison"].get_ordered_agreement_scores() analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() - + metrics = analyzer.get_extension("quality_metrics").get_data() + unit_ids1 = scores.index.values unit_ids2 = scores.columns.values - inds_1 = result['gt_comparison'].sorting1.ids_to_indices(unit_ids1) - inds_2 = result['gt_comparison'].sorting2.ids_to_indices(unit_ids2) + inds_1 = result["gt_comparison"].sorting1.ids_to_indices(unit_ids1) + inds_2 = result["gt_comparison"].sorting2.ids_to_indices(unit_ids2) t1 = result["sliced_gt_templates"].templates_array - t2 = result['clustering_templates'].templates_array + t2 = result["clustering_templates"].templates_array a = t1.reshape(len(t1), -1) b = t2.reshape(len(t2), -1) @@ -222,19 +225,18 @@ def plot_metrics_vs_snr(self, metric='cosine', case_keys=None, figsize=(15,5)): distances = sklearn.metrics.pairwise.cosine_similarity(a, b) else: distances = sklearn.metrics.pairwise_distances(a, b, metric) - + snr = metrics["snr"][unit_ids1][inds_1[: len(inds_2)]] to_plot = [] for found, real in zip(inds_2, inds_1): to_plot += [distances[real, found]] - axs[count].plot(snr, to_plot, '.') - axs[count].set_xlabel('snr') + axs[count].plot(snr, to_plot, ".") + axs[count].set_xlabel("snr") axs[count].set_ylabel(metric) label = self.cases[key]["label"] axs[count].set_title(label) - # def _scatter_clusters( # self, # xs, diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 4a5221e16d..ffecbe028f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -24,17 +24,14 @@ class MatchingBenchmark(Benchmark): def __init__(self, recording, gt_sorting, params): self.recording = recording self.gt_sorting = gt_sorting - self.method = params['method'] - self.templates = params["method_kwargs"]['templates'] - self.method_kwargs = params['method_kwargs'] + self.method = params["method"] + self.templates = params["method_kwargs"]["templates"] + self.method_kwargs = params["method_kwargs"] self.result = {} def run(self, **job_kwargs): spikes = find_spikes_from_templates( - self.recording, - method=self.method, - method_kwargs=self.method_kwargs, - **job_kwargs + self.recording, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs ) unit_ids = self.templates.unit_ids sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) @@ -42,36 +39,33 @@ def run(self, **job_kwargs): sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids) - self.result = {'sorting' : sorting} - self.result['templates'] = self.templates + self.result = {"sorting": sorting} + self.result["templates"] = self.templates def compute_result(self, **result_params): - sorting = self.result['sorting'] + sorting = self.result["sorting"] comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) - self.result['gt_comparison'] = comp - self.result['gt_collision'] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) - + self.result["gt_comparison"] = comp + self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) + _run_key_saved = [ ("sorting", "sorting"), ("templates", "zarr_templates"), ] - _result_key_saved = [ - ("gt_collision", "pickle"), - ("gt_comparison", "pickle") - ] + _result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")] class MatchingStudy(BenchmarkStudy): benchmark_class = MatchingBenchmark - def create_benchmark(self,key): + def create_benchmark(self, key): dataset_key = self.cases[key]["dataset"] recording, gt_sorting = self.datasets[dataset_key] params = self.cases[key]["params"] benchmark = MatchingBenchmark(recording, gt_sorting, params) return benchmark - + def plot_agreements(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -80,9 +74,9 @@ def plot_agreements(self, case_keys=None, figsize=None): for count, key in enumerate(case_keys): ax = axs[count] - ax.set_title(self.cases[key]['label']) - plot_agreement_matrix(self.get_result(key)['gt_comparison'], ax=ax) - + ax.set_title(self.cases[key]["label"]) + plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + def plot_performances_vs_snr(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -90,15 +84,15 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=None): fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) for count, k in enumerate(("accuracy", "recall", "precision")): - + ax = axs[count] for key in case_keys: label = self.cases[key]["label"] - + analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() + metrics = analyzer.get_extension("quality_metrics").get_data() x = metrics["snr"].values - y = self.get_result(key)['gt_comparison'].get_performance()[k].values + y = self.get_result(key)["gt_comparison"].get_performance()[k].values ax.scatter(x, y, marker=".", label=label) ax.set_title(k) @@ -108,23 +102,29 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=None): def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) - + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) for count, key in enumerate(case_keys): - templates_array = self.get_result(key)['templates'].templates_array + templates_array = self.get_result(key)["templates"].templates_array plot_comparison_collision_by_similarity( - self.get_result(key)['gt_collision'], templates_array, ax=axs[count], - show_legend=True, mode="lines", good_only=False + self.get_result(key)["gt_collision"], + templates_array, + ax=axs[count], + show_legend=True, + mode="lines", + good_only=False, ) - def plot_comparison_matching(self, case_keys=None, + def plot_comparison_matching( + self, + case_keys=None, performance_names=["accuracy", "recall", "precision"], colors=["g", "b", "r"], ylim=(-0.1, 1.1), - figsize=None + figsize=None, ): - + if case_keys is None: case_keys = list(self.cases.keys()) @@ -136,8 +136,8 @@ def plot_comparison_matching(self, case_keys=None, ax = axs[i, j] else: ax = axs[j] - comp1 = self.get_result(key1)['gt_comparison'] - comp2 = self.get_result(key2)['gt_comparison'] + comp1 = self.get_result(key1)["gt_comparison"] + comp2 = self.get_result(key2)["gt_comparison"] if i <= j: for performance, color in zip(performance_names, colors): perf1 = comp1.get_performance()[performance] @@ -150,8 +150,8 @@ def plot_comparison_matching(self, case_keys=None, ax.spines[["right", "top"]].set_visible(False) ax.set_aspect("equal") - label1 = self.cases[key1]['label'] - label2 = self.cases[key2]['label'] + label1 = self.cases[key1]["label"] + label2 = self.cases[key2]["label"] if j == i: ax.set_ylabel(f"{label1}") else: @@ -172,4 +172,4 @@ def plot_comparison_matching(self, case_keys=None, ax.spines["right"].set_visible(False) ax.set_xticks([]) ax.set_yticks([]) - plt.tight_layout(h_pad=0, w_pad=0) \ No newline at end of file + plt.tight_layout(h_pad=0, w_pad=0) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index be048ede50..12f0ff7a4a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -28,9 +28,7 @@ # TODO : read from mearec - - -def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direction_dim = 1): +def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direction_dim=1): """ Get final displacement vector unit per units. @@ -57,7 +55,7 @@ def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direct """ - num_units = displacement_unit_factor.shape[0] + num_units = displacement_unit_factor.shape[0] unit_displacements = np.zeros((displacement_vectors.shape[0], num_units)) for i in range(displacement_vectors.shape[2]): m = displacement_vectors[:, direction_dim, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :] @@ -66,10 +64,14 @@ def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direct return unit_displacements -def get_gt_motion_from_unit_discplacement(unit_displacements, displacement_sampling_frequency, - unit_locations, - temporal_bins, spatial_bins, - direction_dim=1,): +def get_gt_motion_from_unit_discplacement( + unit_displacements, + displacement_sampling_frequency, + unit_locations, + temporal_bins, + spatial_bins, + direction_dim=1, +): times = np.arange(unit_displacements.shape[0]) / displacement_sampling_frequency f = scipy.interpolate.interp1d(times, unit_displacements, axis=0) @@ -83,17 +85,25 @@ def get_gt_motion_from_unit_discplacement(unit_displacements, displacement_sampl # non rigid gt_motion = np.zeros((temporal_bins.size, spatial_bins.size)) for t in range(temporal_bins.shape[0]): - f = scipy.interpolate.interp1d(unit_locations[:, direction_dim], unit_displacements[t, :], fill_value="extrapolate") + f = scipy.interpolate.interp1d( + unit_locations[:, direction_dim], unit_displacements[t, :], fill_value="extrapolate" + ) gt_motion[t, :] = f(spatial_bins) return gt_motion - class MotionEstimationBenchmark(Benchmark): - def __init__(self, recording, gt_sorting, params, - unit_locations, unit_displacements, displacement_sampling_frequency, - direction="y"): + def __init__( + self, + recording, + gt_sorting, + params, + unit_locations, + unit_displacements, + displacement_sampling_frequency, + direction="y", + ): Benchmark.__init__(self) self.recording = recording self.gt_sorting = gt_sorting @@ -110,9 +120,7 @@ def run(self, **job_kwargs): noise_levels = get_noise_levels(self.recording, return_scaled=False) t0 = time.perf_counter() - peaks = detect_peaks( - self.recording, noise_levels=noise_levels, **p["detect_kwargs"], **job_kwargs - ) + peaks = detect_peaks(self.recording, noise_levels=noise_levels, **p["detect_kwargs"], **job_kwargs) t1 = time.perf_counter() if p["select_kwargs"] is not None: selected_peaks = select_peaks(self.peaks, **p["select_kwargs"], **job_kwargs) @@ -120,9 +128,7 @@ def run(self, **job_kwargs): selected_peaks = peaks t2 = time.perf_counter() - peak_locations = localize_peaks( - self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs - ) + peak_locations = localize_peaks(self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs) t3 = time.perf_counter() motion, temporal_bins, spatial_bins = estimate_motion( self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"] @@ -159,7 +165,9 @@ def compute_result(self, **result_params): # non rigid gt_motion = np.zeros_like(raw_motion) for t in range(temporal_bins.shape[0]): - f = scipy.interpolate.interp1d(self.unit_locations[:, self.direction_dim], unit_displacements[t, :], fill_value="extrapolate") + f = scipy.interpolate.interp1d( + self.unit_locations[:, self.direction_dim], unit_displacements[t, :], fill_value="extrapolate" + ) gt_motion[t, :] = f(spatial_bins) # align globally gt_motion and motion to avoid offsets @@ -168,7 +176,6 @@ def compute_result(self, **result_params): self.result["gt_motion"] = gt_motion self.result["motion"] = motion - _run_key_saved = [ ("raw_motion", "npy"), ("temporal_bins", "npy"), @@ -176,14 +183,17 @@ def compute_result(self, **result_params): ("step_run_times", "pickle"), ] _result_key_saved = [ - ("gt_motion", "npy",), - ("motion", "npy",) + ( + "gt_motion", + "npy", + ), + ( + "motion", + "npy", + ), ] - - - class MotionEstimationStudy(BenchmarkStudy): benchmark_class = MotionEstimationBenchmark @@ -197,7 +207,6 @@ def create_benchmark(self, key): return benchmark def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): - if case_keys is None: case_keys = list(self.cases.keys()) @@ -218,7 +227,7 @@ def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): ax.set_ylabel("depth (um)") ax.set_xlabel(None) - ax.set_aspect('auto') + ax.set_aspect("auto") # dirft ax = ax1 = fig.add_subplot(gs[2:7]) @@ -227,7 +236,6 @@ def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): spatial_bins = bench.result["spatial_bins"] gt_motion = bench.result["gt_motion"] - # for i in range(self.gt_unit_positions.shape[1]): # ax.plot(temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") @@ -246,8 +254,7 @@ def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) - - ax = ax2= fig.add_subplot(gs[7]) + ax = ax2 = fig.add_subplot(gs[7]) ax2.sharey(ax0) _simpleaxis(ax) ax.hist(unit_locations[:, bench.direction_dim], bins=50, orientation="horizontal", color="0.5") @@ -274,7 +281,6 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): temporal_bins = bench.result["temporal_bins"] spatial_bins = bench.result["spatial_bins"] - fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(2, 2) @@ -319,13 +325,13 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): if lim is not None: ax.set_ylim(0, lim) - def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, figsize=(15, 5)): + def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) fig, axes = plt.subplots(1, 3, figsize=figsize) - + for count, key in enumerate(case_keys): bench = self.benchmarks[key] @@ -336,9 +342,6 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fi temporal_bins = bench.result["temporal_bins"] spatial_bins = bench.result["spatial_bins"] - - - c = colors[count] if colors is not None else None errors = gt_motion - motion mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) @@ -384,9 +387,6 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fi # ax2.sharey(ax0) - - - # class BenchmarkMotionEstimationMearec(BenchmarkBase): # _array_names = ( # "noise_levels", diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index d2b83f181a..af45f7421f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -20,12 +20,18 @@ import matplotlib.pyplot as plt - class MotionInterpolationBenchmark(Benchmark): - def __init__(self, static_recording, gt_sorting, params, - sorter_folder, drifting_recording, - motion, temporal_bins, spatial_bins, - ): + def __init__( + self, + static_recording, + gt_sorting, + params, + sorter_folder, + drifting_recording, + motion, + temporal_bins, + spatial_bins, + ): Benchmark.__init__(self) self.static_recording = static_recording self.gt_sorting = gt_sorting @@ -37,14 +43,13 @@ def __init__(self, static_recording, gt_sorting, params, self.temporal_bins = temporal_bins self.spatial_bins = spatial_bins - def run(self, **job_kwargs): - if self.params["recording_source"] == 'static': + if self.params["recording_source"] == "static": recording = self.static_recording - elif self.params["recording_source"] == 'drifting': + elif self.params["recording_source"] == "drifting": recording = self.drifting_recording - elif self.params["recording_source"] == 'corrected': + elif self.params["recording_source"] == "corrected": correct_motion_kwargs = self.params["correct_motion_kwargs"] recording = InterpolateMotionRecording( self.drifting_recording, self.motion, self.temporal_bins, self.spatial_bins, **correct_motion_kwargs @@ -66,12 +71,11 @@ def run(self, **job_kwargs): def compute_result(self, exhaustive_gt=True, merging_score=0.2): sorting = self.result["sorting"] - # self.result[""] = + # self.result[""] = comparison = GroundTruthComparison(self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt) self.result["comparison"] = comparison self.result["accuracy"] = comparison.get_performance()["accuracy"].values.astype("float32") - gt_unit_ids = self.gt_sorting.unit_ids unit_ids = sorting.unit_ids @@ -90,7 +94,6 @@ def compute_result(self, exhaustive_gt=True, merging_score=0.2): self.result["comparison_merged"] = comparison_merged self.result["accuracy_merged"] = comparison_merged.get_performance()["accuracy"].values.astype("float32") - _run_key_saved = [ ("sorting", "sorting"), ] @@ -111,24 +114,32 @@ def create_benchmark(self, key): recording, gt_sorting = self.datasets[dataset_key] params = self.cases[key]["params"] init_kwargs = self.cases[key]["init_kwargs"] - sorter_folder = self.folder / "sorters" /self.key_to_str(key) + sorter_folder = self.folder / "sorters" / self.key_to_str(key) sorter_folder.parent.mkdir(exist_ok=True) - benchmark = MotionInterpolationBenchmark(recording, gt_sorting, params, - sorter_folder=sorter_folder, **init_kwargs) + benchmark = MotionInterpolationBenchmark( + recording, gt_sorting, params, sorter_folder=sorter_folder, **init_kwargs + ) return benchmark - - def plot_sorting_accuracy(self, case_keys=None, mode="ordered_accuracy", legend=True, colors=None, - mode_best_merge=False, figsize=(10, 5), ax=None, axes=None): - + def plot_sorting_accuracy( + self, + case_keys=None, + mode="ordered_accuracy", + legend=True, + colors=None, + mode_best_merge=False, + figsize=(10, 5), + ax=None, + axes=None, + ): if case_keys is None: case_keys = list(self.cases.keys()) if not mode_best_merge: - ls = '-' + ls = "-" else: - ls = '--' + ls = "--" if mode == "ordered_accuracy": if ax is None: @@ -176,7 +187,7 @@ def plot_sorting_accuracy(self, case_keys=None, mode="ordered_accuracy", legend= unit_locations = ext.get_data() unit_depth = unit_locations[:, 1] - snr= analyzer.get_extension("quality_metrics").get_data()["snr"].values + snr = analyzer.get_extension("quality_metrics").get_data()["snr"].values points = ax.scatter(unit_depth, snr, c=accuracy) points.set_clim(0.0, 1.0) @@ -203,9 +214,9 @@ def plot_sorting_accuracy(self, case_keys=None, mode="ordered_accuracy", legend= else: accuracy = result["accuracy_merged"] - analyzer = self.get_sorting_analyzer(key) - snr= analyzer.get_extension("quality_metrics").get_data()["snr"].values - + analyzer = self.get_sorting_analyzer(key) + snr = analyzer.get_extension("quality_metrics").get_data()["snr"].values + ax.scatter(snr, accuracy, label=label) ax.set_xlabel("snr") ax.set_ylabel("accuracy") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index 18b988953e..415429881d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -10,7 +10,11 @@ plot_unit_templates, plot_unit_waveforms, ) -from spikeinterface.postprocessing.unit_localization import compute_center_of_mass, compute_monopolar_triangulation, compute_grid_convolution +from spikeinterface.postprocessing.unit_localization import ( + compute_center_of_mass, + compute_monopolar_triangulation, + compute_grid_convolution, +) from spikeinterface.core import get_noise_levels import pylab as plt @@ -36,34 +40,31 @@ def __init__(self, recording, gt_sorting, params, gt_positions): self.params[key] = 2 def run(self, **job_kwargs): - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates', **self.templates_params) - templates = ext.get_data(outputs='Templates') + ext = sorting_analyzer.compute("fast_templates", **self.templates_params) + templates = ext.get_data(outputs="Templates") ext = sorting_analyzer.compute("spike_locations", **self.params) spikes_locations = ext.get_data(outputs="by_unit") - self.result = {'spikes_locations' : spikes_locations} - self.result['templates'] = templates + self.result = {"spikes_locations": spikes_locations} + self.result["templates"] = templates def compute_result(self, **result_params): errors = {} for unit_ind, unit_id in enumerate(self.gt_sorting.unit_ids): - data = self.result['spikes_locations'][0][unit_id] + data = self.result["spikes_locations"][0][unit_id] errors[unit_id] = np.sqrt( (data["x"] - self.gt_positions[unit_ind, 0]) ** 2 + (data["y"] - self.gt_positions[unit_ind, 1]) ** 2 ) - self.result['medians_over_templates'] = np.array( + self.result["medians_over_templates"] = np.array( [np.median(errors[unit_id]) for unit_id in self.gt_sorting.unit_ids] ) - self.result['mads_over_templates'] = np.array( - [ - np.median(np.abs(errors[unit_id] - np.median(errors[unit_id]))) - for unit_id in self.gt_sorting.unit_ids - ] + self.result["mads_over_templates"] = np.array( + [np.median(np.abs(errors[unit_id] - np.median(errors[unit_id]))) for unit_id in self.gt_sorting.unit_ids] ) - self.result['errors'] = errors + self.result["errors"] = errors _run_key_saved = [ ("spikes_locations", "pickle"), @@ -95,15 +96,14 @@ def plot_comparison_positions(self, case_keys=None): fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5)) - for count, key in enumerate(case_keys): analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() + metrics = analyzer.get_extension("quality_metrics").get_data() snrs = metrics["snr"].values result = self.get_result(key) - norms = np.linalg.norm(result['templates'].templates_array, axis=(1, 2)) + norms = np.linalg.norm(result["templates"].templates_array, axis=(1, 2)) - coordinates = self.benchmarks[key].gt_positions[:, :2].copy() + coordinates = self.benchmarks[key].gt_positions[:, :2].copy() coordinates[:, 0] -= coordinates[:, 0].mean() coordinates[:, 1] -= coordinates[:, 1].mean() distances_to_center = np.linalg.norm(coordinates, axis=1) @@ -114,9 +114,7 @@ def plot_comparison_positions(self, case_keys=None): data = result["medians_over_templates"] - axs[0].plot( - snrs[wdx], data[wdx], lw=2, label=self.cases[key]['label'] - ) + axs[0].plot(snrs[wdx], data[wdx], lw=2, label=self.cases[key]["label"]) ymin = (data - result["mads_over_templates"])[wdx] ymax = (data + result["mads_over_templates"])[wdx] @@ -128,7 +126,7 @@ def plot_comparison_positions(self, case_keys=None): distances_to_center[zdx], data[zdx], lw=2, - label=self.cases[key]['label'], + label=self.cases[key]["label"], ) ymin = (data - result["mads_over_templates"])[zdx] ymax = (data + result["mads_over_templates"])[zdx] @@ -139,14 +137,14 @@ def plot_comparison_positions(self, case_keys=None): x_means = [] x_stds = [] for count, key in enumerate(case_keys): - result = self.get_result(key)['medians_over_templates'] + result = self.get_result(key)["medians_over_templates"] x_means += [result.mean()] x_stds += [result.std()] y_means = [] y_stds = [] for count, key in enumerate(case_keys): - result = self.get_result(key)['mads_over_templates'] + result = self.get_result(key)["mads_over_templates"] y_means += [result.mean()] y_stds += [result.std()] @@ -161,15 +159,14 @@ def plot_comparison_positions(self, case_keys=None): axs[1].legend() - class UnitLocalizationBenchmark(Benchmark): def __init__(self, recording, gt_sorting, params, gt_positions): self.recording = recording self.gt_sorting = gt_sorting self.gt_positions = gt_positions - self.method = params['method'] - self.method_kwargs = params['method_kwargs'] + self.method = params["method"] + self.method_kwargs = params["method_kwargs"] self.result = {} self.waveforms_params = {} for key in ["ms_before", "ms_after"]: @@ -179,11 +176,11 @@ def __init__(self, recording, gt_sorting, params, gt_positions): self.waveforms_params[key] = 2 def run(self, **job_kwargs): - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory') + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory") sorting_analyzer.compute("random_spikes") - sorting_analyzer.compute('waveforms', **self.waveforms_params, **job_kwargs) - ext = sorting_analyzer.compute('templates') - templates = ext.get_data(outputs='Templates') + sorting_analyzer.compute("waveforms", **self.waveforms_params, **job_kwargs) + ext = sorting_analyzer.compute("templates") + templates = ext.get_data(outputs="Templates") if self.method == "center_of_mass": unit_locations = compute_center_of_mass(sorting_analyzer, **self.method_kwargs) @@ -192,23 +189,21 @@ def run(self, **job_kwargs): elif self.method == "grid_convolution": unit_locations = compute_grid_convolution(sorting_analyzer, **self.method_kwargs) - if (unit_locations.shape[1] == 2): + if unit_locations.shape[1] == 2: unit_locations = np.hstack((unit_locations, np.zeros((len(unit_locations), 1)))) - - self.result = {'unit_locations' : unit_locations} - self.result['templates'] = templates + + self.result = {"unit_locations": unit_locations} + self.result["templates"] = templates def compute_result(self, **result_params): - errors = np.linalg.norm(self.gt_positions[:, :2] - self.result['unit_locations'][:, :2], axis=1) - self.result['errors'] = errors - + errors = np.linalg.norm(self.gt_positions[:, :2] - self.result["unit_locations"][:, :2], axis=1) + self.result["errors"] = errors + _run_key_saved = [ ("unit_locations", "npy"), ("templates", "zarr_templates"), ] - _result_key_saved = [ - ("errors", "npy") - ] + _result_key_saved = [("errors", "npy")] class UnitLocalizationStudy(BenchmarkStudy): @@ -229,21 +224,21 @@ def plot_template_errors(self, case_keys=None): case_keys = list(self.cases.keys()) fig, axs = plt.subplots(ncols=1, nrows=1, figsize=(15, 5)) from spikeinterface.widgets import plot_probe_map - #plot_probe_map(self.benchmarks[case_keys[0]].recording, ax=axs) + + # plot_probe_map(self.benchmarks[case_keys[0]].recording, ax=axs) axs.scatter(self.gt_positions[:, 0], self.gt_positions[:, 1], c=np.arange(len(self.gt_positions)), cmap="jet") - + for count, key in enumerate(case_keys): result = self.get_result(key) axs.scatter( - result['unit_locations'][:, 0], - result['unit_locations'][:, 1], - c=f'C{count}', + result["unit_locations"][:, 0], + result["unit_locations"][:, 1], + c=f"C{count}", marker="v", - label=self.cases[key]['label'] + label=self.cases[key]["label"], ) axs.legend() - def plot_comparison_positions(self, case_keys=None): if case_keys is None: @@ -253,12 +248,12 @@ def plot_comparison_positions(self, case_keys=None): for count, key in enumerate(case_keys): analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() + metrics = analyzer.get_extension("quality_metrics").get_data() snrs = metrics["snr"].values result = self.get_result(key) - norms = np.linalg.norm(result['templates'].templates_array, axis=(1, 2)) + norms = np.linalg.norm(result["templates"].templates_array, axis=(1, 2)) - coordinates = self.benchmarks[key].gt_positions[:, :2].copy() + coordinates = self.benchmarks[key].gt_positions[:, :2].copy() coordinates[:, 0] -= coordinates[:, 0].mean() coordinates[:, 1] -= coordinates[:, 1].mean() distances_to_center = np.linalg.norm(coordinates, axis=1) @@ -266,14 +261,13 @@ def plot_comparison_positions(self, case_keys=None): idx = np.argsort(norms) from scipy.signal import savgol_filter + wdx = np.argsort(snrs) data = result["errors"] - axs[0].plot( - snrs[wdx], data[wdx], lw=2, label=self.cases[key]['label'] - ) - + axs[0].plot(snrs[wdx], data[wdx], lw=2, label=self.cases[key]["label"]) + axs[0].set_xlabel("snr") axs[0].set_ylabel("error (um)") @@ -281,7 +275,7 @@ def plot_comparison_positions(self, case_keys=None): distances_to_center[zdx], data[zdx], lw=2, - label=self.cases[key]['label'], + label=self.cases[key]["label"], ) axs[1].legend() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1f97f0a0c6..a51c8c8145 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -39,9 +39,9 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.gt_sorting = gt_sorting self.indices = indices - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') + ext = sorting_analyzer.compute("fast_templates") extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) @@ -50,55 +50,57 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.peaks = peaks[self.indices] self.params = params self.exhaustive_gt = exhaustive_gt - self.method = params['method'] - self.method_kwargs = params['method_kwargs'] + self.method = params["method"] + self.method_kwargs = params["method_kwargs"] self.result = {} - - def run(self, **job_kwargs): + + def run(self, **job_kwargs): labels, peak_labels = find_cluster_from_peaks( - self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs + self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs ) - self.result['peak_labels'] = peak_labels + self.result["peak_labels"] = peak_labels def compute_result(self, **result_params): - self.noise = self.result['peak_labels'] < 0 + self.noise = self.result["peak_labels"] < 0 spikes = self.gt_sorting.to_spike_vector() - self.result['sliced_gt_sorting'] = NumpySorting(spikes[self.indices], - self.recording.sampling_frequency, - self.gt_sorting.unit_ids) + self.result["sliced_gt_sorting"] = NumpySorting( + spikes[self.indices], self.recording.sampling_frequency, self.gt_sorting.unit_ids + ) data = spikes[self.indices][~self.noise] - data["unit_index"] = self.result['peak_labels'][~self.noise] + data["unit_index"] = self.result["peak_labels"][~self.noise] - self.result['clustering'] = NumpySorting.from_times_labels(data["sample_index"], - self.result['peak_labels'][~self.noise], - self.recording.sampling_frequency) - - self.result['gt_comparison'] = GroundTruthComparison(self.result['sliced_gt_sorting'], - self.result['clustering'], - exhaustive_gt=self.exhaustive_gt) + self.result["clustering"] = NumpySorting.from_times_labels( + data["sample_index"], self.result["peak_labels"][~self.noise], self.recording.sampling_frequency + ) + + self.result["gt_comparison"] = GroundTruthComparison( + self.result["sliced_gt_sorting"], self.result["clustering"], exhaustive_gt=self.exhaustive_gt + ) - sorting_analyzer = create_sorting_analyzer(self.result['sliced_gt_sorting'], self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer( + self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False + ) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') - self.result['sliced_gt_templates'] = ext.get_data(outputs="Templates") + ext = sorting_analyzer.compute("fast_templates") + self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates") - sorting_analyzer = create_sorting_analyzer(self.result['clustering'], self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer( + self.result["clustering"], self.recording, format="memory", sparse=False + ) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') - self.result['clustering_templates'] = ext.get_data(outputs="Templates") + ext = sorting_analyzer.compute("fast_templates") + self.result["clustering_templates"] = ext.get_data(outputs="Templates") - _run_key_saved = [ - ("peak_labels", "npy") - ] + _run_key_saved = [("peak_labels", "npy")] _result_key_saved = [ ("gt_comparison", "pickle"), ("sliced_gt_sorting", "sorting"), ("clustering", "sorting"), ("sliced_gt_templates", "zarr_templates"), - ("clustering_templates", "zarr_templates") + ("clustering_templates", "zarr_templates"), ] @@ -115,8 +117,6 @@ def create_benchmark(self, key): return benchmark - - # class BenchmarkPeakSelection: # def __init__(self, recording, gt_sorting, exhaustive_gt=True, job_kwargs={}, tmp_folder=None, verbose=True): # self.verbose = verbose diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 8a239453e5..5f23fab255 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -17,6 +17,7 @@ _key_separator = "_-°°-_" + class BenchmarkStudy: """ Generic study for sorting components. @@ -31,7 +32,9 @@ class BenchmarkStudy: """ + benchmark_class = None + def __init__(self, study_folder): self.folder = Path(study_folder) self.datasets = {} @@ -147,7 +150,7 @@ def remove_benchmark(self, key): if result_folder.exists(): shutil.rmtree(result_folder) - for f in (log_file, ): + for f in (log_file,): if f.exists(): f.unlink() self.benchmarks[key] = None @@ -178,17 +181,17 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): benchmark.save_run(bench_folder) benchmark.result["run_time"] = float(t1 - t0) benchmark.save_main(bench_folder) - + def get_run_times(self, case_keys=None): if case_keys is None: case_keys = list(self.cases.keys()) - + run_times = {} for key in case_keys: benchmark = self.benchmarks[key] assert benchmark is not None run_times[key] = benchmark.result["run_time"] - + df = pd.DataFrame(dict(run_times=run_times)) if not isinstance(self.levels, str): df.index.names = self.levels @@ -199,9 +202,7 @@ def plot_run_times(self, case_keys=None): case_keys = list(self.cases.keys()) run_times = self.get_run_times(case_keys=case_keys) - run_times.plot(kind='bar') - - + run_times.plot(kind="bar") def compute_results(self, case_keys=None, verbose=False, **result_params): if case_keys is None: @@ -264,7 +265,7 @@ def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], f else: continue sorting_analyzer = self.get_sorting_analyzer(key) - qm_ext = sorting_analyzer.compute("quality_metrics", metric_names=metric_names) + qm_ext = sorting_analyzer.compute("quality_metrics", metric_names=metric_names) metrics = qm_ext.get_data() metrics.to_csv(filename, sep="\t", index=True) @@ -285,16 +286,16 @@ def get_metrics(self, key): def get_units_snr(self, key): """ """ return self.get_metrics(key)["snr"] - + def get_result(self, key): return self.benchmarks[key].result - class Benchmark: """ Responsible to make a unique run() and compute_result() for one case. """ + def __init__(self): self.result = {} @@ -310,14 +311,14 @@ def _save_keys(self, saved_keys, folder): for k, format in saved_keys: if format == "npy": np.save(folder / f"{k}.npy", self.result[k]) - elif format =="pickle": - with open(folder / f"{k}.pickle", mode="wb") as f: + elif format == "pickle": + with open(folder / f"{k}.pickle", mode="wb") as f: pickle.dump(self.result[k], f) - elif format == 'sorting': - self.result[k].save(folder = folder / k, format="numpy_folder") - elif format == 'zarr_templates': + elif format == "sorting": + self.result[k].save(folder=folder / k, format="numpy_folder") + elif format == "zarr_templates": self.result[k].to_zarr(folder / k) - elif format == 'sorting_analyzer': + elif format == "sorting_analyzer": pass else: raise ValueError(f"Save error {k} {format}") @@ -328,7 +329,7 @@ def save_main(self, folder): def save_run(self, folder): self._save_keys(self._run_key_saved, folder) - + def save_result(self, folder): self._save_keys(self._result_key_saved, folder) @@ -340,20 +341,22 @@ def load_folder(cls, folder): file = folder / f"{k}.npy" if file.exists(): result[k] = np.load(file) - elif format =="pickle": + elif format == "pickle": file = folder / f"{k}.pickle" if file.exists(): with open(file, mode="rb") as f: result[k] = pickle.load(f) - elif format =="sorting": + elif format == "sorting": from spikeinterface.core import load_extractor + result[k] = load_extractor(folder / k) - elif format =="zarr_templates": + elif format == "zarr_templates": from spikeinterface.core.template import Templates + result[k] = Templates.from_zarr(folder / k) return result - + def run(self): # run method raise NotImplementedError diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py index 091ab0820e..eb94b553a2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py @@ -3,6 +3,7 @@ This is not tested on github because not relevant at all. This only a local testing. """ + import pytest from pathlib import Path import os @@ -18,11 +19,7 @@ NoiseGeneratorRecording, ) from spikeinterface.core.generate import generate_unit_locations -from spikeinterface.generation import ( - DriftingTemplates, - make_linear_displacement, - InjectDriftingTemplatesRecording -) +from spikeinterface.generation import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording from probeinterface import generate_multi_columns_probe @@ -37,7 +34,6 @@ cache_folder = Path("cache_folder") / "sortingcomponents_benchmark" - def make_dataset(): recording, gt_sorting = generate_ground_truth_recording( durations=[60.0], @@ -57,26 +53,31 @@ def make_dataset(): ) return recording, gt_sorting -def compute_gt_templates(recording, gt_sorting, ms_before=2., ms_after=3., return_scaled=False, **job_kwargs): - spikes = gt_sorting.to_spike_vector()#[spike_indices] + +def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs): + spikes = gt_sorting.to_spike_vector() # [spike_indices] fs = recording.sampling_frequency nbefore = int(ms_before * fs / 1000) nafter = int(ms_after * fs / 1000) templates_array = estimate_templates( - recording, spikes, - gt_sorting.unit_ids, nbefore, nafter, return_scaled=return_scaled, + recording, + spikes, + gt_sorting.unit_ids, + nbefore, + nafter, + return_scaled=return_scaled, **job_kwargs, ) - + gt_templates = Templates( - templates_array=templates_array, - sampling_frequency=fs, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=gt_sorting.unit_ids, - probe=recording.get_probe(), - ) + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=gt_sorting.unit_ids, + probe=recording.get_probe(), + ) return gt_templates @@ -84,11 +85,10 @@ def make_drifting_dataset(): num_units = 15 duration = 125.5 - sampling_frequency = 30000. - ms_before = 1. - ms_after = 3. - displacement_sampling_frequency = 5. - + sampling_frequency = 30000.0 + ms_before = 1.0 + ms_after = 3.0 + displacement_sampling_frequency = 5.0 probe = generate_multi_columns_probe( num_columns=3, @@ -100,8 +100,6 @@ def make_drifting_dataset(): ) probe.set_device_channel_indices(np.arange(probe.contact_ids.size)) - - channel_locations = probe.contact_positions unit_locations = generate_unit_locations( @@ -116,9 +114,7 @@ def make_drifting_dataset(): seed=None, ) - - - nbefore = int(sampling_frequency * ms_before / 1000.) + nbefore = int(sampling_frequency * ms_before / 1000.0) generate_kwargs = dict( sampling_frequency=sampling_frequency, @@ -130,12 +126,9 @@ def make_drifting_dataset(): repolarization_ms=np.ones(num_units) * 0.8, ), unit_params_range=dict( - alpha=(4_000., 8_000.), + alpha=(4_000.0, 8_000.0), depolarization_ms=(0.09, 0.16), - ), - - ) templates_array = generate_templates(channel_locations, unit_locations, **generate_kwargs) @@ -149,28 +142,27 @@ def make_drifting_dataset(): drifting_templates = DriftingTemplates.from_static(templates) channel_locations = probe.contact_positions - start = np.array([0, -15.]) + start = np.array([0, -15.0]) stop = np.array([0, 12]) displacements = make_linear_displacement(start, stop, num_step=29) - sorting = generate_sorting( num_units=num_units, sampling_frequency=sampling_frequency, - durations = [duration,], - firing_rates=25.) + durations=[ + duration, + ], + firing_rates=25.0, + ) sorting - - - times = np.arange(0, duration, 1 / displacement_sampling_frequency) times # 2 rythm mid = (start + stop) / 2 freq0 = 0.1 - displacement_vector0 = np.sin(2 * np.pi * freq0 *times)[:, np.newaxis] * (start - stop) + mid + displacement_vector0 = np.sin(2 * np.pi * freq0 * times)[:, np.newaxis] * (start - stop) + mid # freq1 = 0.01 # displacement_vector1 = 0.2 * np.sin(2 * np.pi * freq1 *times)[:, np.newaxis] * (start - stop) + mid @@ -183,7 +175,6 @@ def make_drifting_dataset(): displacement_unit_factor = np.zeros((num_units, num_motion)) displacement_unit_factor[:, 0] = 1 - drifting_templates.precompute_displacements(displacements) direction = 1 @@ -196,7 +187,7 @@ def make_drifting_dataset(): num_channels=probe.contact_ids.size, sampling_frequency=sampling_frequency, durations=[duration], - noise_level=1., + noise_level=1.0, dtype="float32", ) @@ -207,7 +198,7 @@ def make_drifting_dataset(): displacement_vectors=[displacement_vectors], displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=displacement_unit_factor, - num_samples=[int(duration*sampling_frequency)], + num_samples=[int(duration * sampling_frequency)], amplitude_factor=None, ) @@ -218,19 +209,23 @@ def make_drifting_dataset(): displacement_vectors=[displacement_vectors], displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=np.zeros_like(displacement_unit_factor), - num_samples=[int(duration*sampling_frequency)], + num_samples=[int(duration * sampling_frequency)], amplitude_factor=None, ) - my_dict = _variable_from_namespace([ - drifting_rec, - static_rec, - sorting, - displacement_vectors, - displacement_sampling_frequency, - unit_locations, displacement_unit_factor, - unit_displacements - ], locals()) + my_dict = _variable_from_namespace( + [ + drifting_rec, + static_rec, + sorting, + displacement_vectors, + displacement_sampling_frequency, + unit_locations, + displacement_unit_factor, + unit_displacements, + ], + locals(), + ) return my_dict @@ -241,5 +236,3 @@ def _variable_from_namespace(objs, namespace): if namespace[name] is obj: d[name] = obj return d - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py index b60fb963fd..9e8f0e7404 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py @@ -12,10 +12,9 @@ from spikeinterface.sortingcomponents.benchmark.benchmark_clustering import ClusteringStudy - @pytest.mark.skip() def test_benchmark_clustering(): - + job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") recording, gt_sorting = make_dataset() @@ -23,19 +22,18 @@ def test_benchmark_clustering(): num_spikes = gt_sorting.to_spike_vector().size spike_indices = np.arange(0, num_spikes, 5) - # create study - study_folder = cache_folder / 'study_clustering' - datasets = {"toy" : (recording, gt_sorting)} + study_folder = cache_folder / "study_clustering" + datasets = {"toy": (recording, gt_sorting)} cases = {} - for method in ['random_projections', 'circus']: + for method in ["random_projections", "circus"]: cases[method] = { "label": f"{method} on toy", "dataset": "toy", - "init_kwargs": {'indices' : spike_indices}, - "params" : {"method" : method, "method_kwargs" : {}}, + "init_kwargs": {"indices": spike_indices}, + "params": {"method": method, "method_kwargs": {}}, } - + if study_folder.exists(): shutil.rmtree(study_folder) study = ClusteringStudy.create(study_folder, datasets=datasets, cases=cases) @@ -45,7 +43,6 @@ def test_benchmark_clustering(): study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() - study = ClusteringStudy(study_folder) # run and result @@ -64,8 +61,5 @@ def test_benchmark_clustering(): plt.show() - if __name__ == "__main__": test_benchmark_clustering() - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index 2af8bff1e5..805f5d8327 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -12,33 +12,41 @@ compute_sparsity, ) -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder, compute_gt_templates +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( + make_dataset, + cache_folder, + compute_gt_templates, +) from spikeinterface.sortingcomponents.benchmark.benchmark_matching import MatchingStudy @pytest.mark.skip() def test_benchmark_matching(): - + job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") recording, gt_sorting = make_dataset() # templates sparse - gt_templates = compute_gt_templates(recording, gt_sorting, ms_before=2., ms_after=3., return_scaled=False, **job_kwargs) + gt_templates = compute_gt_templates( + recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs + ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity(gt_templates, noise_levels, method='ptp', threshold=0.25) + sparsity = compute_sparsity(gt_templates, noise_levels, method="ptp", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) - # create study - study_folder = cache_folder / 'study_matching' - datasets = {"toy" : (recording, gt_sorting)} + study_folder = cache_folder / "study_matching" + datasets = {"toy": (recording, gt_sorting)} cases = {} - for engine in ['wobble', 'circus-omp-svd',]: + for engine in [ + "wobble", + "circus-omp-svd", + ]: cases[engine] = { "label": f"{engine} on toy", "dataset": "toy", - "params" : {"method" : engine, "method_kwargs" : {"templates" : gt_templates}}, + "params": {"method": engine, "method_kwargs": {"templates": gt_templates}}, } if study_folder.exists(): shutil.rmtree(study_folder) @@ -66,5 +74,3 @@ def test_benchmark_matching(): if __name__ == "__main__": test_benchmark_matching() - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index 0f009afa9a..7f24c07d3d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -7,16 +7,19 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_drifting_dataset, cache_folder +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( + make_drifting_dataset, + cache_folder, +) from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import MotionEstimationStudy + @pytest.mark.skip() def test_benchmark_motion_estimaton(): job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") - data = make_drifting_dataset() datasets = { @@ -25,39 +28,38 @@ def test_benchmark_motion_estimaton(): cases = {} for label, loc_method, est_method in [ - ("COM + KS", "center_of_mass", "iterative_template"), - ("Grid + Dec", "grid_convolution", "decentralized"), - ]: + ("COM + KS", "center_of_mass", "iterative_template"), + ("Grid + Dec", "grid_convolution", "decentralized"), + ]: cases[label] = dict( - label = label, + label=label, dataset="drifting_rec", init_kwargs=dict( unit_locations=data["unit_locations"], unit_displacements=data["unit_displacements"], displacement_sampling_frequency=data["displacement_sampling_frequency"], - direction="y" + direction="y", ), params=dict( - detect_kwargs=dict(method="locally_exclusive", detect_threshold=10.), + detect_kwargs=dict(method="locally_exclusive", detect_threshold=10.0), select_kwargs=None, localize_kwargs=dict(method=loc_method), estimate_motion_kwargs=dict( method=est_method, - bin_duration_s=1., - bin_um=5., + bin_duration_s=1.0, + bin_um=5.0, rigid=False, - win_step_um=50., - win_sigma_um=200., + win_step_um=50.0, + win_sigma_um=200.0, ), - ) + ), ) - study_folder = cache_folder / 'study_motion_estimation' + study_folder = cache_folder / "study_motion_estimation" if study_folder.exists(): shutil.rmtree(study_folder) study = MotionEstimationStudy.create(study_folder, datasets, cases) - # run and result study.run(**job_kwargs) study.compute_results() @@ -76,6 +78,3 @@ def test_benchmark_motion_estimaton(): if __name__ == "__main__": test_benchmark_motion_estimaton() - - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index cb8cc50b68..924b9ef385 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -10,10 +10,16 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_drifting_dataset, cache_folder +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( + make_drifting_dataset, + cache_folder, +) from spikeinterface.sortingcomponents.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import get_unit_disclacement, get_gt_motion_from_unit_discplacement +from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import ( + get_unit_disclacement, + get_gt_motion_from_unit_discplacement, +) @pytest.mark.skip() @@ -21,7 +27,6 @@ def test_benchmark_motion_interpolation(): job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") - data = make_drifting_dataset() datasets = { @@ -31,33 +36,31 @@ def test_benchmark_motion_interpolation(): duration = data["drifting_rec"].get_duration() channel_locations = data["drifting_rec"].get_channel_locations() - - - unit_displacements = get_unit_disclacement(data["displacement_vectors"], data["displacement_unit_factor"], direction_dim=1) + unit_displacements = get_unit_disclacement( + data["displacement_vectors"], data["displacement_unit_factor"], direction_dim=1 + ) bin_s = 1 temporal_bins = np.arange(0, duration, bin_s) - spatial_bins = np.linspace(np.min(channel_locations[:, 1]), - np.max(channel_locations[:, 1]), - 10 - ) + spatial_bins = np.linspace(np.min(channel_locations[:, 1]), np.max(channel_locations[:, 1]), 10) print(spatial_bins) gt_motion = get_gt_motion_from_unit_discplacement( - unit_displacements, data["displacement_sampling_frequency"], + unit_displacements, + data["displacement_sampling_frequency"], data["unit_locations"], - temporal_bins, spatial_bins, - direction_dim=1 + temporal_bins, + spatial_bins, + direction_dim=1, ) # fig, ax = plt.subplots() # ax.imshow(gt_motion.T) # plt.show() - cases = {} - bin_duration_s = 1. + bin_duration_s = 1.0 cases["static_SC2"] = dict( - label = "No drift - no correction - SC2", + label="No drift - no correction - SC2", dataset="data_static", init_kwargs=dict( drifting_recording=data["drifting_rec"], @@ -69,11 +72,11 @@ def test_benchmark_motion_interpolation(): recording_source="static", sorter_name="spykingcircus2", sorter_params=dict(), - ) + ), ) cases["drifting_SC2"] = dict( - label = "Drift - no correction - SC2", + label="Drift - no correction - SC2", dataset="data_static", init_kwargs=dict( drifting_recording=data["drifting_rec"], @@ -85,11 +88,11 @@ def test_benchmark_motion_interpolation(): recording_source="drifting", sorter_name="spykingcircus2", sorter_params=dict(), - ) + ), ) cases["drifting_SC2"] = dict( - label = "Drift - correction with GT - SC2", + label="Drift - correction with GT - SC2", dataset="data_static", init_kwargs=dict( drifting_recording=data["drifting_rec"], @@ -102,10 +105,10 @@ def test_benchmark_motion_interpolation(): sorter_name="spykingcircus2", sorter_params=dict(), correct_motion_kwargs=dict(spatial_interpolation_method="kriging"), - ) + ), ) - study_folder = cache_folder / 'study_motion_interpolation' + study_folder = cache_folder / "study_motion_interpolation" if study_folder.exists(): shutil.rmtree(study_folder) study = MotionInterpolationStudy.create(study_folder, datasets, cases) @@ -114,7 +117,6 @@ def test_benchmark_motion_interpolation(): study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() - # run and result study.run(**job_kwargs) study.compute_results() @@ -135,9 +137,5 @@ def test_benchmark_motion_interpolation(): plt.show() - - if __name__ == "__main__": test_benchmark_motion_interpolation() - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py index a1555b8dba..297ebc1cf2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py @@ -22,22 +22,23 @@ def test_benchmark_peak_localization(): recording, gt_sorting = make_dataset() - # create study - study_folder = cache_folder / 'study_peak_localization' - datasets = {"toy" : (recording, gt_sorting)} + study_folder = cache_folder / "study_peak_localization" + datasets = {"toy": (recording, gt_sorting)} cases = {} - for method in ['center_of_mass', 'grid_convolution', 'monopolar_triangulation']: + for method in ["center_of_mass", "grid_convolution", "monopolar_triangulation"]: cases[method] = { "label": f"{method} on toy", "dataset": "toy", - "init_kwargs": {"gt_positions" : gt_sorting.get_property('gt_unit_locations')}, - "params" : {"ms_before" : 2, - "method" : method, - "method_kwargs" : {}, - "spike_retriver_kwargs" : {"channel_from_template" : False}} + "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, + "params": { + "ms_before": 2, + "method": method, + "method_kwargs": {}, + "spike_retriver_kwargs": {"channel_from_template": False}, + }, } - + if study_folder.exists(): shutil.rmtree(study_folder) study = PeakLocalizationStudy.create(study_folder, datasets=datasets, cases=cases) @@ -66,20 +67,22 @@ def test_benchmark_unit_localization(): recording, gt_sorting = make_dataset() # create study - study_folder = cache_folder / 'study_unit_localization' - datasets = {"toy" : (recording, gt_sorting)} + study_folder = cache_folder / "study_unit_localization" + datasets = {"toy": (recording, gt_sorting)} cases = {} - for method in ['center_of_mass', 'grid_convolution', 'monopolar_triangulation']: + for method in ["center_of_mass", "grid_convolution", "monopolar_triangulation"]: cases[method] = { "label": f"{method} on toy", "dataset": "toy", - "init_kwargs": {"gt_positions" : gt_sorting.get_property('gt_unit_locations')}, - "params" : {"ms_before" : 2, - "method" : method, - "method_kwargs" : {}, - "spike_retriver_kwargs" : {"channel_from_template" : False}} + "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, + "params": { + "ms_before": 2, + "method": method, + "method_kwargs": {}, + "spike_retriver_kwargs": {"channel_from_template": False}, + }, } - + if study_folder.exists(): shutil.rmtree(study_folder) study = UnitLocalizationStudy.create(study_folder, datasets=datasets, cases=cases) @@ -101,8 +104,6 @@ def test_benchmark_unit_localization(): plt.show() - if __name__ == "__main__": # test_benchmark_peak_localization() test_benchmark_unit_localization() - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py index 78b59be489..f90a0c56d6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py @@ -13,9 +13,5 @@ def test_benchmark_peak_selection(): pass - - if __name__ == "__main__": test_benchmark_peak_selection() - - diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index c4fce76dad..24b4ca8022 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -61,7 +61,9 @@ def __init__( **backend_kwargs, ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - self.check_extensions(sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"]) + self.check_extensions( + sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"] + ) sorting = sorting_analyzer.sorting if unit_ids is None: @@ -183,10 +185,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): sorting_analyzer = data_plot["sorting_analyzer"] - import spikeinterface_gui - app = spikeinterface_gui.mkQApp() + + app = spikeinterface_gui.mkQApp() win = spikeinterface_gui.MainWindow(sorting_analyzer) win.show() app.exec_() - diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 7f761190f4..2d228d7d5f 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -103,7 +103,12 @@ def setUpClass(cls): print(f"Widgets tests: skipping backends - {cls.skip_backends}") - cls.backend_kwargs = {"matplotlib": {}, "sortingview": {}, "ipywidgets": {"display": False}, "spikeinterface_gui": {}} + cls.backend_kwargs = { + "matplotlib": {}, + "sortingview": {}, + "ipywidgets": {"display": False}, + "spikeinterface_gui": {}, + } cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) From ae274db37f589a80fa143fbb91fe71de7195060c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 27 Feb 2024 09:09:22 +0100 Subject: [PATCH 148/192] black --- examples/how_to/analyse_neuropixels.py | 93 ++++++++----- examples/how_to/get_started.py | 61 ++++----- examples/how_to/handle_drift.py | 80 ++++++----- .../comparison/generate_erroneous_sorting.py | 27 ++-- .../plot_5_comparison_sorter_weaknesses.py | 14 +- .../core/plot_1_recording_extractor.py | 59 +++++---- .../core/plot_2_sorting_extractor.py | 32 ++--- .../core/plot_3_handle_probe_info.py | 13 +- .../core/plot_4_sorting_analyzer.py | 34 ++--- .../plot_5_append_concatenate_segments.py | 16 +-- .../core/plot_6_handle_times.py | 1 + .../extractors/plot_1_read_various_formats.py | 16 +-- .../plot_2_working_with_unscaled_traces.py | 12 +- .../qualitymetrics/plot_3_quality_mertics.py | 23 +++- .../qualitymetrics/plot_4_curation.py | 11 +- .../widgets/plot_1_rec_gallery.py | 10 +- .../widgets/plot_2_sort_gallery.py | 7 +- .../widgets/plot_3_waveforms_gallery.py | 25 ++-- .../widgets/plot_4_peaks_gallery.py | 37 +++--- .../core/analyzer_extension_core.py | 25 ++-- src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/sortinganalyzer.py | 14 +- .../tests/test_analyzer_extension_core.py | 7 +- ...forms_extractor_backwards_compatibility.py | 4 +- .../postprocessing/spike_amplitudes.py | 2 +- .../benchmark/benchmark_clustering.py | 124 +++++++++--------- .../benchmark/benchmark_matching.py | 76 +++++------ .../benchmark/benchmark_motion_estimation.py | 80 +++++------ .../benchmark_motion_interpolation.py | 61 +++++---- .../benchmark/benchmark_peak_localization.py | 107 +++++++-------- .../benchmark/benchmark_peak_selection.py | 64 ++++----- .../benchmark/benchmark_tools.py | 45 ++++--- .../tests/common_benchmark_testing.py | 105 +++++++-------- .../tests/test_benchmark_clustering.py | 20 +-- .../tests/test_benchmark_matching.py | 28 ++-- .../tests/test_benchmark_motion_estimation.py | 35 +++-- .../test_benchmark_motion_interpolation.py | 52 ++++---- .../tests/test_benchmark_peak_localization.py | 43 +++--- .../tests/test_benchmark_peak_selection.py | 4 - src/spikeinterface/widgets/sorting_summary.py | 9 +- .../widgets/tests/test_widgets.py | 7 +- 41 files changed, 780 insertions(+), 705 deletions(-) diff --git a/examples/how_to/analyse_neuropixels.py b/examples/how_to/analyse_neuropixels.py index eed05a0ee5..29d19f2331 100644 --- a/examples/how_to/analyse_neuropixels.py +++ b/examples/how_to/analyse_neuropixels.py @@ -28,9 +28,9 @@ from pathlib import Path # + -base_folder = Path('/mnt/data/sam/DataSpikeSorting/neuropixel_example/') +base_folder = Path("/mnt/data/sam/DataSpikeSorting/neuropixel_example/") -spikeglx_folder = base_folder / 'Rec_1_10_11_2021_g0' +spikeglx_folder = base_folder / "Rec_1_10_11_2021_g0" # - @@ -40,11 +40,11 @@ # We need to specify which one to read: # -stream_names, stream_ids = si.get_neo_streams('spikeglx', spikeglx_folder) +stream_names, stream_ids = si.get_neo_streams("spikeglx", spikeglx_folder) stream_names # we do not load the sync channel, so the probe is automatically loaded -raw_rec = si.read_spikeglx(spikeglx_folder, stream_name='imec0.ap', load_sync_channel=False) +raw_rec = si.read_spikeglx(spikeglx_folder, stream_name="imec0.ap", load_sync_channel=False) raw_rec # we automatically have the probe loaded! @@ -63,10 +63,10 @@ # # + -rec1 = si.highpass_filter(raw_rec, freq_min=400.) +rec1 = si.highpass_filter(raw_rec, freq_min=400.0) bad_channel_ids, channel_labels = si.detect_bad_channels(rec1) rec2 = rec1.remove_channels(bad_channel_ids) -print('bad_channel_ids', bad_channel_ids) +print("bad_channel_ids", bad_channel_ids) rec3 = si.phase_shift(rec2) rec4 = si.common_reference(rec3, operator="median", reference="global") @@ -94,17 +94,23 @@ # here we use a static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) -si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) -si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) -si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) -for i, label in enumerate(('filter', 'cmr', 'final')): +si.plot_traces(rec1, backend="matplotlib", clim=(-50, 50), ax=axs[0]) +si.plot_traces(rec4, backend="matplotlib", clim=(-50, 50), ax=axs[1]) +si.plot_traces(rec, backend="matplotlib", clim=(-50, 50), ax=axs[2]) +for i, label in enumerate(("filter", "cmr", "final")): axs[i].set_title(label) # - # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) -some_chans = rec.channel_ids[[100, 150, 200, ]] -si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) +some_chans = rec.channel_ids[ + [ + 100, + 150, + 200, + ] +] +si.plot_traces({"filter": rec1, "cmr": rec4}, backend="matplotlib", mode="line", ax=ax, channel_ids=some_chans) # ### Should we save the preprocessed data to a binary file? @@ -118,9 +124,9 @@ # Depending on the complexity of the preprocessing chain, this operation can take a while. However, we can make use of the powerful parallelization mechanism of SpikeInterface. # + -job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) +job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True) -rec = rec.save(folder=base_folder / 'preprocess', format='binary', **job_kwargs) +rec = rec.save(folder=base_folder / "preprocess", format="binary", **job_kwargs) # - # our recording now points to the new binary folder @@ -149,7 +155,7 @@ fig, ax = plt.subplots() _ = ax.hist(noise_levels_microV, bins=np.arange(5, 30, 2.5)) -ax.set_xlabel('noise [microV]') +ax.set_xlabel("noise [microV]") # ### Detect and localize peaks # @@ -168,15 +174,16 @@ # + from spikeinterface.sortingcomponents.peak_detection import detect_peaks -job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) -peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, - detect_threshold=5, radius_um=50., **job_kwargs) +job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True) +peaks = detect_peaks( + rec, method="locally_exclusive", noise_levels=noise_levels_int16, detect_threshold=5, radius_um=50.0, **job_kwargs +) peaks # + from spikeinterface.sortingcomponents.peak_localization import localize_peaks -peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs) +peak_locations = localize_peaks(rec, peaks, method="center_of_mass", radius_um=50.0, **job_kwargs) # - # ### Check for drift @@ -190,7 +197,7 @@ # check for drift fs = rec.sampling_frequency fig, ax = plt.subplots(figsize=(10, 8)) -ax.scatter(peaks['sample_index'] / fs, peak_locations['y'], color='k', marker='.', alpha=0.002) +ax.scatter(peaks["sample_index"] / fs, peak_locations["y"], color="k", marker=".", alpha=0.002) # + @@ -199,7 +206,7 @@ si.plot_probe_map(rec, ax=ax, with_channel_ids=True) ax.set_ylim(-100, 150) -ax.scatter(peak_locations['x'], peak_locations['y'], color='purple', alpha=0.002) +ax.scatter(peak_locations["x"], peak_locations["y"], color="purple", alpha=0.002) # - # ## Run a spike sorter @@ -222,18 +229,24 @@ # # check default params for kilosort2.5 -si.get_default_sorter_params('kilosort2_5') +si.get_default_sorter_params("kilosort2_5") # + # run kilosort2.5 without drift correction -params_kilosort2_5 = {'do_correction': False} - -sorting = si.run_sorter('kilosort2_5', rec, output_folder=base_folder / 'kilosort2.5_output', - docker_image=True, verbose=True, **params_kilosort2_5) +params_kilosort2_5 = {"do_correction": False} + +sorting = si.run_sorter( + "kilosort2_5", + rec, + output_folder=base_folder / "kilosort2.5_output", + docker_image=True, + verbose=True, + **params_kilosort2_5, +) # - # the results can be read back for future sessions -sorting = si.read_sorter_folder(base_folder / 'kilosort2.5_output') +sorting = si.read_sorter_folder(base_folder / "kilosort2.5_output") # here we have 31 units in our recording sorting @@ -247,16 +260,23 @@ # Note that we use the `sparse=True` option. This option is important because the waveforms will be extracted only for a few channels around the main channel of each unit. This saves tons of disk space and speeds up the waveforms extraction and further processing. # -we = si.extract_waveforms(rec, sorting, folder=base_folder / 'waveforms_kilosort2.5', - sparse=True, max_spikes_per_unit=500, ms_before=1.5,ms_after=2., - **job_kwargs) +we = si.extract_waveforms( + rec, + sorting, + folder=base_folder / "waveforms_kilosort2.5", + sparse=True, + max_spikes_per_unit=500, + ms_before=1.5, + ms_after=2.0, + **job_kwargs, +) # the `WaveformExtractor` contains all information and is persistent on disk print(we) print(we.folder) # the `WaveformExtrator` can be easily loaded back from its folder -we = si.load_waveforms(base_folder / 'waveforms_kilosort2.5') +we = si.load_waveforms(base_folder / "waveforms_kilosort2.5") we # Many additional computations rely on the `WaveformExtractor`. @@ -281,8 +301,9 @@ # # `si.compute_principal_components(waveform_extractor)` -metrics = si.compute_quality_metrics(we, metric_names=['firing_rate', 'presence_ratio', 'snr', - 'isi_violation', 'amplitude_cutoff']) +metrics = si.compute_quality_metrics( + we, metric_names=["firing_rate", "presence_ratio", "snr", "isi_violation", "amplitude_cutoff"] +) metrics # ## Curation using metrics @@ -306,16 +327,16 @@ # # In order to export the final results we need to make a copy of the the waveforms, but only for the selected units (so we can avoid computing them again). -we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / 'waveforms_clean') +we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / "waveforms_clean") we_clean # Then we export figures to a report folder # export spike sorting report to a folder -si.export_report(we_clean, base_folder / 'report', format='png') +si.export_report(we_clean, base_folder / "report", format="png") -we_clean = si.load_waveforms(base_folder / 'waveforms_clean') +we_clean = si.load_waveforms(base_folder / "waveforms_clean") we_clean # And push the results to sortingview webased viewer diff --git a/examples/how_to/get_started.py b/examples/how_to/get_started.py index 329a2b32b0..556153fce5 100644 --- a/examples/how_to/get_started.py +++ b/examples/how_to/get_started.py @@ -79,7 +79,7 @@ # Then we can open it. Note that [MEArec](https://mearec.readthedocs.io>) simulated file # contains both a "recording" and a "sorting" object. -local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') +local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") recording, sorting_true = se.read_mearec(local_path) print(recording) print(sorting_true) @@ -103,10 +103,10 @@ num_chan = recording.get_num_channels() num_seg = recording.get_num_segments() -print('Channel ids:', channel_ids) -print('Sampling frequency:', fs) -print('Number of channels:', num_chan) -print('Number of segments:', num_seg) +print("Channel ids:", channel_ids) +print("Sampling frequency:", fs) +print("Number of channels:", num_chan) +print("Number of segments:", num_seg) # - # ...and from a `BaseSorting` @@ -116,9 +116,9 @@ unit_ids = sorting_true.get_unit_ids() spike_train = sorting_true.get_unit_spike_train(unit_id=unit_ids[0]) -print('Number of segments:', num_seg) -print('Unit ids:', unit_ids) -print('Spike train of first unit:', spike_train) +print("Number of segments:", num_seg) +print("Unit ids:", unit_ids) +print("Spike train of first unit:", spike_train) # - # SpikeInterface internally uses the [`ProbeInterface`](https://probeinterface.readthedocs.io/en/main/) to handle `probeinterface.Probe` and @@ -144,19 +144,19 @@ recording_cmr = recording recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000) print(recording_f) -recording_cmr = si.common_reference(recording_f, reference='global', operator='median') +recording_cmr = si.common_reference(recording_f, reference="global", operator="median") print(recording_cmr) # this computes and saves the recording after applying the preprocessing chain -recording_preprocessed = recording_cmr.save(format='binary') +recording_preprocessed = recording_cmr.save(format="binary") print(recording_preprocessed) # - # Now you are ready to spike sort using the `spikeinterface.sorters` module! # Let's first check which sorters are implemented and which are installed -print('Available sorters', ss.available_sorters()) -print('Installed sorters', ss.installed_sorters()) +print("Available sorters", ss.available_sorters()) +print("Installed sorters", ss.installed_sorters()) # The `ss.installed_sorters()` will list the sorters installed on the machine. # We can see we have HerdingSpikes and Tridesclous installed. @@ -164,9 +164,9 @@ # The available parameters are dictionaries and can be accessed with: print("Tridesclous params:") -pprint(ss.get_default_sorter_params('tridesclous')) +pprint(ss.get_default_sorter_params("tridesclous")) print("SpykingCircus2 params:") -pprint(ss.get_default_sorter_params('spykingcircus2')) +pprint(ss.get_default_sorter_params("spykingcircus2")) # Let's run `tridesclous` and change one of the parameters, say, the `detect_threshold`: @@ -176,12 +176,13 @@ # Alternatively we can pass a full dictionary containing the parameters: # + -other_params = ss.get_default_sorter_params('tridesclous') -other_params['detect_threshold'] = 6 +other_params = ss.get_default_sorter_params("tridesclous") +other_params["detect_threshold"] = 6 # parameters set by params dictionary -sorting_TDC_2 = ss.run_sorter(sorter_name="tridesclous", recording=recording_preprocessed, - output_folder="tdc_output2", **other_params) +sorting_TDC_2 = ss.run_sorter( + sorter_name="tridesclous", recording=recording_preprocessed, output_folder="tdc_output2", **other_params +) print(sorting_TDC_2) # - @@ -192,13 +193,12 @@ # The `sorting_TDC` and `sorting_SC2` are `BaseSorting` objects. We can print the units found using: -print('Units found by tridesclous:', sorting_TDC.get_unit_ids()) -print('Units found by spyking-circus2:', sorting_SC2.get_unit_ids()) +print("Units found by tridesclous:", sorting_TDC.get_unit_ids()) +print("Units found by spyking-circus2:", sorting_SC2.get_unit_ids()) # If a sorter is not installed locally, we can also avoid installing it and run it anyways, using a container (Docker or Singularity). For example, let's run `Kilosort2` using Docker: -sorting_KS2 = ss.run_sorter(sorter_name="kilosort2", recording=recording_preprocessed, - docker_image=True, verbose=True) +sorting_KS2 = ss.run_sorter(sorter_name="kilosort2", recording=recording_preprocessed, docker_image=True, verbose=True) print(sorting_KS2) # SpikeInterface provides a efficient way to extract waveforms from paired recording/sorting objects. @@ -206,7 +206,7 @@ # for each unit, extracts their waveforms, and stores them to disk. These waveforms are helpful to compute the average waveform, or "template", for each unit and then to compute, for example, quality metrics. # + -we_TDC = si.extract_waveforms(recording_preprocessed, sorting_TDC, 'waveforms_folder', overwrite=True) +we_TDC = si.extract_waveforms(recording_preprocessed, sorting_TDC, "waveforms_folder", overwrite=True) print(we_TDC) unit_id0 = sorting_TDC.unit_ids[0] @@ -236,7 +236,7 @@ # Importantly, waveform extractors (and all extensions) can be reloaded at later times: -we_loaded = si.load_waveforms('waveforms_folder') +we_loaded = si.load_waveforms("waveforms_folder") print(we_loaded.get_available_extension_names()) # Once we have computed all of the postprocessing information, we can compute quality metrics (different quality metrics require different extensions - e.g., drift metrics require `spike_locations`): @@ -277,21 +277,21 @@ # Alternatively, we can export the data locally to Phy. [Phy]() is a GUI for manual # curation of the spike sorting output. To export to phy you can run: -sexp.export_to_phy(we_TDC, 'phy_folder_for_TDC', verbose=True) +sexp.export_to_phy(we_TDC, "phy_folder_for_TDC", verbose=True) # Then you can run the template-gui with: `phy template-gui phy_folder_for_TDC/params.py` # and manually curate the results. # After curating with Phy, the curated sorting can be reloaded to SpikeInterface. In this case, we exclude the units that have been labeled as "noise": -sorting_curated_phy = se.read_phy('phy_folder_for_TDC', exclude_cluster_groups=["noise"]) +sorting_curated_phy = se.read_phy("phy_folder_for_TDC", exclude_cluster_groups=["noise"]) # Quality metrics can be also used to automatically curate the spike sorting # output. For example, you can select sorted units with a SNR above a # certain threshold: # + -keep_mask = (qm['snr'] > 10) & (qm['isi_violations_ratio'] < 0.01) +keep_mask = (qm["snr"] > 10) & (qm["isi_violations_ratio"] < 0.01) print("Mask:", keep_mask.values) sorting_curated_auto = sorting_TDC.select_units(sorting_TDC.unit_ids[keep_mask]) @@ -310,8 +310,9 @@ comp_gt = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_TDC) comp_pair = sc.compare_two_sorters(sorting1=sorting_TDC, sorting2=sorting_SC2) -comp_multi = sc.compare_multiple_sorters(sorting_list=[sorting_TDC, sorting_SC2, sorting_KS2], - name_list=['tdc', 'sc2', 'ks2']) +comp_multi = sc.compare_multiple_sorters( + sorting_list=[sorting_TDC, sorting_SC2, sorting_KS2], name_list=["tdc", "sc2", "ks2"] +) # When comparing with a ground-truth sorting (1,), you can get the sorting performance and plot a confusion # matrix @@ -335,7 +336,7 @@ # + sorting_agreement = comp_multi.get_agreement_sorting(minimum_agreement_count=2) -print('Units in agreement between TDC, SC2, and KS2:', sorting_agreement.get_unit_ids()) +print("Units in agreement between TDC, SC2, and KS2:", sorting_agreement.get_unit_ids()) w_multi = sw.plot_multicomparison_agreement(comp_multi) w_multi = sw.plot_multicomparison_agreement_by_sorter(comp_multi) diff --git a/examples/how_to/handle_drift.py b/examples/how_to/handle_drift.py index a1671a7424..79a7c899f5 100644 --- a/examples/how_to/handle_drift.py +++ b/examples/how_to/handle_drift.py @@ -54,10 +54,11 @@ import shutil import spikeinterface.full as si + # - -base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick') -dataset_folder = base_folder / 'dataset1/NP1' +base_folder = Path("/mnt/data/sam/DataSpikeSorting/imposed_motion_nick") +dataset_folder = base_folder / "dataset1/NP1" # read the file raw_rec = si.read_spikeglx(dataset_folder) @@ -67,13 +68,16 @@ # We preprocess the recording with bandpass filter and a common median reference. # Note, that it is better to not whiten the recording before motion estimation to get a better estimate of peak locations! + def preprocess_chain(rec): - rec = si.bandpass_filter(rec, freq_min=300., freq_max=6000.) - rec = si.common_reference(rec, reference='global', operator='median') + rec = si.bandpass_filter(rec, freq_min=300.0, freq_max=6000.0) + rec = si.common_reference(rec, reference="global", operator="median") return rec + + rec = preprocess_chain(raw_rec) -job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) +job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True) # ### Run motion correction with one function! # @@ -87,21 +91,22 @@ def preprocess_chain(rec): # internally, we can explore a preset like this # every parameter can be overwritten at runtime from spikeinterface.preprocessing.motion import motion_options_preset -motion_options_preset['kilosort_like'] + +motion_options_preset["kilosort_like"] # lets try theses 3 presets -some_presets = ('rigid_fast', 'kilosort_like', 'nonrigid_accurate') +some_presets = ("rigid_fast", "kilosort_like", "nonrigid_accurate") # some_presets = ('kilosort_like', ) # compute motion with 3 presets for preset in some_presets: - print('Computing with', preset) - folder = base_folder / 'motion_folder_dataset1' / preset + print("Computing with", preset) + folder = base_folder / "motion_folder_dataset1" / preset if folder.exists(): shutil.rmtree(folder) - recording_corrected, motion_info = si.correct_motion(rec, preset=preset, - folder=folder, - output_motion_info=True, **job_kwargs) + recording_corrected, motion_info = si.correct_motion( + rec, preset=preset, folder=folder, output_motion_info=True, **job_kwargs + ) # ### Plot the results # @@ -130,13 +135,19 @@ def preprocess_chain(rec): for preset in some_presets: # load - folder = base_folder / 'motion_folder_dataset1' / preset + folder = base_folder / "motion_folder_dataset1" / preset motion_info = si.load_motion_info(folder) # and plot fig = plt.figure(figsize=(14, 8)) - si.plot_motion(motion_info, figure=fig, depth_lim=(400, 600), - color_amplitude=True, amplitude_cmap='inferno', scatter_decimate=10) + si.plot_motion( + motion_info, + figure=fig, + depth_lim=(400, 600), + color_amplitude=True, + amplitude_cmap="inferno", + scatter_decimate=10, + ) fig.suptitle(f"{preset=}") @@ -159,7 +170,7 @@ def preprocess_chain(rec): from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks for preset in some_presets: - folder = base_folder / 'motion_folder_dataset1' / preset + folder = base_folder / "motion_folder_dataset1" / preset motion_info = si.load_motion_info(folder) fig, axs = plt.subplots(ncols=2, figsize=(12, 8), sharey=True) @@ -167,29 +178,36 @@ def preprocess_chain(rec): ax = axs[0] si.plot_probe_map(rec, ax=ax) - peaks = motion_info['peaks'] + peaks = motion_info["peaks"] sr = rec.get_sampling_frequency() - time_lim0 = 750. - time_lim1 = 1500. - mask = (peaks['sample_index'] > int(sr * time_lim0)) & (peaks['sample_index'] < int(sr * time_lim1)) + time_lim0 = 750.0 + time_lim1 = 1500.0 + mask = (peaks["sample_index"] > int(sr * time_lim0)) & (peaks["sample_index"] < int(sr * time_lim1)) sl = slice(None, None, 5) - amps = np.abs(peaks['amplitude'][mask][sl]) + amps = np.abs(peaks["amplitude"][mask][sl]) amps /= np.quantile(amps, 0.95) - c = plt.get_cmap('inferno')(amps) + c = plt.get_cmap("inferno")(amps) color_kargs = dict(alpha=0.2, s=2, c=c) - loc = motion_info['peak_locations'] - #color='black', - ax.scatter(loc['x'][mask][sl], loc['y'][mask][sl], **color_kargs) + loc = motion_info["peak_locations"] + # color='black', + ax.scatter(loc["x"][mask][sl], loc["y"][mask][sl], **color_kargs) - loc2 = correct_motion_on_peaks(motion_info['peaks'], motion_info['peak_locations'], rec.sampling_frequency, - motion_info['motion'], motion_info['temporal_bins'], motion_info['spatial_bins'], direction="y") + loc2 = correct_motion_on_peaks( + motion_info["peaks"], + motion_info["peak_locations"], + rec.sampling_frequency, + motion_info["motion"], + motion_info["temporal_bins"], + motion_info["spatial_bins"], + direction="y", + ) ax = axs[1] si.plot_probe_map(rec, ax=ax) # color='black', - ax.scatter(loc2['x'][mask][sl], loc2['y'][mask][sl], **color_kargs) + ax.scatter(loc2["x"][mask][sl], loc2["y"][mask][sl], **color_kargs) ax.set_ylim(400, 600) fig.suptitle(f"{preset=}") @@ -204,16 +222,16 @@ def preprocess_chain(rec): # + run_times = [] for preset in some_presets: - folder = base_folder / 'motion_folder_dataset1' / preset + folder = base_folder / "motion_folder_dataset1" / preset motion_info = si.load_motion_info(folder) - run_times.append(motion_info['run_times']) + run_times.append(motion_info["run_times"]) keys = run_times[0].keys() bottom = np.zeros(len(run_times)) fig, ax = plt.subplots() for k in keys: rtimes = np.array([rt[k] for rt in run_times]) - if np.any(rtimes>0.): + if np.any(rtimes > 0.0): ax.bar(some_presets, rtimes, bottom=bottom, label=k) bottom += rtimes ax.legend() diff --git a/examples/modules_gallery/comparison/generate_erroneous_sorting.py b/examples/modules_gallery/comparison/generate_erroneous_sorting.py index d62a15bdc0..608e23d7f5 100644 --- a/examples/modules_gallery/comparison/generate_erroneous_sorting.py +++ b/examples/modules_gallery/comparison/generate_erroneous_sorting.py @@ -11,6 +11,7 @@ import spikeinterface.comparison as sc import spikeinterface.widgets as sw + def generate_erroneous_sorting(): """ Generate an erroneous spike sorting for illustration purposes. @@ -36,14 +37,13 @@ def generate_erroneous_sorting(): rec, sorting_true = se.toy_example(num_channels=4, num_units=10, duration=10, seed=10, num_segments=1) # artificially remap to one based - sorting_true = sorting_true.select_units(unit_ids=None, - renamed_unit_ids=np.arange(10, dtype='int64')+1) + sorting_true = sorting_true.select_units(unit_ids=None, renamed_unit_ids=np.arange(10, dtype="int64") + 1) sampling_frequency = sorting_true.get_sampling_frequency() units_err = {} - # sorting_true have 10 units + # sorting_true have 10 units np.random.seed(0) # unit 1 2 are perfect @@ -52,16 +52,16 @@ def generate_erroneous_sorting(): units_err[u] = st # unit 3 4 (medium) 10 (low) have medium to low agreement - for u, score in [(3, 0.8), (4, 0.75), (10, 0.3)]: + for u, score in [(3, 0.8), (4, 0.75), (10, 0.3)]: st = sorting_true.get_unit_spike_train(u) - st = np.sort(np.random.choice(st, size=int(st.size*score), replace=False)) + st = np.sort(np.random.choice(st, size=int(st.size * score), replace=False)) units_err[u] = st # unit 5 6 are over merge st5 = sorting_true.get_unit_spike_train(5) st6 = sorting_true.get_unit_spike_train(6) st = np.unique(np.concatenate([st5, st6])) - st = np.sort(np.random.choice(st, size=int(st.size*0.7), replace=False)) + st = np.sort(np.random.choice(st, size=int(st.size * 0.7), replace=False)) units_err[56] = st # unit 7 is over split in 2 part @@ -69,14 +69,14 @@ def generate_erroneous_sorting(): st70 = st7[::2] units_err[70] = st70 st71 = st7[1::2] - st71 = np.sort(np.random.choice(st71, size=int(st71.size*0.9), replace=False)) + st71 = np.sort(np.random.choice(st71, size=int(st71.size * 0.9), replace=False)) units_err[71] = st71 # unit 8 is redundant 3 times st8 = sorting_true.get_unit_spike_train(8) - st80 = np.sort(np.random.choice(st8, size=int(st8.size*0.65), replace=False)) - st81 = np.sort(np.random.choice(st8, size=int(st8.size*0.6), replace=False)) - st82 = np.sort(np.random.choice(st8, size=int(st8.size*0.55), replace=False)) + st80 = np.sort(np.random.choice(st8, size=int(st8.size * 0.65), replace=False)) + st81 = np.sort(np.random.choice(st8, size=int(st8.size * 0.6), replace=False)) + st82 = np.sort(np.random.choice(st8, size=int(st8.size * 0.55), replace=False)) units_err[80] = st80 units_err[81] = st81 units_err[82] = st82 @@ -85,18 +85,15 @@ def generate_erroneous_sorting(): # there are some units that do not exist 15 16 and 17 nframes = rec.get_num_frames(segment_index=0) - for u in [15,16,17]: + for u in [15, 16, 17]: st = np.sort(np.random.randint(0, high=nframes, size=35)) units_err[u] = st sorting_err = se.NumpySorting.from_unit_dict(units_err, sampling_frequency) - return sorting_true, sorting_err - - -if __name__ == '__main__': +if __name__ == "__main__": # just for check sorting_true, sorting_err = generate_erroneous_sorting() comp = sc.compare_sorter_to_ground_truth(sorting_true, sorting_err, exhaustive_gt=True) diff --git a/examples/modules_gallery/comparison/plot_5_comparison_sorter_weaknesses.py b/examples/modules_gallery/comparison/plot_5_comparison_sorter_weaknesses.py index c32c683941..562b174a31 100644 --- a/examples/modules_gallery/comparison/plot_5_comparison_sorter_weaknesses.py +++ b/examples/modules_gallery/comparison/plot_5_comparison_sorter_weaknesses.py @@ -31,7 +31,6 @@ """ - ############################################################################## # Import @@ -55,36 +54,36 @@ ############################################################################## # Here the same matrix but **ordered** -# It is now quite trivial to check that fake injected errors are enlighted here. +# It is now quite trivial to check that fake injected errors are enlighted here. sw.plot_agreement_matrix(comp, ordered=True) ############################################################################## # Here we can see that only Units 1 2 and 3 are well detected with 'accuracy'>0.75 -print('well_detected', comp.get_well_detected_units(well_detected_score=0.75)) +print("well_detected", comp.get_well_detected_units(well_detected_score=0.75)) ############################################################################## # Here we can explore **"false positive units"** units that do not exists in ground truth -print('false_positive', comp.get_false_positive_units(redundant_score=0.2)) +print("false_positive", comp.get_false_positive_units(redundant_score=0.2)) ############################################################################## # Here we can explore **"redundant units"** units that do not exists in ground truth -print('redundant', comp.get_redundant_units(redundant_score=0.2)) +print("redundant", comp.get_redundant_units(redundant_score=0.2)) ############################################################################## # Here we can explore **"overmerged units"** units that do not exists in ground truth -print('overmerged', comp.get_overmerged_units(overmerged_score=0.2)) +print("overmerged", comp.get_overmerged_units(overmerged_score=0.2)) ############################################################################## # Here we can explore **"bad units"** units that a mixed a several possible errors. -print('bad', comp.get_bad_units()) +print("bad", comp.get_bad_units()) ############################################################################## @@ -93,5 +92,4 @@ comp.print_summary(well_detected_score=0.75, redundant_score=0.2, overmerged_score=0.2) - plt.show() diff --git a/examples/modules_gallery/core/plot_1_recording_extractor.py b/examples/modules_gallery/core/plot_1_recording_extractor.py index f5d3ee1db2..aa59abd76d 100644 --- a/examples/modules_gallery/core/plot_1_recording_extractor.py +++ b/examples/modules_gallery/core/plot_1_recording_extractor.py @@ -1,4 +1,4 @@ -''' +""" Recording objects ================= @@ -12,7 +12,8 @@ * saving (caching) -''' +""" + import matplotlib.pyplot as plt import numpy as np @@ -25,8 +26,8 @@ # Let's define the properties of the dataset: num_channels = 7 -sampling_frequency = 30000. # in Hz -durations = [10., 15.] # in s for 2 segments +sampling_frequency = 30000.0 # in Hz +durations = [10.0, 15.0] # in s for 2 segments num_segments = 2 num_timepoints = [int(sampling_frequency * d) for d in durations] @@ -47,11 +48,11 @@ ############################################################################## # We can now print properties that the :code:`RecordingExtractor` retrieves from the underlying recording. -print(f'Number of channels = {recording.get_channel_ids()}') -print(f'Sampling frequency = {recording.get_sampling_frequency()} Hz') -print(f'Number of segments= {recording.get_num_segments()}') -print(f'Number of timepoints in seg0= {recording.get_num_frames(segment_index=0)}') -print(f'Number of timepoints in seg1= {recording.get_num_frames(segment_index=1)}') +print(f"Number of channels = {recording.get_channel_ids()}") +print(f"Sampling frequency = {recording.get_sampling_frequency()} Hz") +print(f"Number of segments= {recording.get_num_segments()}") +print(f"Number of timepoints in seg0= {recording.get_num_frames(segment_index=0)}") +print(f"Number of timepoints in seg1= {recording.get_num_frames(segment_index=1)}") ############################################################################## # The geometry of the Probe is handled with the :probeinterface:`ProbeInterface <>` library. @@ -62,7 +63,7 @@ from probeinterface import generate_linear_probe from probeinterface.plotting import plot_probe -probe = generate_linear_probe(num_elec=7, ypitch=20, contact_shapes='circle', contact_shape_params={'radius': 6}) +probe = generate_linear_probe(num_elec=7, ypitch=20, contact_shapes="circle", contact_shape_params={"radius": 6}) # the probe has to be wired to the recording device (i.e., which electrode corresponds to an entry in the data # matrix) @@ -75,7 +76,7 @@ ############################################################################## # Some extractors also implement a :code:`write` function. -file_paths = ['traces0.raw', 'traces1.raw'] +file_paths = ["traces0.raw", "traces1.raw"] se.BinaryRecordingExtractor.write_recording(recording, file_paths) ############################################################################## @@ -83,7 +84,9 @@ # Note that this new recording is now "on disk" and not "in memory" as the Numpy recording was. # This means that the loading is "lazy" and the data are not loaded into memory. -recording2 = se.BinaryRecordingExtractor(file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=traces0.dtype) +recording2 = se.BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=traces0.dtype +) print(recording2) ############################################################################## @@ -100,38 +103,40 @@ # Internally, a recording has :code:`channel_ids`: that are a vector that can have a # dtype of :code:`int` or :code:`str`: -print('chan_ids (dtype=int):', recording.get_channel_ids()) +print("chan_ids (dtype=int):", recording.get_channel_ids()) -recording3 = se.NumpyRecording(traces_list=[traces0, traces1], - sampling_frequency=sampling_frequency, - channel_ids=['a', 'b', 'c', 'd', 'e', 'f', 'g']) -print('chan_ids (dtype=str):', recording3.get_channel_ids()) +recording3 = se.NumpyRecording( + traces_list=[traces0, traces1], + sampling_frequency=sampling_frequency, + channel_ids=["a", "b", "c", "d", "e", "f", "g"], +) +print("chan_ids (dtype=str):", recording3.get_channel_ids()) ############################################################################## # :code:`channel_ids` are used to retrieve information (e.g. traces) only on a # subset of channels: -traces = recording3.get_traces(segment_index=1, end_frame=50, channel_ids=['a', 'd']) +traces = recording3.get_traces(segment_index=1, end_frame=50, channel_ids=["a", "d"]) print(traces.shape) ############################################################################## # You can also get a recording with a subset of channels (i.e. a channel slice): -recording4 = recording3.channel_slice(channel_ids=['a', 'c', 'e']) +recording4 = recording3.channel_slice(channel_ids=["a", "c", "e"]) print(recording4) print(recording4.get_channel_ids()) # which is equivalent to from spikeinterface import ChannelSliceRecording -recording4 = ChannelSliceRecording(recording3, channel_ids=['a', 'c', 'e']) +recording4 = ChannelSliceRecording(recording3, channel_ids=["a", "c", "e"]) ############################################################################## # Another possibility is to split a recording based on a certain property (e.g. 'group') -recording3.set_property('group', [0, 0, 0, 1, 1, 1, 2]) +recording3.set_property("group", [0, 0, 0, 1, 1, 1, 2]) -recordings = recording3.split_by(property='group') +recordings = recording3.split_by(property="group") print(recordings) print(recordings[0].get_channel_ids()) print(recordings[1].get_channel_ids()) @@ -158,9 +163,9 @@ ############################################################################### # The dictionary can also be dumped directly to a JSON file on disk: -recording2.dump('my_recording.json') +recording2.dump("my_recording.json") -recording2_loaded = load_extractor('my_recording.json') +recording2_loaded = load_extractor("my_recording.json") print(recording2_loaded) ############################################################################### @@ -170,11 +175,11 @@ # :code:`save()` function. This operation is very useful to save traces obtained # after long computations (e.g. filtering or referencing): -recording2.save(folder='./my_recording') +recording2.save(folder="./my_recording") import os -pprint(os.listdir('./my_recording')) +pprint(os.listdir("./my_recording")) -recording2_cached = load_extractor('my_recording.json') +recording2_cached = load_extractor("my_recording.json") print(recording2_cached) diff --git a/examples/modules_gallery/core/plot_2_sorting_extractor.py b/examples/modules_gallery/core/plot_2_sorting_extractor.py index 59acf82712..b572218ed8 100644 --- a/examples/modules_gallery/core/plot_2_sorting_extractor.py +++ b/examples/modules_gallery/core/plot_2_sorting_extractor.py @@ -1,4 +1,4 @@ -''' +""" Sorting objects =============== @@ -11,7 +11,7 @@ * dumping to/loading from dict-json * saving (caching) -''' +""" import numpy as np import spikeinterface.extractors as se @@ -22,8 +22,8 @@ # # Let's define the properties of the dataset: -sampling_frequency = 30000. -duration = 20. +sampling_frequency = 30000.0 +duration = 20.0 num_timepoints = int(sampling_frequency * duration) num_units = 4 num_spikes = 1000 @@ -47,18 +47,18 @@ # We can now print properties that the :code:`SortingExtractor` retrieves from # the underlying sorted dataset. -print('Unit ids = {}'.format(sorting.get_unit_ids())) +print("Unit ids = {}".format(sorting.get_unit_ids())) st = sorting.get_unit_spike_train(unit_id=1, segment_index=0) -print('Num. events for unit 1seg0 = {}'.format(len(st))) +print("Num. events for unit 1seg0 = {}".format(len(st))) st1 = sorting.get_unit_spike_train(unit_id=1, start_frame=0, end_frame=30000, segment_index=1) -print('Num. events for first second of unit 1 seg1 = {}'.format(len(st1))) +print("Num. events for first second of unit 1 seg1 = {}".format(len(st1))) ############################################################################## # Some extractors also implement a :code:`write` function. We can for example # save our newly created sorting object to NPZ format (a simple format based # on numpy used in :code:`spikeinterface`): -file_path = 'my_sorting.npz' +file_path = "my_sorting.npz" se.NpzSortingExtractor.write_sorting(sorting, file_path) ############################################################################## @@ -76,9 +76,9 @@ for unit_id in sorting2.get_unit_ids(): st = sorting2.get_unit_spike_train(unit_id=unit_id, segment_index=0) firing_rates.append(st.size / duration) -sorting2.set_property('firing_rate', firing_rates) +sorting2.set_property("firing_rate", firing_rates) -print(sorting2.get_property('firing_rate')) +print(sorting2.get_property("firing_rate")) ############################################################################## # You can also get a a sorting with a subset of unit. Properties are @@ -87,7 +87,7 @@ sorting3 = sorting2.select_units(unit_ids=[1, 4]) print(sorting3) -print(sorting3.get_property('firing_rate')) +print(sorting3.get_property("firing_rate")) # which is equivalent to from spikeinterface import UnitsSelectionSorting @@ -115,9 +115,9 @@ ############################################################################### # The dictionary can also be dumped directly to a JSON file on disk: -sorting2.dump('my_sorting.json') +sorting2.dump("my_sorting.json") -sorting2_loaded = load_extractor('my_sorting.json') +sorting2_loaded = load_extractor("my_sorting.json") print(sorting2_loaded) ############################################################################### @@ -127,11 +127,11 @@ # :code:`save()` function: -sorting2.save(folder='./my_sorting') +sorting2.save(folder="./my_sorting") import os -pprint(os.listdir('./my_sorting')) +pprint(os.listdir("./my_sorting")) -sorting2_cached = load_extractor('./my_sorting') +sorting2_cached = load_extractor("./my_sorting") print(sorting2_cached) diff --git a/examples/modules_gallery/core/plot_3_handle_probe_info.py b/examples/modules_gallery/core/plot_3_handle_probe_info.py index d134b29ec5..75b2b56be8 100644 --- a/examples/modules_gallery/core/plot_3_handle_probe_info.py +++ b/examples/modules_gallery/core/plot_3_handle_probe_info.py @@ -1,4 +1,4 @@ -''' +""" Handling probe information =========================== @@ -10,7 +10,8 @@ manually. Here's how! -''' +""" + import numpy as np import spikeinterface.extractors as se @@ -38,11 +39,11 @@ from probeinterface import get_probe -other_probe = get_probe(manufacturer='cambridgeneurotech', probe_name='ASSY-37-E-1') +other_probe = get_probe(manufacturer="cambridgeneurotech", probe_name="ASSY-37-E-1") print(other_probe) other_probe.set_device_channel_indices(np.arange(32)) -recording_2_shanks = recording.set_probe(other_probe, group_mode='by_shank') +recording_2_shanks = recording.set_probe(other_probe, group_mode="by_shank") plot_probe(recording_2_shanks.get_probe()) ############################################################################### @@ -51,9 +52,9 @@ # We can use this information to split the recording into two sub-recordings: print(recording_2_shanks) -print(recording_2_shanks.get_property('group')) +print(recording_2_shanks.get_property("group")) -rec0, rec1 = recording_2_shanks.split_by(property='group') +rec0, rec1 = recording_2_shanks.split_by(property="group") print(rec0) print(rec1) diff --git a/examples/modules_gallery/core/plot_4_sorting_analyzer.py b/examples/modules_gallery/core/plot_4_sorting_analyzer.py index 593b1103e1..864f11ad1d 100644 --- a/examples/modules_gallery/core/plot_4_sorting_analyzer.py +++ b/examples/modules_gallery/core/plot_4_sorting_analyzer.py @@ -1,4 +1,4 @@ -''' +""" SortingAnalyzer =============== @@ -23,7 +23,8 @@ Here the how! -''' +""" + import matplotlib.pyplot as plt from spikeinterface import download_dataset @@ -35,8 +36,8 @@ # to download a MEArec dataset. It is a simulated dataset that contains "ground truth" # sorting information: -repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' -remote_path = 'mearec/mearec_test_10s.h5' +repo = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data" +remote_path = "mearec/mearec_test_10s.h5" local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) ############################################################################## @@ -84,19 +85,23 @@ # extract waveforms (sparse in this examples) and compute templates. # You can see that printing the object indicate which extension are computed yet. -analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500,) +analyzer.compute( + "random_spikes", + method="uniform", + max_spikes_per_unit=500, +) analyzer.compute("waveforms", ms_before=1.0, ms_after=2.0, return_scaled=True) analyzer.compute("templates", operators=["average", "median", "std"]) print(analyzer) - ############################################################################### -# To speed up computation, some steps like ""waveforms" can also be extracted +# To speed up computation, some steps like ""waveforms" can also be extracted # using parallel processing (recommended!). Like this -analyzer.compute("waveforms", ms_before=1.0, ms_after=2.0, return_scaled=True, - n_jobs=8, chunk_duration="1s", progress_bar=True) +analyzer.compute( + "waveforms", ms_before=1.0, ms_after=2.0, return_scaled=True, n_jobs=8, chunk_duration="1s", progress_bar=True +) # which is equivalent of this job_kwargs = dict(n_jobs=8, chunk_duration="1s", progress_bar=True) @@ -111,7 +116,7 @@ ext_wf = analyzer.get_extension("waveforms") for unit_id in analyzer.unit_ids: wfs = ext_wf.get_waveforms_one_unit(unit_id) - print(unit_id, ':', wfs.shape) + print(unit_id, ":", wfs.shape) ############################################################################### # Same for the "templates" extension. Here we can get all templates at once @@ -128,7 +133,6 @@ print(median_templates.shape) - ############################################################################### # This can be plot easily. @@ -136,7 +140,7 @@ fig, ax = plt.subplots() template = av_templates[unit_index] ax.plot(template) - ax.set_title(f'{unit_id}') + ax.set_title(f"{unit_id}") ############################################################################### @@ -150,10 +154,10 @@ # The SortingAnalyzer offer also select_units() method wich allows to export # only some relevant units for instance to a new SortingAnalyzer instance. -analyzer_some_units = analyzer.select_units(unit_ids=analyzer.unit_ids[:5], - format="binary_folder", folder="analyzer_some_units") +analyzer_some_units = analyzer.select_units( + unit_ids=analyzer.unit_ids[:5], format="binary_folder", folder="analyzer_some_units" +) print(analyzer_some_units) - plt.show() diff --git a/examples/modules_gallery/core/plot_5_append_concatenate_segments.py b/examples/modules_gallery/core/plot_5_append_concatenate_segments.py index db179859b0..b67a1ff0c2 100644 --- a/examples/modules_gallery/core/plot_5_append_concatenate_segments.py +++ b/examples/modules_gallery/core/plot_5_append_concatenate_segments.py @@ -32,16 +32,16 @@ ############################################################################## # First let's generate 2 recordings with 2 and 3 segments respectively: -sampling_frequency = 1000. +sampling_frequency = 1000.0 -trace0 = np.zeros((150, 5), dtype='float32') -trace1 = np.zeros((100, 5), dtype='float32') +trace0 = np.zeros((150, 5), dtype="float32") +trace1 = np.zeros((100, 5), dtype="float32") rec0 = NumpyRecording([trace0, trace1], sampling_frequency) print(rec0) -trace2 = np.zeros((50, 5), dtype='float32') -trace3 = np.zeros((200, 5), dtype='float32') -trace4 = np.zeros((120, 5), dtype='float32') +trace2 = np.zeros((50, 5), dtype="float32") +trace3 = np.zeros((200, 5), dtype="float32") +trace4 = np.zeros((120, 5), dtype="float32") rec1 = NumpyRecording([trace2, trace3, trace4], sampling_frequency) print(rec1) @@ -54,7 +54,7 @@ print(rec) for i in range(rec.get_num_segments()): s = rec.get_num_samples(segment_index=i) - print(f'segment {i} num_samples {s}') + print(f"segment {i} num_samples {s}") ############################################################################## # Let's use the :py:func:`~spikeinterface.core.concatenate_recordings()`: @@ -63,4 +63,4 @@ rec = concatenate_recordings(recording_list) print(rec) s = rec.get_num_samples(segment_index=0) -print(f'segment {0} num_samples {s}') +print(f"segment {0} num_samples {s}") diff --git a/examples/modules_gallery/core/plot_6_handle_times.py b/examples/modules_gallery/core/plot_6_handle_times.py index 4ca116e3c6..28abf68d84 100644 --- a/examples/modules_gallery/core/plot_6_handle_times.py +++ b/examples/modules_gallery/core/plot_6_handle_times.py @@ -7,6 +7,7 @@ This notebook shows how to handle time information in SpikeInterface recording and sorting objects. """ + from spikeinterface.extractors import toy_example ############################################################################## diff --git a/examples/modules_gallery/extractors/plot_1_read_various_formats.py b/examples/modules_gallery/extractors/plot_1_read_various_formats.py index df85946530..ef31b1dc76 100644 --- a/examples/modules_gallery/extractors/plot_1_read_various_formats.py +++ b/examples/modules_gallery/extractors/plot_1_read_various_formats.py @@ -1,4 +1,4 @@ -''' +""" Read various format into SpikeInterface ======================================= @@ -14,7 +14,7 @@ * file formats can be file-based (NWB, ...) or folder based (SpikeGLX, OpenEphys, ...) In this example we demonstrate how to read different file formats into SI -''' +""" import matplotlib.pyplot as plt @@ -29,10 +29,10 @@ # * Spike2: file from spike2 devices. It contains "recording" information only. -spike2_file_path = si.download_dataset(remote_path='spike2/130322-1LY.smr') +spike2_file_path = si.download_dataset(remote_path="spike2/130322-1LY.smr") print(spike2_file_path) -mearec_folder_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') +mearec_folder_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") print(mearec_folder_path) ############################################################################## @@ -45,13 +45,13 @@ # want to retrieve ('0' in our case). # the stream information can be retrieved by using the :py:func:`~spikeinterface.extractors.get_neo_streams` function. -stream_names, stream_ids = se.get_neo_streams('spike2', spike2_file_path) +stream_names, stream_ids = se.get_neo_streams("spike2", spike2_file_path) print(stream_names) print(stream_ids) stream_id = stream_ids[0] -print('stream_id', stream_id) +print("stream_id", stream_id) -recording = se.read_spike2(spike2_file_path, stream_id='0') +recording = se.read_spike2(spike2_file_path, stream_id="0") print(recording) print(type(recording)) print(isinstance(recording, si.BaseRecording)) @@ -61,7 +61,7 @@ # :py:class:`~spikeinterface.extractors.Spike2RecordingExtractor` object: # -recording = se.Spike2RecordingExtractor(spike2_file_path, stream_id='0') +recording = se.Spike2RecordingExtractor(spike2_file_path, stream_id="0") print(recording) ############################################################################## diff --git a/examples/modules_gallery/extractors/plot_2_working_with_unscaled_traces.py b/examples/modules_gallery/extractors/plot_2_working_with_unscaled_traces.py index a6a68a91f1..f2282297ea 100644 --- a/examples/modules_gallery/extractors/plot_2_working_with_unscaled_traces.py +++ b/examples/modules_gallery/extractors/plot_2_working_with_unscaled_traces.py @@ -1,4 +1,4 @@ -''' +""" Working with unscaled traces ============================ @@ -6,7 +6,7 @@ traces to uV. This example shows how to work with unscaled and scaled traces in the :py:mod:`spikeinterface.extractors` module. -''' +""" import numpy as np import matplotlib.pyplot as plt @@ -36,7 +36,7 @@ # (where 10 is the number of bits of our ADC) gain = 0.1 -offset = -2 ** (10 - 1) * gain +offset = -(2 ** (10 - 1)) * gain ############################################################################### # We are now ready to set gains and offsets for our extractor. We also have to set the :code:`has_unscaled` field to @@ -49,14 +49,14 @@ # Internally the gain and offset are handled with properties # So the gain could be "by channel". -print(recording.get_property('gain_to_uV')) -print(recording.get_property('offset_to_uV')) +print(recording.get_property("gain_to_uV")) +print(recording.get_property("offset_to_uV")) ############################################################################### # With gain and offset information, we can retrieve traces both in their unscaled (raw) type, and in their scaled # type: -traces_unscaled = recording.get_traces(return_scaled=False) # return_scaled is False by default +traces_unscaled = recording.get_traces(return_scaled=False) # return_scaled is False by default traces_scaled = recording.get_traces(return_scaled=True) print(f"Traces dtype after scaling: {traces_scaled.dtype}") diff --git a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py index 557fff229b..bfa6880cb0 100644 --- a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py +++ b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py @@ -10,14 +10,19 @@ import spikeinterface.core as si import spikeinterface.extractors as se from spikeinterface.postprocessing import compute_principal_components -from spikeinterface.qualitymetrics import (compute_snrs, compute_firing_rates, - compute_isi_violations, calculate_pc_metrics, compute_quality_metrics) +from spikeinterface.qualitymetrics import ( + compute_snrs, + compute_firing_rates, + compute_isi_violations, + calculate_pc_metrics, + compute_quality_metrics, +) ############################################################################## # First, let's download a simulated dataset # from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' -local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') +local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") recording, sorting = se.read_mearec(local_path) print(recording) print(sorting) @@ -32,11 +37,11 @@ print(analyzer) ############################################################################## -# Depending on which metrics we want to compute we will need first to compute +# Depending on which metrics we want to compute we will need first to compute # some necessary extensions. (if not computed an error message will be raised) analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=600, seed=2205) -analyzer.compute("waveforms",ms_before=1.3, ms_after=2.6, n_jobs=2) +analyzer.compute("waveforms", ms_before=1.3, ms_after=2.6, n_jobs=2) analyzer.compute("templates", operators=["average", "median", "std"]) analyzer.compute("noise_levels") @@ -69,5 +74,11 @@ analyzer.compute("principal_components", n_components=3, mode="by_channel_global", whiten=True) -metrics = compute_quality_metrics(analyzer, metric_names=["isolation_distance", "d_prime",]) +metrics = compute_quality_metrics( + analyzer, + metric_names=[ + "isolation_distance", + "d_prime", + ], +) print(metrics) diff --git a/examples/modules_gallery/qualitymetrics/plot_4_curation.py b/examples/modules_gallery/qualitymetrics/plot_4_curation.py index da379f0789..f625914191 100644 --- a/examples/modules_gallery/qualitymetrics/plot_4_curation.py +++ b/examples/modules_gallery/qualitymetrics/plot_4_curation.py @@ -6,6 +6,7 @@ quality metrics that you have calculated. """ + ############################################################################# # Import the modules and/or functions necessary from spikeinterface @@ -22,7 +23,7 @@ # # Let's imagine that the ground-truth sorting is in fact the output of a sorter. -local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') +local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") recording, sorting = se.read_mearec(file_path=local_path) print(recording) print(sorting) @@ -38,14 +39,14 @@ analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording, format="memory") analyzer.compute(["random_spikes", "waveforms", "templates", "noise_levels"]) -analyzer.compute("principal_components", n_components=3, mode='by_channel_local') +analyzer.compute("principal_components", n_components=3, mode="by_channel_local") print(analyzer) ############################################################################## # Then we compute some quality metrics: -metrics = compute_quality_metrics(analyzer, metric_names=['snr', 'isi_violation', 'nearest_neighbor']) +metrics = compute_quality_metrics(analyzer, metric_names=["snr", "isi_violation", "nearest_neighbor"]) print(metrics) ############################################################################## @@ -55,7 +56,7 @@ # # Then create a list of unit ids that we want to keep -keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_ratio'] < 0.2) & (metrics['nn_hit_rate'] > 0.90) +keep_mask = (metrics["snr"] > 7.5) & (metrics["isi_violations_ratio"] < 0.2) & (metrics["nn_hit_rate"] > 0.90) print(keep_mask) keep_unit_ids = keep_mask[keep_mask].index.values @@ -69,7 +70,7 @@ print(curated_sorting) -curated_sorting.save(folder='curated_sorting') +curated_sorting.save(folder="curated_sorting") ############################################################################## # We can also save the analyzer with only theses units diff --git a/examples/modules_gallery/widgets/plot_1_rec_gallery.py b/examples/modules_gallery/widgets/plot_1_rec_gallery.py index 1544bbfc54..bb121e26a2 100644 --- a/examples/modules_gallery/widgets/plot_1_rec_gallery.py +++ b/examples/modules_gallery/widgets/plot_1_rec_gallery.py @@ -1,9 +1,10 @@ -''' +""" RecordingExtractor Widgets Gallery =================================== Here is a gallery of all the available widgets using RecordingExtractor objects. -''' +""" + import matplotlib.pyplot as plt import spikeinterface.extractors as se @@ -39,10 +40,9 @@ w_ts.ax.set_ylabel("Channel_ids") ############################################################################## -# We can also use the 'map' mode useful for high channel count +# We can also use the 'map' mode useful for high channel count -w_ts = sw.plot_traces(recording, mode='map', time_range=(5, 8), - show_channel_ids=True, order_channel_by_depth=True) +w_ts = sw.plot_traces(recording, mode="map", time_range=(5, 8), show_channel_ids=True, order_channel_by_depth=True) ############################################################################## # plot_electrode_geometry() diff --git a/examples/modules_gallery/widgets/plot_2_sort_gallery.py b/examples/modules_gallery/widgets/plot_2_sort_gallery.py index bea6f34e4d..da5c611ce4 100644 --- a/examples/modules_gallery/widgets/plot_2_sort_gallery.py +++ b/examples/modules_gallery/widgets/plot_2_sort_gallery.py @@ -1,9 +1,10 @@ -''' +""" SortingExtractor Widgets Gallery =================================== Here is a gallery of all the available widgets using SortingExtractor objects. -''' +""" + import matplotlib.pyplot as plt import spikeinterface.extractors as se @@ -24,7 +25,7 @@ # plot_isi_distribution() # ~~~~~~~~~~~~~~~~~~~~~~~ -w_isi = sw.plot_isi_distribution(sorting, window_ms=150.0, bin_ms=5.0, figsize=(20,8)) +w_isi = sw.plot_isi_distribution(sorting, window_ms=150.0, bin_ms=5.0, figsize=(20, 8)) ############################################################################## # plot_autocorrelograms() diff --git a/examples/modules_gallery/widgets/plot_3_waveforms_gallery.py b/examples/modules_gallery/widgets/plot_3_waveforms_gallery.py index a02b62bbd8..fc4a7775d2 100644 --- a/examples/modules_gallery/widgets/plot_3_waveforms_gallery.py +++ b/examples/modules_gallery/widgets/plot_3_waveforms_gallery.py @@ -1,9 +1,10 @@ -''' +""" Waveforms Widgets Gallery ========================= Here is a gallery of all the available widgets using a pair of RecordingExtractor-SortingExtractor objects. -''' +""" + import matplotlib.pyplot as plt import spikeinterface as si @@ -15,7 +16,7 @@ # First, let's download a simulated dataset # from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' -local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') +local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") recording, sorting = si.read_mearec(local_path) print(recording) print(sorting) @@ -42,7 +43,7 @@ unit_ids = sorting.unit_ids[:4] -sw.plot_unit_waveforms(analyzer, unit_ids=unit_ids, figsize=(16,4)) +sw.plot_unit_waveforms(analyzer, unit_ids=unit_ids, figsize=(16, 4)) ############################################################################## # plot_unit_templates() @@ -50,21 +51,21 @@ unit_ids = sorting.unit_ids -sw.plot_unit_templates(analyzer, unit_ids=unit_ids, ncols=5, figsize=(16,8)) +sw.plot_unit_templates(analyzer, unit_ids=unit_ids, ncols=5, figsize=(16, 8)) ############################################################################## # plot_amplitudes() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -sw.plot_amplitudes(analyzer, plot_histograms=True, figsize=(12,8)) +sw.plot_amplitudes(analyzer, plot_histograms=True, figsize=(12, 8)) ############################################################################## # plot_unit_locations() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -sw.plot_unit_locations(analyzer, figsize=(4,8)) +sw.plot_unit_locations(analyzer, figsize=(4, 8)) ############################################################################## @@ -74,21 +75,20 @@ # This is your best friend to check over merge unit_ids = sorting.unit_ids[:4] -sw.plot_unit_waveforms_density_map(analyzer, unit_ids=unit_ids, figsize=(14,8)) - +sw.plot_unit_waveforms_density_map(analyzer, unit_ids=unit_ids, figsize=(14, 8)) ############################################################################## # plot_amplitudes_distribution() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -sw.plot_all_amplitudes_distributions(analyzer, figsize=(10,10)) +sw.plot_all_amplitudes_distributions(analyzer, figsize=(10, 10)) ############################################################################## # plot_units_depths() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -sw.plot_unit_depths(analyzer, figsize=(10,10)) +sw.plot_unit_depths(analyzer, figsize=(10, 10)) ############################################################################## @@ -96,8 +96,7 @@ # ~~~~~~~~~~~~~~~~~~~~~ unit_ids = sorting.unit_ids[:4] -sw.plot_unit_probe_map(analyzer, unit_ids=unit_ids, figsize=(20,8)) - +sw.plot_unit_probe_map(analyzer, unit_ids=unit_ids, figsize=(20, 8)) plt.show() diff --git a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py index e3464dd1e8..cce04ae5a0 100644 --- a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py +++ b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py @@ -1,4 +1,4 @@ -''' +""" Peaks Widgets Gallery ===================== @@ -7,7 +7,8 @@ They are useful to check drift before running sorters. -''' +""" + import matplotlib.pyplot as plt import spikeinterface.full as si @@ -16,34 +17,40 @@ # First, let's download a simulated dataset # from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' -local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') +local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") rec, sorting = si.read_mearec(local_path) ############################################################################## -# Let's filter and detect peaks on it +# Let's filter and detect peaks on it from spikeinterface.sortingcomponents.peak_detection import detect_peaks -rec_filtred = si.bandpass_filter(recording=rec, freq_min=300., freq_max=6000., margin_ms=5.0) +rec_filtred = si.bandpass_filter(recording=rec, freq_min=300.0, freq_max=6000.0, margin_ms=5.0) print(rec_filtred) peaks = detect_peaks( - recording=rec_filtred, method='locally_exclusive', - peak_sign='neg', detect_threshold=6, exclude_sweep_ms=0.3, - radius_um=100, - noise_levels=None, - random_chunk_kwargs={}, - chunk_memory='10M', n_jobs=1, progress_bar=True) + recording=rec_filtred, + method="locally_exclusive", + peak_sign="neg", + detect_threshold=6, + exclude_sweep_ms=0.3, + radius_um=100, + noise_levels=None, + random_chunk_kwargs={}, + chunk_memory="10M", + n_jobs=1, + progress_bar=True, +) ############################################################################## -# peaks is a numpy 1D array with structured dtype that contains several fields: +# peaks is a numpy 1D array with structured dtype that contains several fields: print(peaks.dtype) print(peaks.shape) print(peaks.dtype.fields.keys()) ############################################################################## -# This "peaks" vector can be used in several widgets, for instance +# This "peaks" vector can be used in several widgets, for instance # plot_peak_activity() si.plot_peak_activity(recording=rec_filtred, peaks=peaks) @@ -51,9 +58,9 @@ plt.show() ############################################################################## -# can be also animated with bin_duration_s=1. +# can be also animated with bin_duration_s=1. -si.plot_peak_activity(recording=rec_filtred, peaks=peaks, bin_duration_s=1.) +si.plot_peak_activity(recording=rec_filtred, peaks=peaks, bin_duration_s=1.0) plt.show() diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 79fe6ab600..268513dac8 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -41,24 +41,24 @@ class SelectRandomSpikes(AnalyzerExtension): ------- """ + extension_name = "random_spikes" depend_on = [] need_recording = False use_nodepipeline = False need_job_kwargs = False - def _run(self, + def _run( + self, ): - self.data["random_spikes_indices"] = random_spikes_selection( - self.sorting_analyzer.sorting, num_samples=self.sorting_analyzer.rec_attributes["num_samples"], - **self.params) + self.data["random_spikes_indices"] = random_spikes_selection( + self.sorting_analyzer.sorting, + num_samples=self.sorting_analyzer.rec_attributes["num_samples"], + **self.params, + ) def _set_params(self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None): - params = dict( - method=method, - max_spikes_per_unit=max_spikes_per_unit, - margin_size=margin_size, - seed=seed) + params = dict(method=method, max_spikes_per_unit=max_spikes_per_unit, margin_size=margin_size, seed=seed) return params def _select_extension_data(self, unit_ids): @@ -76,7 +76,6 @@ def _select_extension_data(self, unit_ids): new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_spike_mask]) return new_data - def _get_data(self): return self.data["random_spikes_indices"] @@ -88,7 +87,6 @@ def some_spikes(self): self._some_spikes = spikes[self.data["random_spikes_indices"]] return self._some_spikes - def get_selected_indices_in_spike_train(self, unit_id, segment_index): # usefull for Waveforms extractor backwars compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain @@ -107,12 +105,9 @@ def get_selected_indices_in_spike_train(self, unit_id, segment_index): return selected_spikes_in_spike_train - register_result_extension(SelectRandomSpikes) - - class ComputeWaveforms(AnalyzerExtension): """ AnalyzerExtension that extract some waveforms of each units. @@ -197,7 +192,7 @@ def _set_params( if return_scaled: # check if has scaled values: - if not recording.has_scaled() and recording.get_dtype().kind == 'i': + if not recording.has_scaled() and recording.get_dtype().kind == "i": print("Setting 'return_scaled' to False") return_scaled = False diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b834cbac96..74937d0861 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -310,7 +310,7 @@ def get_traces( warnings.warn(message) if not self.has_scaled(): - if self._dtype.kind == 'f': + if self._dtype.kind == "f": # here we do not truely have scale but we assume this is scaled # this helps a lot for simulated data pass diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 2ba9737ee5..f1858810fb 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -161,7 +161,12 @@ class SortingAnalyzer: """ def __init__( - self, sorting=None, recording=None, rec_attributes=None, format=None, sparsity=None, + self, + sorting=None, + recording=None, + rec_attributes=None, + format=None, + sparsity=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -409,7 +414,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, rec_attributes): # the recording rec_dict = recording.to_dict(relative_to=folder, recursive=True) - + if recording.check_serializability("json"): # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) zarr_rec = np.array([check_json(rec_dict)], dtype=object) @@ -760,12 +765,11 @@ def compute(self, input, save=True, **kwargs): elif isinstance(input, list): params_, job_kwargs = split_job_kwargs(kwargs) assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" - extensions = {k : {} for k in input} - self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs) + extensions = {k: {} for k in input} + self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs) else: raise ValueError("SortingAnalyzer.compute() need str, dict or list") - def compute_one_extension(self, extension_name, save=True, **kwargs): """ Compute one extension diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index f94226110f..cb70b21d69 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -74,7 +74,12 @@ def _check_result_extension(sorting_analyzer, extension_name): @pytest.mark.parametrize("format", ["memory", "binary_folder", "zarr"]) -@pytest.mark.parametrize("sparse", [False, ]) +@pytest.mark.parametrize( + "sparse", + [ + False, + ], +) def test_SelectRandomSpikes(format, sparse): sorting_analyzer = get_sorting_analyzer(format=format, sparse=sparse) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 56dc17817b..f10454e085 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -231,7 +231,9 @@ def get_sampled_indices(self, unit_id): selected_spikes = [] for segment_index in range(self.get_num_segments()): # inds = self.sorting_analyzer.get_selected_indices_in_spike_train(unit_id, segment_index) - inds = self.sorting_analyzer.get_extension("random_spikes").get_selected_indices_in_spike_train(unit_id, segment_index) + inds = self.sorting_analyzer.get_extension("random_spikes").get_selected_indices_in_spike_train( + unit_id, segment_index + ) sampled_index = np.zeros(inds.size, dtype=[("spike_index", "int64"), ("segment_index", "int64")]) sampled_index["spike_index"] = inds sampled_index["segment_index"][:] = segment_index diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 30aeca4b4b..7362dfc4dd 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -98,7 +98,7 @@ def _get_pipeline_nodes(self): if return_scaled: # check if has scaled values: - if not recording.has_scaled_traces() and recording.get_dtype().kind == 'i': + if not recording.has_scaled_traces() and recording.get_dtype().kind == "i": warnings.warn("Recording doesn't have scaled traces! Setting 'return_scaled' to False") return_scaled = False diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 96b83c1fe4..e9470a18ed 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -11,7 +11,8 @@ plot_unit_waveforms, ) from spikeinterface.comparison.comparisontools import make_matching_events -#from spikeinterface.postprocessing import get_template_extremum_channel + +# from spikeinterface.postprocessing import get_template_extremum_channel from spikeinterface.core import get_noise_levels import time @@ -35,9 +36,9 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.gt_sorting = gt_sorting self.indices = indices - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') + ext = sorting_analyzer.compute("fast_templates") extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) @@ -46,55 +47,57 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.peaks = peaks[self.indices] self.params = params self.exhaustive_gt = exhaustive_gt - self.method = params['method'] - self.method_kwargs = params['method_kwargs'] + self.method = params["method"] + self.method_kwargs = params["method_kwargs"] self.result = {} - - def run(self, **job_kwargs): + + def run(self, **job_kwargs): labels, peak_labels = find_cluster_from_peaks( - self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs + self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs ) - self.result['peak_labels'] = peak_labels + self.result["peak_labels"] = peak_labels def compute_result(self, **result_params): - self.noise = self.result['peak_labels'] < 0 + self.noise = self.result["peak_labels"] < 0 spikes = self.gt_sorting.to_spike_vector() - self.result['sliced_gt_sorting'] = NumpySorting(spikes[self.indices], - self.recording.sampling_frequency, - self.gt_sorting.unit_ids) + self.result["sliced_gt_sorting"] = NumpySorting( + spikes[self.indices], self.recording.sampling_frequency, self.gt_sorting.unit_ids + ) data = spikes[self.indices][~self.noise] - data["unit_index"] = self.result['peak_labels'][~self.noise] + data["unit_index"] = self.result["peak_labels"][~self.noise] + + self.result["clustering"] = NumpySorting.from_times_labels( + data["sample_index"], self.result["peak_labels"][~self.noise], self.recording.sampling_frequency + ) - self.result['clustering'] = NumpySorting.from_times_labels(data["sample_index"], - self.result['peak_labels'][~self.noise], - self.recording.sampling_frequency) - - self.result['gt_comparison'] = GroundTruthComparison(self.result['sliced_gt_sorting'], - self.result['clustering'], - exhaustive_gt=self.exhaustive_gt) + self.result["gt_comparison"] = GroundTruthComparison( + self.result["sliced_gt_sorting"], self.result["clustering"], exhaustive_gt=self.exhaustive_gt + ) - sorting_analyzer = create_sorting_analyzer(self.result['sliced_gt_sorting'], self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer( + self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False + ) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') - self.result['sliced_gt_templates'] = ext.get_data(outputs="Templates") + ext = sorting_analyzer.compute("fast_templates") + self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates") - sorting_analyzer = create_sorting_analyzer(self.result['clustering'], self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer( + self.result["clustering"], self.recording, format="memory", sparse=False + ) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') - self.result['clustering_templates'] = ext.get_data(outputs="Templates") + ext = sorting_analyzer.compute("fast_templates") + self.result["clustering_templates"] = ext.get_data(outputs="Templates") - _run_key_saved = [ - ("peak_labels", "npy") - ] + _run_key_saved = [("peak_labels", "npy")] _result_key_saved = [ ("gt_comparison", "pickle"), ("sliced_gt_sorting", "sorting"), ("clustering", "sorting"), ("sliced_gt_templates", "zarr_templates"), - ("clustering_templates", "zarr_templates") + ("clustering_templates", "zarr_templates"), ] @@ -111,23 +114,24 @@ def create_benchmark(self, key): return benchmark def homogeneity_score(self, ignore_noise=True, case_keys=None): - + if case_keys is None: case_keys = list(self.cases.keys()) - + for count, key in enumerate(case_keys): result = self.get_result(key) noise = result["peak_labels"] < 0 from sklearn.metrics import homogeneity_score + gt_labels = self.benchmarks[key].gt_sorting.to_spike_vector()["unit_index"] gt_labels = gt_labels[self.benchmarks[key].indices] - found_labels = result['peak_labels'] + found_labels = result["peak_labels"] if ignore_noise: gt_labels = gt_labels[~noise] found_labels = found_labels[~noise] - print(self.cases[key]['label'], homogeneity_score(gt_labels, found_labels), np.mean(noise)) + print(self.cases[key]["label"], homogeneity_score(gt_labels, found_labels), np.mean(noise)) - def plot_agreements(self, case_keys=None, figsize=(15,15)): + def plot_agreements(self, case_keys=None, figsize=(15, 15)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -135,32 +139,32 @@ def plot_agreements(self, case_keys=None, figsize=(15,15)): for count, key in enumerate(case_keys): ax = axs[count] - ax.set_title(self.cases[key]['label']) - plot_agreement_matrix(self.get_result(key)['gt_comparison'], ax=ax) + ax.set_title(self.cases[key]["label"]) + plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) - def plot_performances_vs_snr(self, case_keys=None, figsize=(15,15)): + def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): if case_keys is None: case_keys = list(self.cases.keys()) fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) for count, k in enumerate(("accuracy", "recall", "precision")): - + ax = axs[count] for key in case_keys: label = self.cases[key]["label"] - + analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() + metrics = analyzer.get_extension("quality_metrics").get_data() x = metrics["snr"].values - y = self.get_result(key)['gt_comparison'].get_performance()[k].values + y = self.get_result(key)["gt_comparison"].get_performance()[k].values ax.scatter(x, y, marker=".", label=label) ax.set_title(k) if count == 2: ax.legend() - - def plot_error_metrics(self, metric='cosine', case_keys=None, figsize=(15,5)): + + def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -170,14 +174,14 @@ def plot_error_metrics(self, metric='cosine', case_keys=None, figsize=(15,5)): for count, key in enumerate(case_keys): result = self.get_result(key) - scores = result['gt_comparison'].get_ordered_agreement_scores() + scores = result["gt_comparison"].get_ordered_agreement_scores() unit_ids1 = scores.index.values unit_ids2 = scores.columns.values - inds_1 = result['gt_comparison'].sorting1.ids_to_indices(unit_ids1) - inds_2 = result['gt_comparison'].sorting2.ids_to_indices(unit_ids2) + inds_1 = result["gt_comparison"].sorting1.ids_to_indices(unit_ids1) + inds_2 = result["gt_comparison"].sorting2.ids_to_indices(unit_ids2) t1 = result["sliced_gt_templates"].templates_array - t2 = result['clustering_templates'].templates_array + t2 = result["clustering_templates"].templates_array a = t1.reshape(len(t1), -1)[inds_1] b = t2.reshape(len(t2), -1)[inds_2] @@ -194,8 +198,7 @@ def plot_error_metrics(self, metric='cosine', case_keys=None, figsize=(15,5)): label = self.cases[key]["label"] axs[count].set_title(label) - - def plot_metrics_vs_snr(self, metric='cosine', case_keys=None, figsize=(15,5)): + def plot_metrics_vs_snr(self, metric="cosine", case_keys=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -205,17 +208,17 @@ def plot_metrics_vs_snr(self, metric='cosine', case_keys=None, figsize=(15,5)): for count, key in enumerate(case_keys): result = self.get_result(key) - scores = result['gt_comparison'].get_ordered_agreement_scores() + scores = result["gt_comparison"].get_ordered_agreement_scores() analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() - + metrics = analyzer.get_extension("quality_metrics").get_data() + unit_ids1 = scores.index.values unit_ids2 = scores.columns.values - inds_1 = result['gt_comparison'].sorting1.ids_to_indices(unit_ids1) - inds_2 = result['gt_comparison'].sorting2.ids_to_indices(unit_ids2) + inds_1 = result["gt_comparison"].sorting1.ids_to_indices(unit_ids1) + inds_2 = result["gt_comparison"].sorting2.ids_to_indices(unit_ids2) t1 = result["sliced_gt_templates"].templates_array - t2 = result['clustering_templates'].templates_array + t2 = result["clustering_templates"].templates_array a = t1.reshape(len(t1), -1) b = t2.reshape(len(t2), -1) @@ -225,19 +228,18 @@ def plot_metrics_vs_snr(self, metric='cosine', case_keys=None, figsize=(15,5)): distances = sklearn.metrics.pairwise.cosine_similarity(a, b) else: distances = sklearn.metrics.pairwise_distances(a, b, metric) - + snr = metrics["snr"][unit_ids1][inds_1[: len(inds_2)]] to_plot = [] for found, real in zip(inds_2, inds_1): to_plot += [distances[real, found]] - axs[count].plot(snr, to_plot, '.') - axs[count].set_xlabel('snr') + axs[count].plot(snr, to_plot, ".") + axs[count].set_xlabel("snr") axs[count].set_ylabel(metric) label = self.cases[key]["label"] axs[count].set_title(label) - # def _scatter_clusters( # self, # xs, diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 4a5221e16d..ffecbe028f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -24,17 +24,14 @@ class MatchingBenchmark(Benchmark): def __init__(self, recording, gt_sorting, params): self.recording = recording self.gt_sorting = gt_sorting - self.method = params['method'] - self.templates = params["method_kwargs"]['templates'] - self.method_kwargs = params['method_kwargs'] + self.method = params["method"] + self.templates = params["method_kwargs"]["templates"] + self.method_kwargs = params["method_kwargs"] self.result = {} def run(self, **job_kwargs): spikes = find_spikes_from_templates( - self.recording, - method=self.method, - method_kwargs=self.method_kwargs, - **job_kwargs + self.recording, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs ) unit_ids = self.templates.unit_ids sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) @@ -42,36 +39,33 @@ def run(self, **job_kwargs): sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids) - self.result = {'sorting' : sorting} - self.result['templates'] = self.templates + self.result = {"sorting": sorting} + self.result["templates"] = self.templates def compute_result(self, **result_params): - sorting = self.result['sorting'] + sorting = self.result["sorting"] comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) - self.result['gt_comparison'] = comp - self.result['gt_collision'] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) - + self.result["gt_comparison"] = comp + self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True) + _run_key_saved = [ ("sorting", "sorting"), ("templates", "zarr_templates"), ] - _result_key_saved = [ - ("gt_collision", "pickle"), - ("gt_comparison", "pickle") - ] + _result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")] class MatchingStudy(BenchmarkStudy): benchmark_class = MatchingBenchmark - def create_benchmark(self,key): + def create_benchmark(self, key): dataset_key = self.cases[key]["dataset"] recording, gt_sorting = self.datasets[dataset_key] params = self.cases[key]["params"] benchmark = MatchingBenchmark(recording, gt_sorting, params) return benchmark - + def plot_agreements(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -80,9 +74,9 @@ def plot_agreements(self, case_keys=None, figsize=None): for count, key in enumerate(case_keys): ax = axs[count] - ax.set_title(self.cases[key]['label']) - plot_agreement_matrix(self.get_result(key)['gt_comparison'], ax=ax) - + ax.set_title(self.cases[key]["label"]) + plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + def plot_performances_vs_snr(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -90,15 +84,15 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=None): fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) for count, k in enumerate(("accuracy", "recall", "precision")): - + ax = axs[count] for key in case_keys: label = self.cases[key]["label"] - + analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() + metrics = analyzer.get_extension("quality_metrics").get_data() x = metrics["snr"].values - y = self.get_result(key)['gt_comparison'].get_performance()[k].values + y = self.get_result(key)["gt_comparison"].get_performance()[k].values ax.scatter(x, y, marker=".", label=label) ax.set_title(k) @@ -108,23 +102,29 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=None): def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) - + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) for count, key in enumerate(case_keys): - templates_array = self.get_result(key)['templates'].templates_array + templates_array = self.get_result(key)["templates"].templates_array plot_comparison_collision_by_similarity( - self.get_result(key)['gt_collision'], templates_array, ax=axs[count], - show_legend=True, mode="lines", good_only=False + self.get_result(key)["gt_collision"], + templates_array, + ax=axs[count], + show_legend=True, + mode="lines", + good_only=False, ) - def plot_comparison_matching(self, case_keys=None, + def plot_comparison_matching( + self, + case_keys=None, performance_names=["accuracy", "recall", "precision"], colors=["g", "b", "r"], ylim=(-0.1, 1.1), - figsize=None + figsize=None, ): - + if case_keys is None: case_keys = list(self.cases.keys()) @@ -136,8 +136,8 @@ def plot_comparison_matching(self, case_keys=None, ax = axs[i, j] else: ax = axs[j] - comp1 = self.get_result(key1)['gt_comparison'] - comp2 = self.get_result(key2)['gt_comparison'] + comp1 = self.get_result(key1)["gt_comparison"] + comp2 = self.get_result(key2)["gt_comparison"] if i <= j: for performance, color in zip(performance_names, colors): perf1 = comp1.get_performance()[performance] @@ -150,8 +150,8 @@ def plot_comparison_matching(self, case_keys=None, ax.spines[["right", "top"]].set_visible(False) ax.set_aspect("equal") - label1 = self.cases[key1]['label'] - label2 = self.cases[key2]['label'] + label1 = self.cases[key1]["label"] + label2 = self.cases[key2]["label"] if j == i: ax.set_ylabel(f"{label1}") else: @@ -172,4 +172,4 @@ def plot_comparison_matching(self, case_keys=None, ax.spines["right"].set_visible(False) ax.set_xticks([]) ax.set_yticks([]) - plt.tight_layout(h_pad=0, w_pad=0) \ No newline at end of file + plt.tight_layout(h_pad=0, w_pad=0) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index be048ede50..12f0ff7a4a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -28,9 +28,7 @@ # TODO : read from mearec - - -def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direction_dim = 1): +def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direction_dim=1): """ Get final displacement vector unit per units. @@ -57,7 +55,7 @@ def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direct """ - num_units = displacement_unit_factor.shape[0] + num_units = displacement_unit_factor.shape[0] unit_displacements = np.zeros((displacement_vectors.shape[0], num_units)) for i in range(displacement_vectors.shape[2]): m = displacement_vectors[:, direction_dim, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :] @@ -66,10 +64,14 @@ def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direct return unit_displacements -def get_gt_motion_from_unit_discplacement(unit_displacements, displacement_sampling_frequency, - unit_locations, - temporal_bins, spatial_bins, - direction_dim=1,): +def get_gt_motion_from_unit_discplacement( + unit_displacements, + displacement_sampling_frequency, + unit_locations, + temporal_bins, + spatial_bins, + direction_dim=1, +): times = np.arange(unit_displacements.shape[0]) / displacement_sampling_frequency f = scipy.interpolate.interp1d(times, unit_displacements, axis=0) @@ -83,17 +85,25 @@ def get_gt_motion_from_unit_discplacement(unit_displacements, displacement_sampl # non rigid gt_motion = np.zeros((temporal_bins.size, spatial_bins.size)) for t in range(temporal_bins.shape[0]): - f = scipy.interpolate.interp1d(unit_locations[:, direction_dim], unit_displacements[t, :], fill_value="extrapolate") + f = scipy.interpolate.interp1d( + unit_locations[:, direction_dim], unit_displacements[t, :], fill_value="extrapolate" + ) gt_motion[t, :] = f(spatial_bins) return gt_motion - class MotionEstimationBenchmark(Benchmark): - def __init__(self, recording, gt_sorting, params, - unit_locations, unit_displacements, displacement_sampling_frequency, - direction="y"): + def __init__( + self, + recording, + gt_sorting, + params, + unit_locations, + unit_displacements, + displacement_sampling_frequency, + direction="y", + ): Benchmark.__init__(self) self.recording = recording self.gt_sorting = gt_sorting @@ -110,9 +120,7 @@ def run(self, **job_kwargs): noise_levels = get_noise_levels(self.recording, return_scaled=False) t0 = time.perf_counter() - peaks = detect_peaks( - self.recording, noise_levels=noise_levels, **p["detect_kwargs"], **job_kwargs - ) + peaks = detect_peaks(self.recording, noise_levels=noise_levels, **p["detect_kwargs"], **job_kwargs) t1 = time.perf_counter() if p["select_kwargs"] is not None: selected_peaks = select_peaks(self.peaks, **p["select_kwargs"], **job_kwargs) @@ -120,9 +128,7 @@ def run(self, **job_kwargs): selected_peaks = peaks t2 = time.perf_counter() - peak_locations = localize_peaks( - self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs - ) + peak_locations = localize_peaks(self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs) t3 = time.perf_counter() motion, temporal_bins, spatial_bins = estimate_motion( self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"] @@ -159,7 +165,9 @@ def compute_result(self, **result_params): # non rigid gt_motion = np.zeros_like(raw_motion) for t in range(temporal_bins.shape[0]): - f = scipy.interpolate.interp1d(self.unit_locations[:, self.direction_dim], unit_displacements[t, :], fill_value="extrapolate") + f = scipy.interpolate.interp1d( + self.unit_locations[:, self.direction_dim], unit_displacements[t, :], fill_value="extrapolate" + ) gt_motion[t, :] = f(spatial_bins) # align globally gt_motion and motion to avoid offsets @@ -168,7 +176,6 @@ def compute_result(self, **result_params): self.result["gt_motion"] = gt_motion self.result["motion"] = motion - _run_key_saved = [ ("raw_motion", "npy"), ("temporal_bins", "npy"), @@ -176,14 +183,17 @@ def compute_result(self, **result_params): ("step_run_times", "pickle"), ] _result_key_saved = [ - ("gt_motion", "npy",), - ("motion", "npy",) + ( + "gt_motion", + "npy", + ), + ( + "motion", + "npy", + ), ] - - - class MotionEstimationStudy(BenchmarkStudy): benchmark_class = MotionEstimationBenchmark @@ -197,7 +207,6 @@ def create_benchmark(self, key): return benchmark def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): - if case_keys is None: case_keys = list(self.cases.keys()) @@ -218,7 +227,7 @@ def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): ax.set_ylabel("depth (um)") ax.set_xlabel(None) - ax.set_aspect('auto') + ax.set_aspect("auto") # dirft ax = ax1 = fig.add_subplot(gs[2:7]) @@ -227,7 +236,6 @@ def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): spatial_bins = bench.result["spatial_bins"] gt_motion = bench.result["gt_motion"] - # for i in range(self.gt_unit_positions.shape[1]): # ax.plot(temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") @@ -246,8 +254,7 @@ def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): ax.axhline(probe_y_min, color="k", ls="--", alpha=0.5) ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) - - ax = ax2= fig.add_subplot(gs[7]) + ax = ax2 = fig.add_subplot(gs[7]) ax2.sharey(ax0) _simpleaxis(ax) ax.hist(unit_locations[:, bench.direction_dim], bins=50, orientation="horizontal", color="0.5") @@ -274,7 +281,6 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): temporal_bins = bench.result["temporal_bins"] spatial_bins = bench.result["spatial_bins"] - fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(2, 2) @@ -319,13 +325,13 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): if lim is not None: ax.set_ylim(0, lim) - def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, figsize=(15, 5)): + def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) fig, axes = plt.subplots(1, 3, figsize=figsize) - + for count, key in enumerate(case_keys): bench = self.benchmarks[key] @@ -336,9 +342,6 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fi temporal_bins = bench.result["temporal_bins"] spatial_bins = bench.result["spatial_bins"] - - - c = colors[count] if colors is not None else None errors = gt_motion - motion mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) @@ -384,9 +387,6 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fi # ax2.sharey(ax0) - - - # class BenchmarkMotionEstimationMearec(BenchmarkBase): # _array_names = ( # "noise_levels", diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index d2b83f181a..af45f7421f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -20,12 +20,18 @@ import matplotlib.pyplot as plt - class MotionInterpolationBenchmark(Benchmark): - def __init__(self, static_recording, gt_sorting, params, - sorter_folder, drifting_recording, - motion, temporal_bins, spatial_bins, - ): + def __init__( + self, + static_recording, + gt_sorting, + params, + sorter_folder, + drifting_recording, + motion, + temporal_bins, + spatial_bins, + ): Benchmark.__init__(self) self.static_recording = static_recording self.gt_sorting = gt_sorting @@ -37,14 +43,13 @@ def __init__(self, static_recording, gt_sorting, params, self.temporal_bins = temporal_bins self.spatial_bins = spatial_bins - def run(self, **job_kwargs): - if self.params["recording_source"] == 'static': + if self.params["recording_source"] == "static": recording = self.static_recording - elif self.params["recording_source"] == 'drifting': + elif self.params["recording_source"] == "drifting": recording = self.drifting_recording - elif self.params["recording_source"] == 'corrected': + elif self.params["recording_source"] == "corrected": correct_motion_kwargs = self.params["correct_motion_kwargs"] recording = InterpolateMotionRecording( self.drifting_recording, self.motion, self.temporal_bins, self.spatial_bins, **correct_motion_kwargs @@ -66,12 +71,11 @@ def run(self, **job_kwargs): def compute_result(self, exhaustive_gt=True, merging_score=0.2): sorting = self.result["sorting"] - # self.result[""] = + # self.result[""] = comparison = GroundTruthComparison(self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt) self.result["comparison"] = comparison self.result["accuracy"] = comparison.get_performance()["accuracy"].values.astype("float32") - gt_unit_ids = self.gt_sorting.unit_ids unit_ids = sorting.unit_ids @@ -90,7 +94,6 @@ def compute_result(self, exhaustive_gt=True, merging_score=0.2): self.result["comparison_merged"] = comparison_merged self.result["accuracy_merged"] = comparison_merged.get_performance()["accuracy"].values.astype("float32") - _run_key_saved = [ ("sorting", "sorting"), ] @@ -111,24 +114,32 @@ def create_benchmark(self, key): recording, gt_sorting = self.datasets[dataset_key] params = self.cases[key]["params"] init_kwargs = self.cases[key]["init_kwargs"] - sorter_folder = self.folder / "sorters" /self.key_to_str(key) + sorter_folder = self.folder / "sorters" / self.key_to_str(key) sorter_folder.parent.mkdir(exist_ok=True) - benchmark = MotionInterpolationBenchmark(recording, gt_sorting, params, - sorter_folder=sorter_folder, **init_kwargs) + benchmark = MotionInterpolationBenchmark( + recording, gt_sorting, params, sorter_folder=sorter_folder, **init_kwargs + ) return benchmark - - def plot_sorting_accuracy(self, case_keys=None, mode="ordered_accuracy", legend=True, colors=None, - mode_best_merge=False, figsize=(10, 5), ax=None, axes=None): - + def plot_sorting_accuracy( + self, + case_keys=None, + mode="ordered_accuracy", + legend=True, + colors=None, + mode_best_merge=False, + figsize=(10, 5), + ax=None, + axes=None, + ): if case_keys is None: case_keys = list(self.cases.keys()) if not mode_best_merge: - ls = '-' + ls = "-" else: - ls = '--' + ls = "--" if mode == "ordered_accuracy": if ax is None: @@ -176,7 +187,7 @@ def plot_sorting_accuracy(self, case_keys=None, mode="ordered_accuracy", legend= unit_locations = ext.get_data() unit_depth = unit_locations[:, 1] - snr= analyzer.get_extension("quality_metrics").get_data()["snr"].values + snr = analyzer.get_extension("quality_metrics").get_data()["snr"].values points = ax.scatter(unit_depth, snr, c=accuracy) points.set_clim(0.0, 1.0) @@ -203,9 +214,9 @@ def plot_sorting_accuracy(self, case_keys=None, mode="ordered_accuracy", legend= else: accuracy = result["accuracy_merged"] - analyzer = self.get_sorting_analyzer(key) - snr= analyzer.get_extension("quality_metrics").get_data()["snr"].values - + analyzer = self.get_sorting_analyzer(key) + snr = analyzer.get_extension("quality_metrics").get_data()["snr"].values + ax.scatter(snr, accuracy, label=label) ax.set_xlabel("snr") ax.set_ylabel("accuracy") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index 61e1ce2098..6a548d579a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -10,7 +10,11 @@ plot_unit_templates, plot_unit_waveforms, ) -from spikeinterface.postprocessing.unit_localization import compute_center_of_mass, compute_monopolar_triangulation, compute_grid_convolution +from spikeinterface.postprocessing.unit_localization import ( + compute_center_of_mass, + compute_monopolar_triangulation, + compute_grid_convolution, +) from spikeinterface.core import get_noise_levels import pylab as plt @@ -36,34 +40,31 @@ def __init__(self, recording, gt_sorting, params, gt_positions): self.params[key] = 2 def run(self, **job_kwargs): - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates', **self.templates_params) - templates = ext.get_data(outputs='Templates') + ext = sorting_analyzer.compute("fast_templates", **self.templates_params) + templates = ext.get_data(outputs="Templates") ext = sorting_analyzer.compute("spike_locations", **self.params) spikes_locations = ext.get_data(outputs="by_unit") - self.result = {'spikes_locations' : spikes_locations} - self.result['templates'] = templates + self.result = {"spikes_locations": spikes_locations} + self.result["templates"] = templates def compute_result(self, **result_params): errors = {} for unit_ind, unit_id in enumerate(self.gt_sorting.unit_ids): - data = self.result['spikes_locations'][0][unit_id] + data = self.result["spikes_locations"][0][unit_id] errors[unit_id] = np.sqrt( (data["x"] - self.gt_positions[unit_ind, 0]) ** 2 + (data["y"] - self.gt_positions[unit_ind, 1]) ** 2 ) - self.result['medians_over_templates'] = np.array( + self.result["medians_over_templates"] = np.array( [np.median(errors[unit_id]) for unit_id in self.gt_sorting.unit_ids] ) - self.result['mads_over_templates'] = np.array( - [ - np.median(np.abs(errors[unit_id] - np.median(errors[unit_id]))) - for unit_id in self.gt_sorting.unit_ids - ] + self.result["mads_over_templates"] = np.array( + [np.median(np.abs(errors[unit_id] - np.median(errors[unit_id]))) for unit_id in self.gt_sorting.unit_ids] ) - self.result['errors'] = errors + self.result["errors"] = errors _run_key_saved = [ ("spikes_locations", "pickle"), @@ -95,15 +96,14 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5)) - for count, key in enumerate(case_keys): analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() + metrics = analyzer.get_extension("quality_metrics").get_data() snrs = metrics["snr"].values result = self.get_result(key) - norms = np.linalg.norm(result['templates'].templates_array, axis=(1, 2)) + norms = np.linalg.norm(result["templates"].templates_array, axis=(1, 2)) - coordinates = self.benchmarks[key].gt_positions[:, :2].copy() + coordinates = self.benchmarks[key].gt_positions[:, :2].copy() coordinates[:, 0] -= coordinates[:, 0].mean() coordinates[:, 1] -= coordinates[:, 1].mean() distances_to_center = np.linalg.norm(coordinates, axis=1) @@ -111,13 +111,12 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): idx = np.argsort(norms) from scipy.signal import savgol_filter + wdx = np.argsort(snrs) data = result["medians_over_templates"] - axs[0].plot( - snrs[wdx], savgol_filter(data[wdx], smoothing_factor, 3), lw=2, label=self.cases[key]['label'] - ) + axs[0].plot(snrs[wdx], savgol_filter(data[wdx], smoothing_factor, 3), lw=2, label=self.cases[key]["label"]) ymin = savgol_filter((data - result["mads_over_templates"])[wdx], smoothing_factor, 3) ymax = savgol_filter((data + result["mads_over_templates"])[wdx], smoothing_factor, 3) @@ -129,7 +128,7 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): distances_to_center[zdx], savgol_filter(data[zdx], smoothing_factor, 3), lw=2, - label=self.cases[key]['label'], + label=self.cases[key]["label"], ) ymin = savgol_filter((data - result["mads_over_templates"])[zdx], smoothing_factor, 3) ymax = savgol_filter((data + result["mads_over_templates"])[zdx], smoothing_factor, 3) @@ -140,14 +139,14 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): x_means = [] x_stds = [] for count, key in enumerate(case_keys): - result = self.get_result(key)['medians_over_templates'] + result = self.get_result(key)["medians_over_templates"] x_means += [result.mean()] x_stds += [result.std()] y_means = [] y_stds = [] for count, key in enumerate(case_keys): - result = self.get_result(key)['mads_over_templates'] + result = self.get_result(key)["mads_over_templates"] y_means += [result.mean()] y_stds += [result.std()] @@ -162,15 +161,14 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): axs[1].legend() - class UnitLocalizationBenchmark(Benchmark): def __init__(self, recording, gt_sorting, params, gt_positions): self.recording = recording self.gt_sorting = gt_sorting self.gt_positions = gt_positions - self.method = params['method'] - self.method_kwargs = params['method_kwargs'] + self.method = params["method"] + self.method_kwargs = params["method_kwargs"] self.result = {} self.waveforms_params = {} for key in ["ms_before", "ms_after"]: @@ -180,11 +178,11 @@ def __init__(self, recording, gt_sorting, params, gt_positions): self.waveforms_params[key] = 2 def run(self, **job_kwargs): - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory') + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory") sorting_analyzer.compute("random_spikes") - sorting_analyzer.compute('waveforms', **self.waveforms_params, **job_kwargs) - ext = sorting_analyzer.compute('templates') - templates = ext.get_data(outputs='Templates') + sorting_analyzer.compute("waveforms", **self.waveforms_params, **job_kwargs) + ext = sorting_analyzer.compute("templates") + templates = ext.get_data(outputs="Templates") if self.method == "center_of_mass": unit_locations = compute_center_of_mass(sorting_analyzer, **self.method_kwargs) @@ -193,23 +191,21 @@ def run(self, **job_kwargs): elif self.method == "grid_convolution": unit_locations = compute_grid_convolution(sorting_analyzer, **self.method_kwargs) - if (unit_locations.shape[1] == 2): + if unit_locations.shape[1] == 2: unit_locations = np.hstack((unit_locations, np.zeros((len(unit_locations), 1)))) - - self.result = {'unit_locations' : unit_locations} - self.result['templates'] = templates + + self.result = {"unit_locations": unit_locations} + self.result["templates"] = templates def compute_result(self, **result_params): - errors = np.linalg.norm(self.gt_positions[:, :2] - self.result['unit_locations'][:, :2], axis=1) - self.result['errors'] = errors - + errors = np.linalg.norm(self.gt_positions[:, :2] - self.result["unit_locations"][:, :2], axis=1) + self.result["errors"] = errors + _run_key_saved = [ ("unit_locations", "npy"), ("templates", "zarr_templates"), ] - _result_key_saved = [ - ("errors", "npy") - ] + _result_key_saved = [("errors", "npy")] class UnitLocalizationStudy(BenchmarkStudy): @@ -230,21 +226,21 @@ def plot_template_errors(self, case_keys=None): case_keys = list(self.cases.keys()) fig, axs = plt.subplots(ncols=1, nrows=1, figsize=(15, 5)) from spikeinterface.widgets import plot_probe_map - #plot_probe_map(self.benchmarks[case_keys[0]].recording, ax=axs) + + # plot_probe_map(self.benchmarks[case_keys[0]].recording, ax=axs) axs.scatter(self.gt_positions[:, 0], self.gt_positions[:, 1], c=np.arange(len(self.gt_positions)), cmap="jet") - + for count, key in enumerate(case_keys): result = self.get_result(key) axs.scatter( - result['unit_locations'][:, 0], - result['unit_locations'][:, 1], - c=f'C{count}', + result["unit_locations"][:, 0], + result["unit_locations"][:, 1], + c=f"C{count}", marker="v", - label=self.cases[key]['label'] + label=self.cases[key]["label"], ) axs.legend() - def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): if case_keys is None: @@ -254,12 +250,12 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): for count, key in enumerate(case_keys): analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension('quality_metrics').get_data() + metrics = analyzer.get_extension("quality_metrics").get_data() snrs = metrics["snr"].values result = self.get_result(key) - norms = np.linalg.norm(result['templates'].templates_array, axis=(1, 2)) + norms = np.linalg.norm(result["templates"].templates_array, axis=(1, 2)) - coordinates = self.benchmarks[key].gt_positions[:, :2].copy() + coordinates = self.benchmarks[key].gt_positions[:, :2].copy() coordinates[:, 0] -= coordinates[:, 0].mean() coordinates[:, 1] -= coordinates[:, 1].mean() distances_to_center = np.linalg.norm(coordinates, axis=1) @@ -267,14 +263,13 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): idx = np.argsort(norms) from scipy.signal import savgol_filter + wdx = np.argsort(snrs) data = result["errors"] - axs[0].plot( - snrs[wdx], savgol_filter(data[wdx], smoothing_factor, 3), lw=2, label=self.cases[key]['label'] - ) - + axs[0].plot(snrs[wdx], savgol_filter(data[wdx], smoothing_factor, 3), lw=2, label=self.cases[key]["label"]) + axs[0].set_xlabel("snr") axs[0].set_ylabel("error (um)") @@ -282,7 +277,7 @@ def plot_comparison_positions(self, case_keys=None, smoothing_factor=5): distances_to_center[zdx], savgol_filter(data[zdx], smoothing_factor, 3), lw=2, - label=self.cases[key]['label'], + label=self.cases[key]["label"], ) axs[1].legend() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1f97f0a0c6..a51c8c8145 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -39,9 +39,9 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.gt_sorting = gt_sorting self.indices = indices - sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') + ext = sorting_analyzer.compute("fast_templates") extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) @@ -50,55 +50,57 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.peaks = peaks[self.indices] self.params = params self.exhaustive_gt = exhaustive_gt - self.method = params['method'] - self.method_kwargs = params['method_kwargs'] + self.method = params["method"] + self.method_kwargs = params["method_kwargs"] self.result = {} - - def run(self, **job_kwargs): + + def run(self, **job_kwargs): labels, peak_labels = find_cluster_from_peaks( - self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs + self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, **job_kwargs ) - self.result['peak_labels'] = peak_labels + self.result["peak_labels"] = peak_labels def compute_result(self, **result_params): - self.noise = self.result['peak_labels'] < 0 + self.noise = self.result["peak_labels"] < 0 spikes = self.gt_sorting.to_spike_vector() - self.result['sliced_gt_sorting'] = NumpySorting(spikes[self.indices], - self.recording.sampling_frequency, - self.gt_sorting.unit_ids) + self.result["sliced_gt_sorting"] = NumpySorting( + spikes[self.indices], self.recording.sampling_frequency, self.gt_sorting.unit_ids + ) data = spikes[self.indices][~self.noise] - data["unit_index"] = self.result['peak_labels'][~self.noise] + data["unit_index"] = self.result["peak_labels"][~self.noise] - self.result['clustering'] = NumpySorting.from_times_labels(data["sample_index"], - self.result['peak_labels'][~self.noise], - self.recording.sampling_frequency) - - self.result['gt_comparison'] = GroundTruthComparison(self.result['sliced_gt_sorting'], - self.result['clustering'], - exhaustive_gt=self.exhaustive_gt) + self.result["clustering"] = NumpySorting.from_times_labels( + data["sample_index"], self.result["peak_labels"][~self.noise], self.recording.sampling_frequency + ) + + self.result["gt_comparison"] = GroundTruthComparison( + self.result["sliced_gt_sorting"], self.result["clustering"], exhaustive_gt=self.exhaustive_gt + ) - sorting_analyzer = create_sorting_analyzer(self.result['sliced_gt_sorting'], self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer( + self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False + ) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') - self.result['sliced_gt_templates'] = ext.get_data(outputs="Templates") + ext = sorting_analyzer.compute("fast_templates") + self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates") - sorting_analyzer = create_sorting_analyzer(self.result['clustering'], self.recording, format='memory', sparse=False) + sorting_analyzer = create_sorting_analyzer( + self.result["clustering"], self.recording, format="memory", sparse=False + ) sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute('fast_templates') - self.result['clustering_templates'] = ext.get_data(outputs="Templates") + ext = sorting_analyzer.compute("fast_templates") + self.result["clustering_templates"] = ext.get_data(outputs="Templates") - _run_key_saved = [ - ("peak_labels", "npy") - ] + _run_key_saved = [("peak_labels", "npy")] _result_key_saved = [ ("gt_comparison", "pickle"), ("sliced_gt_sorting", "sorting"), ("clustering", "sorting"), ("sliced_gt_templates", "zarr_templates"), - ("clustering_templates", "zarr_templates") + ("clustering_templates", "zarr_templates"), ] @@ -115,8 +117,6 @@ def create_benchmark(self, key): return benchmark - - # class BenchmarkPeakSelection: # def __init__(self, recording, gt_sorting, exhaustive_gt=True, job_kwargs={}, tmp_folder=None, verbose=True): # self.verbose = verbose diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 8a239453e5..5f23fab255 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -17,6 +17,7 @@ _key_separator = "_-°°-_" + class BenchmarkStudy: """ Generic study for sorting components. @@ -31,7 +32,9 @@ class BenchmarkStudy: """ + benchmark_class = None + def __init__(self, study_folder): self.folder = Path(study_folder) self.datasets = {} @@ -147,7 +150,7 @@ def remove_benchmark(self, key): if result_folder.exists(): shutil.rmtree(result_folder) - for f in (log_file, ): + for f in (log_file,): if f.exists(): f.unlink() self.benchmarks[key] = None @@ -178,17 +181,17 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): benchmark.save_run(bench_folder) benchmark.result["run_time"] = float(t1 - t0) benchmark.save_main(bench_folder) - + def get_run_times(self, case_keys=None): if case_keys is None: case_keys = list(self.cases.keys()) - + run_times = {} for key in case_keys: benchmark = self.benchmarks[key] assert benchmark is not None run_times[key] = benchmark.result["run_time"] - + df = pd.DataFrame(dict(run_times=run_times)) if not isinstance(self.levels, str): df.index.names = self.levels @@ -199,9 +202,7 @@ def plot_run_times(self, case_keys=None): case_keys = list(self.cases.keys()) run_times = self.get_run_times(case_keys=case_keys) - run_times.plot(kind='bar') - - + run_times.plot(kind="bar") def compute_results(self, case_keys=None, verbose=False, **result_params): if case_keys is None: @@ -264,7 +265,7 @@ def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], f else: continue sorting_analyzer = self.get_sorting_analyzer(key) - qm_ext = sorting_analyzer.compute("quality_metrics", metric_names=metric_names) + qm_ext = sorting_analyzer.compute("quality_metrics", metric_names=metric_names) metrics = qm_ext.get_data() metrics.to_csv(filename, sep="\t", index=True) @@ -285,16 +286,16 @@ def get_metrics(self, key): def get_units_snr(self, key): """ """ return self.get_metrics(key)["snr"] - + def get_result(self, key): return self.benchmarks[key].result - class Benchmark: """ Responsible to make a unique run() and compute_result() for one case. """ + def __init__(self): self.result = {} @@ -310,14 +311,14 @@ def _save_keys(self, saved_keys, folder): for k, format in saved_keys: if format == "npy": np.save(folder / f"{k}.npy", self.result[k]) - elif format =="pickle": - with open(folder / f"{k}.pickle", mode="wb") as f: + elif format == "pickle": + with open(folder / f"{k}.pickle", mode="wb") as f: pickle.dump(self.result[k], f) - elif format == 'sorting': - self.result[k].save(folder = folder / k, format="numpy_folder") - elif format == 'zarr_templates': + elif format == "sorting": + self.result[k].save(folder=folder / k, format="numpy_folder") + elif format == "zarr_templates": self.result[k].to_zarr(folder / k) - elif format == 'sorting_analyzer': + elif format == "sorting_analyzer": pass else: raise ValueError(f"Save error {k} {format}") @@ -328,7 +329,7 @@ def save_main(self, folder): def save_run(self, folder): self._save_keys(self._run_key_saved, folder) - + def save_result(self, folder): self._save_keys(self._result_key_saved, folder) @@ -340,20 +341,22 @@ def load_folder(cls, folder): file = folder / f"{k}.npy" if file.exists(): result[k] = np.load(file) - elif format =="pickle": + elif format == "pickle": file = folder / f"{k}.pickle" if file.exists(): with open(file, mode="rb") as f: result[k] = pickle.load(f) - elif format =="sorting": + elif format == "sorting": from spikeinterface.core import load_extractor + result[k] = load_extractor(folder / k) - elif format =="zarr_templates": + elif format == "zarr_templates": from spikeinterface.core.template import Templates + result[k] = Templates.from_zarr(folder / k) return result - + def run(self): # run method raise NotImplementedError diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py index 091ab0820e..eb94b553a2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py @@ -3,6 +3,7 @@ This is not tested on github because not relevant at all. This only a local testing. """ + import pytest from pathlib import Path import os @@ -18,11 +19,7 @@ NoiseGeneratorRecording, ) from spikeinterface.core.generate import generate_unit_locations -from spikeinterface.generation import ( - DriftingTemplates, - make_linear_displacement, - InjectDriftingTemplatesRecording -) +from spikeinterface.generation import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording from probeinterface import generate_multi_columns_probe @@ -37,7 +34,6 @@ cache_folder = Path("cache_folder") / "sortingcomponents_benchmark" - def make_dataset(): recording, gt_sorting = generate_ground_truth_recording( durations=[60.0], @@ -57,26 +53,31 @@ def make_dataset(): ) return recording, gt_sorting -def compute_gt_templates(recording, gt_sorting, ms_before=2., ms_after=3., return_scaled=False, **job_kwargs): - spikes = gt_sorting.to_spike_vector()#[spike_indices] + +def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs): + spikes = gt_sorting.to_spike_vector() # [spike_indices] fs = recording.sampling_frequency nbefore = int(ms_before * fs / 1000) nafter = int(ms_after * fs / 1000) templates_array = estimate_templates( - recording, spikes, - gt_sorting.unit_ids, nbefore, nafter, return_scaled=return_scaled, + recording, + spikes, + gt_sorting.unit_ids, + nbefore, + nafter, + return_scaled=return_scaled, **job_kwargs, ) - + gt_templates = Templates( - templates_array=templates_array, - sampling_frequency=fs, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=gt_sorting.unit_ids, - probe=recording.get_probe(), - ) + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=gt_sorting.unit_ids, + probe=recording.get_probe(), + ) return gt_templates @@ -84,11 +85,10 @@ def make_drifting_dataset(): num_units = 15 duration = 125.5 - sampling_frequency = 30000. - ms_before = 1. - ms_after = 3. - displacement_sampling_frequency = 5. - + sampling_frequency = 30000.0 + ms_before = 1.0 + ms_after = 3.0 + displacement_sampling_frequency = 5.0 probe = generate_multi_columns_probe( num_columns=3, @@ -100,8 +100,6 @@ def make_drifting_dataset(): ) probe.set_device_channel_indices(np.arange(probe.contact_ids.size)) - - channel_locations = probe.contact_positions unit_locations = generate_unit_locations( @@ -116,9 +114,7 @@ def make_drifting_dataset(): seed=None, ) - - - nbefore = int(sampling_frequency * ms_before / 1000.) + nbefore = int(sampling_frequency * ms_before / 1000.0) generate_kwargs = dict( sampling_frequency=sampling_frequency, @@ -130,12 +126,9 @@ def make_drifting_dataset(): repolarization_ms=np.ones(num_units) * 0.8, ), unit_params_range=dict( - alpha=(4_000., 8_000.), + alpha=(4_000.0, 8_000.0), depolarization_ms=(0.09, 0.16), - ), - - ) templates_array = generate_templates(channel_locations, unit_locations, **generate_kwargs) @@ -149,28 +142,27 @@ def make_drifting_dataset(): drifting_templates = DriftingTemplates.from_static(templates) channel_locations = probe.contact_positions - start = np.array([0, -15.]) + start = np.array([0, -15.0]) stop = np.array([0, 12]) displacements = make_linear_displacement(start, stop, num_step=29) - sorting = generate_sorting( num_units=num_units, sampling_frequency=sampling_frequency, - durations = [duration,], - firing_rates=25.) + durations=[ + duration, + ], + firing_rates=25.0, + ) sorting - - - times = np.arange(0, duration, 1 / displacement_sampling_frequency) times # 2 rythm mid = (start + stop) / 2 freq0 = 0.1 - displacement_vector0 = np.sin(2 * np.pi * freq0 *times)[:, np.newaxis] * (start - stop) + mid + displacement_vector0 = np.sin(2 * np.pi * freq0 * times)[:, np.newaxis] * (start - stop) + mid # freq1 = 0.01 # displacement_vector1 = 0.2 * np.sin(2 * np.pi * freq1 *times)[:, np.newaxis] * (start - stop) + mid @@ -183,7 +175,6 @@ def make_drifting_dataset(): displacement_unit_factor = np.zeros((num_units, num_motion)) displacement_unit_factor[:, 0] = 1 - drifting_templates.precompute_displacements(displacements) direction = 1 @@ -196,7 +187,7 @@ def make_drifting_dataset(): num_channels=probe.contact_ids.size, sampling_frequency=sampling_frequency, durations=[duration], - noise_level=1., + noise_level=1.0, dtype="float32", ) @@ -207,7 +198,7 @@ def make_drifting_dataset(): displacement_vectors=[displacement_vectors], displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=displacement_unit_factor, - num_samples=[int(duration*sampling_frequency)], + num_samples=[int(duration * sampling_frequency)], amplitude_factor=None, ) @@ -218,19 +209,23 @@ def make_drifting_dataset(): displacement_vectors=[displacement_vectors], displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=np.zeros_like(displacement_unit_factor), - num_samples=[int(duration*sampling_frequency)], + num_samples=[int(duration * sampling_frequency)], amplitude_factor=None, ) - my_dict = _variable_from_namespace([ - drifting_rec, - static_rec, - sorting, - displacement_vectors, - displacement_sampling_frequency, - unit_locations, displacement_unit_factor, - unit_displacements - ], locals()) + my_dict = _variable_from_namespace( + [ + drifting_rec, + static_rec, + sorting, + displacement_vectors, + displacement_sampling_frequency, + unit_locations, + displacement_unit_factor, + unit_displacements, + ], + locals(), + ) return my_dict @@ -241,5 +236,3 @@ def _variable_from_namespace(objs, namespace): if namespace[name] is obj: d[name] = obj return d - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py index b60fb963fd..9e8f0e7404 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py @@ -12,10 +12,9 @@ from spikeinterface.sortingcomponents.benchmark.benchmark_clustering import ClusteringStudy - @pytest.mark.skip() def test_benchmark_clustering(): - + job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") recording, gt_sorting = make_dataset() @@ -23,19 +22,18 @@ def test_benchmark_clustering(): num_spikes = gt_sorting.to_spike_vector().size spike_indices = np.arange(0, num_spikes, 5) - # create study - study_folder = cache_folder / 'study_clustering' - datasets = {"toy" : (recording, gt_sorting)} + study_folder = cache_folder / "study_clustering" + datasets = {"toy": (recording, gt_sorting)} cases = {} - for method in ['random_projections', 'circus']: + for method in ["random_projections", "circus"]: cases[method] = { "label": f"{method} on toy", "dataset": "toy", - "init_kwargs": {'indices' : spike_indices}, - "params" : {"method" : method, "method_kwargs" : {}}, + "init_kwargs": {"indices": spike_indices}, + "params": {"method": method, "method_kwargs": {}}, } - + if study_folder.exists(): shutil.rmtree(study_folder) study = ClusteringStudy.create(study_folder, datasets=datasets, cases=cases) @@ -45,7 +43,6 @@ def test_benchmark_clustering(): study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() - study = ClusteringStudy(study_folder) # run and result @@ -64,8 +61,5 @@ def test_benchmark_clustering(): plt.show() - if __name__ == "__main__": test_benchmark_clustering() - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index 2af8bff1e5..805f5d8327 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -12,33 +12,41 @@ compute_sparsity, ) -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset, cache_folder, compute_gt_templates +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( + make_dataset, + cache_folder, + compute_gt_templates, +) from spikeinterface.sortingcomponents.benchmark.benchmark_matching import MatchingStudy @pytest.mark.skip() def test_benchmark_matching(): - + job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") recording, gt_sorting = make_dataset() # templates sparse - gt_templates = compute_gt_templates(recording, gt_sorting, ms_before=2., ms_after=3., return_scaled=False, **job_kwargs) + gt_templates = compute_gt_templates( + recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs + ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity(gt_templates, noise_levels, method='ptp', threshold=0.25) + sparsity = compute_sparsity(gt_templates, noise_levels, method="ptp", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) - # create study - study_folder = cache_folder / 'study_matching' - datasets = {"toy" : (recording, gt_sorting)} + study_folder = cache_folder / "study_matching" + datasets = {"toy": (recording, gt_sorting)} cases = {} - for engine in ['wobble', 'circus-omp-svd',]: + for engine in [ + "wobble", + "circus-omp-svd", + ]: cases[engine] = { "label": f"{engine} on toy", "dataset": "toy", - "params" : {"method" : engine, "method_kwargs" : {"templates" : gt_templates}}, + "params": {"method": engine, "method_kwargs": {"templates": gt_templates}}, } if study_folder.exists(): shutil.rmtree(study_folder) @@ -66,5 +74,3 @@ def test_benchmark_matching(): if __name__ == "__main__": test_benchmark_matching() - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index 0f009afa9a..7f24c07d3d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -7,16 +7,19 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_drifting_dataset, cache_folder +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( + make_drifting_dataset, + cache_folder, +) from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import MotionEstimationStudy + @pytest.mark.skip() def test_benchmark_motion_estimaton(): job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") - data = make_drifting_dataset() datasets = { @@ -25,39 +28,38 @@ def test_benchmark_motion_estimaton(): cases = {} for label, loc_method, est_method in [ - ("COM + KS", "center_of_mass", "iterative_template"), - ("Grid + Dec", "grid_convolution", "decentralized"), - ]: + ("COM + KS", "center_of_mass", "iterative_template"), + ("Grid + Dec", "grid_convolution", "decentralized"), + ]: cases[label] = dict( - label = label, + label=label, dataset="drifting_rec", init_kwargs=dict( unit_locations=data["unit_locations"], unit_displacements=data["unit_displacements"], displacement_sampling_frequency=data["displacement_sampling_frequency"], - direction="y" + direction="y", ), params=dict( - detect_kwargs=dict(method="locally_exclusive", detect_threshold=10.), + detect_kwargs=dict(method="locally_exclusive", detect_threshold=10.0), select_kwargs=None, localize_kwargs=dict(method=loc_method), estimate_motion_kwargs=dict( method=est_method, - bin_duration_s=1., - bin_um=5., + bin_duration_s=1.0, + bin_um=5.0, rigid=False, - win_step_um=50., - win_sigma_um=200., + win_step_um=50.0, + win_sigma_um=200.0, ), - ) + ), ) - study_folder = cache_folder / 'study_motion_estimation' + study_folder = cache_folder / "study_motion_estimation" if study_folder.exists(): shutil.rmtree(study_folder) study = MotionEstimationStudy.create(study_folder, datasets, cases) - # run and result study.run(**job_kwargs) study.compute_results() @@ -76,6 +78,3 @@ def test_benchmark_motion_estimaton(): if __name__ == "__main__": test_benchmark_motion_estimaton() - - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index cb8cc50b68..924b9ef385 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -10,10 +10,16 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_drifting_dataset, cache_folder +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( + make_drifting_dataset, + cache_folder, +) from spikeinterface.sortingcomponents.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import get_unit_disclacement, get_gt_motion_from_unit_discplacement +from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import ( + get_unit_disclacement, + get_gt_motion_from_unit_discplacement, +) @pytest.mark.skip() @@ -21,7 +27,6 @@ def test_benchmark_motion_interpolation(): job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") - data = make_drifting_dataset() datasets = { @@ -31,33 +36,31 @@ def test_benchmark_motion_interpolation(): duration = data["drifting_rec"].get_duration() channel_locations = data["drifting_rec"].get_channel_locations() - - - unit_displacements = get_unit_disclacement(data["displacement_vectors"], data["displacement_unit_factor"], direction_dim=1) + unit_displacements = get_unit_disclacement( + data["displacement_vectors"], data["displacement_unit_factor"], direction_dim=1 + ) bin_s = 1 temporal_bins = np.arange(0, duration, bin_s) - spatial_bins = np.linspace(np.min(channel_locations[:, 1]), - np.max(channel_locations[:, 1]), - 10 - ) + spatial_bins = np.linspace(np.min(channel_locations[:, 1]), np.max(channel_locations[:, 1]), 10) print(spatial_bins) gt_motion = get_gt_motion_from_unit_discplacement( - unit_displacements, data["displacement_sampling_frequency"], + unit_displacements, + data["displacement_sampling_frequency"], data["unit_locations"], - temporal_bins, spatial_bins, - direction_dim=1 + temporal_bins, + spatial_bins, + direction_dim=1, ) # fig, ax = plt.subplots() # ax.imshow(gt_motion.T) # plt.show() - cases = {} - bin_duration_s = 1. + bin_duration_s = 1.0 cases["static_SC2"] = dict( - label = "No drift - no correction - SC2", + label="No drift - no correction - SC2", dataset="data_static", init_kwargs=dict( drifting_recording=data["drifting_rec"], @@ -69,11 +72,11 @@ def test_benchmark_motion_interpolation(): recording_source="static", sorter_name="spykingcircus2", sorter_params=dict(), - ) + ), ) cases["drifting_SC2"] = dict( - label = "Drift - no correction - SC2", + label="Drift - no correction - SC2", dataset="data_static", init_kwargs=dict( drifting_recording=data["drifting_rec"], @@ -85,11 +88,11 @@ def test_benchmark_motion_interpolation(): recording_source="drifting", sorter_name="spykingcircus2", sorter_params=dict(), - ) + ), ) cases["drifting_SC2"] = dict( - label = "Drift - correction with GT - SC2", + label="Drift - correction with GT - SC2", dataset="data_static", init_kwargs=dict( drifting_recording=data["drifting_rec"], @@ -102,10 +105,10 @@ def test_benchmark_motion_interpolation(): sorter_name="spykingcircus2", sorter_params=dict(), correct_motion_kwargs=dict(spatial_interpolation_method="kriging"), - ) + ), ) - study_folder = cache_folder / 'study_motion_interpolation' + study_folder = cache_folder / "study_motion_interpolation" if study_folder.exists(): shutil.rmtree(study_folder) study = MotionInterpolationStudy.create(study_folder, datasets, cases) @@ -114,7 +117,6 @@ def test_benchmark_motion_interpolation(): study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() - # run and result study.run(**job_kwargs) study.compute_results() @@ -135,9 +137,5 @@ def test_benchmark_motion_interpolation(): plt.show() - - if __name__ == "__main__": test_benchmark_motion_interpolation() - - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py index baa756d521..4b85e488ff 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py @@ -22,22 +22,23 @@ def test_benchmark_peak_localization(): recording, gt_sorting = make_dataset() - # create study - study_folder = cache_folder / 'study_peak_localization' - datasets = {"toy" : (recording, gt_sorting)} + study_folder = cache_folder / "study_peak_localization" + datasets = {"toy": (recording, gt_sorting)} cases = {} - for method in ['center_of_mass', 'grid_convolution', 'monopolar_triangulation']: + for method in ["center_of_mass", "grid_convolution", "monopolar_triangulation"]: cases[method] = { "label": f"{method} on toy", "dataset": "toy", - "init_kwargs": {"gt_positions" : gt_sorting.get_property('gt_unit_locations')}, - "params" : {"ms_before" : 2, - "method" : method, - "method_kwargs" : {}, - "spike_retriver_kwargs" : {"channel_from_template" : False}} + "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, + "params": { + "ms_before": 2, + "method": method, + "method_kwargs": {}, + "spike_retriver_kwargs": {"channel_from_template": False}, + }, } - + if study_folder.exists(): shutil.rmtree(study_folder) study = PeakLocalizationStudy.create(study_folder, datasets=datasets, cases=cases) @@ -66,20 +67,22 @@ def test_benchmark_unit_localization(): recording, gt_sorting = make_dataset() # create study - study_folder = cache_folder / 'study_unit_localization' - datasets = {"toy" : (recording, gt_sorting)} + study_folder = cache_folder / "study_unit_localization" + datasets = {"toy": (recording, gt_sorting)} cases = {} - for method in ['center_of_mass', 'grid_convolution', 'monopolar_triangulation']: + for method in ["center_of_mass", "grid_convolution", "monopolar_triangulation"]: cases[method] = { "label": f"{method} on toy", "dataset": "toy", - "init_kwargs": {"gt_positions" : gt_sorting.get_property('gt_unit_locations')}, - "params" : {"ms_before" : 2, - "method" : method, - "method_kwargs" : {}, - "spike_retriver_kwargs" : {"channel_from_template" : False}} + "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, + "params": { + "ms_before": 2, + "method": method, + "method_kwargs": {}, + "spike_retriver_kwargs": {"channel_from_template": False}, + }, } - + if study_folder.exists(): shutil.rmtree(study_folder) study = UnitLocalizationStudy.create(study_folder, datasets=datasets, cases=cases) @@ -101,8 +104,6 @@ def test_benchmark_unit_localization(): plt.show() - if __name__ == "__main__": # test_benchmark_peak_localization() test_benchmark_unit_localization() - diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py index 78b59be489..f90a0c56d6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py @@ -13,9 +13,5 @@ def test_benchmark_peak_selection(): pass - - if __name__ == "__main__": test_benchmark_peak_selection() - - diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index c4fce76dad..24b4ca8022 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -61,7 +61,9 @@ def __init__( **backend_kwargs, ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - self.check_extensions(sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"]) + self.check_extensions( + sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"] + ) sorting = sorting_analyzer.sorting if unit_ids is None: @@ -183,10 +185,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): sorting_analyzer = data_plot["sorting_analyzer"] - import spikeinterface_gui - app = spikeinterface_gui.mkQApp() + + app = spikeinterface_gui.mkQApp() win = spikeinterface_gui.MainWindow(sorting_analyzer) win.show() app.exec_() - diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 7f761190f4..2d228d7d5f 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -103,7 +103,12 @@ def setUpClass(cls): print(f"Widgets tests: skipping backends - {cls.skip_backends}") - cls.backend_kwargs = {"matplotlib": {}, "sortingview": {}, "ipywidgets": {"display": False}, "spikeinterface_gui": {}} + cls.backend_kwargs = { + "matplotlib": {}, + "sortingview": {}, + "ipywidgets": {"display": False}, + "spikeinterface_gui": {}, + } cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) From 8cfb19a8f795844f425e19c74789707e46c65ba5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 09:56:41 +0000 Subject: [PATCH 149/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 8e6ccd3802..92999d5ea4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -92,7 +92,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = params["job_kwargs"] job_kwargs = fix_job_kwargs(job_kwargs) - job_kwargs.update({'verbose' : verbose, 'progress_bar' : verbose}) + job_kwargs.update({"verbose": verbose, "progress_bar": verbose}) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 826ac2abe8..c9e149fc75 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -44,7 +44,7 @@ class RandomProjectionClustering: _default_params = { "hdbscan_kwargs": { "min_cluster_size": 20, - "min_samples" : 10, + "min_samples": 10, "allow_single_cluster": True, "core_dist_n_jobs": os.cpu_count(), "cluster_selection_method": "leaf", From 1d7f796a604f66874e7f1be608839fa9b2b55c15 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 27 Feb 2024 11:39:12 +0100 Subject: [PATCH 150/192] analyzer in howto analyze_neuropixel --- doc/how_to/analyse_neuropixels.rst | 515 +++++++++++++------------ examples/how_to/analyse_neuropixels.py | 215 ++++++----- 2 files changed, 380 insertions(+), 350 deletions(-) diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index 37646c2146..255172efc0 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -4,22 +4,22 @@ Analyse Neuropixels datasets This example shows how to perform Neuropixels-specific analysis, including custom pre- and post-processing. -.. code:: ipython +.. code:: ipython3 %matplotlib inline -.. code:: ipython +.. code:: ipython3 import spikeinterface.full as si - + import numpy as np import matplotlib.pyplot as plt from pathlib import Path -.. code:: ipython - - base_folder = Path('/mnt/data/sam/DataSpikeSorting/neuropixel_example/') +.. code:: ipython3 + base_folder = Path('/mnt/data/sam/DataSpikeSorting/howto_si/neuropixel_example/') + spikeglx_folder = base_folder / 'Rec_1_10_11_2021_g0' @@ -29,7 +29,7 @@ Read the data The ``SpikeGLX`` folder can contain several “streams” (AP, LF and NIDQ). We need to specify which one to read: -.. code:: ipython +.. code:: ipython3 stream_names, stream_ids = si.get_neo_streams('spikeglx', spikeglx_folder) stream_names @@ -43,7 +43,7 @@ We need to specify which one to read: -.. code:: ipython +.. code:: ipython3 # we do not load the sync channel, so the probe is automatically loaded raw_rec = si.read_spikeglx(spikeglx_folder, stream_name='imec0.ap', load_sync_channel=False) @@ -54,11 +54,12 @@ We need to specify which one to read: .. parsed-literal:: - SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1138.145s + SpikeGLXRecordingExtractor: 384 channels - 30.0kHz - 1 segments - 34,145,070 samples + 1,138.15s (18.97 minutes) - int16 dtype - 24.42 GiB -.. code:: ipython +.. code:: ipython3 # we automaticaly have the probe loaded! raw_rec.get_probe().to_dataframe() @@ -73,11 +74,11 @@ We need to specify which one to read: .dataframe tbody tr th:only-of-type { vertical-align: middle; } - + .dataframe tbody tr th { vertical-align: top; } - + .dataframe thead th { text-align: right; } @@ -201,7 +202,7 @@ We need to specify which one to read: -.. code:: ipython +.. code:: ipython3 fig, ax = plt.subplots(figsize=(15, 10)) si.plot_probe_map(raw_rec, ax=ax, with_channel_ids=True) @@ -229,13 +230,13 @@ Let’s do something similar to the IBL destriping chain (See - instead of interpolating bad channels, we remove then. - instead of highpass_spatial_filter() we use common_reference() -.. code:: ipython +.. code:: ipython3 rec1 = si.highpass_filter(raw_rec, freq_min=400.) bad_channel_ids, channel_labels = si.detect_bad_channels(rec1) rec2 = rec1.remove_channels(bad_channel_ids) print('bad_channel_ids', bad_channel_ids) - + rec3 = si.phase_shift(rec2) rec4 = si.common_reference(rec3, operator="median", reference="global") rec = rec4 @@ -251,7 +252,8 @@ Let’s do something similar to the IBL destriping chain (See .. parsed-literal:: - CommonReferenceRecording: 383 channels - 1 segments - 30.0kHz - 1138.145s + CommonReferenceRecording: 383 channels - 30.0kHz - 1 segments - 34,145,070 samples + 1,138.15s (18.97 minutes) - int16 dtype - 24.36 GiB @@ -264,18 +266,18 @@ the ipywydgets interactive ploter .. code:: python %matplotlib widget - si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') + si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') Note that using this ipywydgets make possible to explore diffrents preprocessing chain wihtout to save the entire file to disk. Everything is lazy, so you can change the previsous cell (parameters, step order, …) and visualize it immediatly. -.. code:: ipython +.. code:: ipython3 # here we use static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) - + si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) @@ -287,7 +289,7 @@ is lazy, so you can change the previsous cell (parameters, step order, .. image:: analyse_neuropixels_files/analyse_neuropixels_13_0.png -.. code:: ipython +.. code:: ipython3 # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) @@ -299,7 +301,7 @@ is lazy, so you can change the previsous cell (parameters, step order, .. parsed-literal:: - + @@ -326,25 +328,13 @@ Depending on the complexity of the preprocessing chain, this operation can take a while. However, we can make use of the powerful parallelization mechanism of SpikeInterface. -.. code:: ipython +.. code:: ipython3 job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) - + rec = rec.save(folder=base_folder / 'preprocess', format='binary', **job_kwargs) - -.. parsed-literal:: - - write_binary_recording with n_jobs = 40 and chunk_size = 30000 - - - -.. parsed-literal:: - - write_binary_recording: 0%| | 0/1139 [00:00`__ for more information and a list of all supported metrics. @@ -697,19 +714,24 @@ Some metrics are based on PCA (like ``'isolation_distance', 'l_ratio', 'd_prime'``) and require to estimate PCA for their computation. This can be achieved with: -``si.compute_principal_components(waveform_extractor)`` +``analyzer.compute("principal_components")`` -.. code:: ipython +.. code:: ipython3 - metrics = si.compute_quality_metrics(we, metric_names=['firing_rate', 'presence_ratio', 'snr', - 'isi_violation', 'amplitude_cutoff']) + metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'amplitude_cutoff'] + + + # metrics = analyzer.compute("quality_metrics").get_data() + # equivalent to + metrics = si.compute_quality_metrics(analyzer, metric_names=metric_names) + metrics .. parsed-literal:: - /home/samuel.garcia/Documents/SpikeInterface/spikeinterface/spikeinterface/qualitymetrics/misc_metrics.py:511: UserWarning: Units [11, 13, 15, 18, 21, 22] have too few spikes and amplitude_cutoff is set to NaN - warnings.warn(f"Units {nan_units} have too few spikes and " + /home/samuel.garcia/Documents/SpikeInterface/spikeinterface/src/spikeinterface/qualitymetrics/misc_metrics.py:846: UserWarning: Some units have too few spikes : amplitude_cutoff is set to NaN + warnings.warn(f"Some units have too few spikes : amplitude_cutoff is set to NaN") @@ -721,11 +743,11 @@ PCA for their computation. This can be achieved with: .dataframe tbody tr th:only-of-type { vertical-align: middle; } - + .dataframe tbody tr th { vertical-align: top; } - + .dataframe thead th { text-align: right; } @@ -734,293 +756,293 @@ PCA for their computation. This can be achieved with: + amplitude_cutoff firing_rate - presence_ratio - snr isi_violations_ratio isi_violations_count - amplitude_cutoff + presence_ratio + snr 0 + 0.011528 0.798668 + 4.591436 + 10.0 1.000000 - 1.324698 - 4.591437 - 10 - 0.011528 + 1.430458 1 - 9.886261 - 1.000000 - 1.959527 - 5.333803 - 1780 0.000062 + 9.886262 + 5.333802 + 1780.0 + 1.000000 + 1.938214 2 + 0.002567 2.849373 - 1.000000 - 1.467690 3.859813 - 107 - 0.002567 + 107.0 + 1.000000 + 1.586939 3 + 0.000099 5.404408 + 3.519589 + 351.0 1.000000 - 1.253708 - 3.519590 - 351 - 0.000188 + 2.073651 4 + 0.001487 4.772678 - 1.000000 - 1.722377 3.947255 - 307 - 0.001487 + 307.0 + 1.000000 + 1.595303 5 + 0.001190 1.802055 - 1.000000 - 2.358286 6.403293 - 71 - 0.001422 + 71.0 + 1.000000 + 2.411436 6 + 0.003508 0.531567 + 94.320694 + 91.0 0.888889 - 3.359229 - 94.320701 - 91 - 0.004900 + 3.377035 7 - 5.400014 - 1.000000 - 4.653080 - 0.612662 - 61 0.000119 + 5.400015 + 0.612662 + 61.0 + 1.000000 + 4.631496 8 - 10.563679 - 1.000000 - 8.267220 - 0.073487 - 28 0.000265 + 10.563680 + 0.073487 + 28.0 + 1.000000 + 8.178637 9 + 0.000968 8.181734 - 1.000000 - 4.546735 0.730646 - 167 - 0.000968 + 167.0 + 1.000000 + 3.900670 10 - 16.839681 - 1.000000 - 5.094325 - 0.298477 - 289 0.000259 + 16.839682 + 0.298477 + 289.0 + 1.000000 + 5.044798 11 + NaN 0.007029 - 0.388889 - 4.032887 0.000000 - 0 - NaN + 0.0 + 0.388889 + 4.032886 12 - 10.184114 - 1.000000 - 4.780558 - 0.720070 - 255 0.000264 + 10.184115 + 0.720070 + 255.0 + 1.000000 + 4.767068 13 + NaN 0.005272 + 0.000000 + 0.0 0.222222 4.627749 - 0.000000 - 0 - NaN 14 - 10.047928 - 1.000000 - 4.984704 - 0.771631 - 266 0.000371 + 10.047929 + 0.771631 + 266.0 + 1.000000 + 5.185702 15 + NaN 0.107192 + 0.000000 + 0.0 0.888889 4.248180 - 0.000000 - 0 - NaN 16 + 0.000452 0.535081 - 0.944444 - 2.326990 8.183362 - 8 - 0.000452 + 8.0 + 0.944444 + 2.309993 17 - 4.650549 - 1.000000 - 1.998918 - 6.391674 - 472 0.000196 + 4.650550 + 6.391673 + 472.0 + 1.000000 + 2.064208 18 + NaN 0.077319 + 293.942411 + 6.0 0.722222 6.619197 - 293.942433 - 6 - NaN 19 - 7.088727 - 1.000000 - 1.715093 + 0.000053 + 7.088728 5.146421 - 883 - 0.000268 + 883.0 + 1.000000 + 2.057868 20 - 9.821243 + 0.000071 + 9.821244 + 5.322676 + 1753.0 1.000000 - 1.575338 - 5.322677 - 1753 - 0.000059 + 1.688922 21 + NaN 0.046567 + 405.178005 + 3.0 0.666667 - 5.899877 - 405.178035 - 3 - NaN + 5.899876 22 + NaN 0.094891 + 65.051727 + 2.0 0.722222 6.476350 - 65.051732 - 2 - NaN 23 + 0.002927 1.849501 + 13.699103 + 160.0 1.000000 - 2.493723 - 13.699104 - 160 - 0.002927 + 2.282473 24 + 0.003143 1.420733 - 1.000000 - 1.549977 4.352889 - 30 - 0.004044 + 30.0 + 1.000000 + 1.573989 25 + 0.002457 0.675661 + 56.455510 + 88.0 0.944444 - 4.110071 - 56.455515 - 88 - 0.002457 + 4.107643 26 + 0.003152 0.642273 - 1.000000 - 1.981111 2.129918 - 3 - 0.003152 + 3.0 + 1.000000 + 1.902601 27 + 0.000229 1.012173 + 6.860924 + 24.0 0.888889 - 1.843515 - 6.860925 - 24 - 0.000229 + 1.854307 28 + 0.002856 0.804818 + 38.433003 + 85.0 0.888889 - 3.662210 - 38.433006 - 85 - 0.002856 + 3.755829 29 + 0.002854 1.012173 - 1.000000 - 1.097260 1.143487 - 4 - 0.000845 + 4.0 + 1.000000 + 1.345607 30 + 0.005439 0.649302 + 63.910953 + 92.0 0.888889 - 4.243889 - 63.910958 - 92 - 0.005439 + 4.168347 @@ -1034,12 +1056,12 @@ Curation using metrics A very common curation approach is to threshold these metrics to select *good* units: -.. code:: ipython +.. code:: ipython3 amplitude_cutoff_thresh = 0.1 isi_violations_ratio_thresh = 1 presence_ratio_thresh = 0.9 - + our_query = f"(amplitude_cutoff < {amplitude_cutoff_thresh}) & (isi_violations_ratio < {isi_violations_ratio_thresh}) & (presence_ratio > {presence_ratio_thresh})" print(our_query) @@ -1049,7 +1071,7 @@ A very common curation approach is to threshold these metrics to select (amplitude_cutoff < 0.1) & (isi_violations_ratio < 1) & (presence_ratio > 0.9) -.. code:: ipython +.. code:: ipython3 keep_units = metrics.query(our_query) keep_unit_ids = keep_units.index.values @@ -1071,43 +1093,43 @@ In order to export the final results we need to make a copy of the the waveforms, but only for the selected units (so we can avoid to compute them again). -.. code:: ipython +.. code:: ipython3 - we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / 'waveforms_clean') + analyzer_clean = analyzer.select_units(keep_unit_ids, folder=base_folder / 'analyzer_clean', format='binary_folder') -.. code:: ipython +.. code:: ipython3 - we_clean + analyzer_clean .. parsed-literal:: - WaveformExtractor: 383 channels - 6 units - 1 segments - before:45 after:60 n_per_units:500 - sparse + SortingAnalyzer: 383 channels - 6 units - 1 segments - binary_folder - sparse - has recording + Loaded 9 extenstions: random_spikes, waveforms, templates, noise_levels, correlograms, unit_locations, spike_amplitudes, template_similarity, quality_metrics Then we export figures to a report folder -.. code:: ipython +.. code:: ipython3 # export spike sorting report to a folder - si.export_report(we_clean, base_folder / 'report', format='png') + si.export_report(analyzer_clean, base_folder / 'report', format='png') -.. code:: ipython +.. code:: ipython3 - we_clean = si.load_waveforms(base_folder / 'waveforms_clean') - we_clean + analyzer_clean = si.load_sorting_analyzer(base_folder / 'analyzer_clean') + analyzer_clean .. parsed-literal:: - WaveformExtractor: 383 channels - 6 units - 1 segments - before:45 after:60 n_per_units:500 - sparse + SortingAnalyzer: 383 channels - 6 units - 1 segments - binary_folder - sparse - has recording + Loaded 9 extenstions: random_spikes, waveforms, templates, noise_levels, template_similarity, spike_amplitudes, correlograms, unit_locations, quality_metrics @@ -1115,4 +1137,5 @@ And push the results to sortingview webased viewer .. code:: python - si.plot_sorting_summary(we_clean, backend='sortingview') + si.plot_sorting_summary(analyzer_clean, backend='sortingview') + diff --git a/examples/how_to/analyse_neuropixels.py b/examples/how_to/analyse_neuropixels.py index 29d19f2331..3a936b072c 100644 --- a/examples/how_to/analyse_neuropixels.py +++ b/examples/how_to/analyse_neuropixels.py @@ -7,7 +7,7 @@ # extension: .py # format_name: light # format_version: '1.5' -# jupytext_version: 1.14.4 +# jupytext_version: 1.14.6 # kernelspec: # display_name: Python 3 (ipykernel) # language: python @@ -28,9 +28,9 @@ from pathlib import Path # + -base_folder = Path("/mnt/data/sam/DataSpikeSorting/neuropixel_example/") +base_folder = Path('/mnt/data/sam/DataSpikeSorting/howto_si/neuropixel_example/') -spikeglx_folder = base_folder / "Rec_1_10_11_2021_g0" +spikeglx_folder = base_folder / 'Rec_1_10_11_2021_g0' # - @@ -40,14 +40,14 @@ # We need to specify which one to read: # -stream_names, stream_ids = si.get_neo_streams("spikeglx", spikeglx_folder) +stream_names, stream_ids = si.get_neo_streams('spikeglx', spikeglx_folder) stream_names # we do not load the sync channel, so the probe is automatically loaded -raw_rec = si.read_spikeglx(spikeglx_folder, stream_name="imec0.ap", load_sync_channel=False) +raw_rec = si.read_spikeglx(spikeglx_folder, stream_name='imec0.ap', load_sync_channel=False) raw_rec -# we automatically have the probe loaded! +# we automaticaly have the probe loaded! raw_rec.get_probe().to_dataframe() fig, ax = plt.subplots(figsize=(15, 10)) @@ -58,15 +58,15 @@ # # Let's do something similar to the IBL destriping chain (See :ref:`ibl_destripe`) to preprocess the data but: # -# * instead of interpolating bad channels, we remove them. +# * instead of interpolating bad channels, we remove then. # * instead of highpass_spatial_filter() we use common_reference() # # + -rec1 = si.highpass_filter(raw_rec, freq_min=400.0) +rec1 = si.highpass_filter(raw_rec, freq_min=400.) bad_channel_ids, channel_labels = si.detect_bad_channels(rec1) rec2 = rec1.remove_channels(bad_channel_ids) -print("bad_channel_ids", bad_channel_ids) +print('bad_channel_ids', bad_channel_ids) rec3 = si.phase_shift(rec2) rec4 = si.common_reference(rec3, operator="median", reference="global") @@ -78,39 +78,33 @@ # # -# The preprocessing steps can be interactively explored with the ipywidgets interactive plotter +# Interactive explore the preprocess steps could de done with this with the ipywydgets interactive ploter # # ```python # # %matplotlib widget -# si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') +# si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') # ``` # -# Note that using this ipywidgets make possible to explore different preprocessing chains without saving the entire file to disk. -# Everything is lazy, so you can change the previous cell (parameters, step order, ...) and visualize it immediately. +# Note that using this ipywydgets make possible to explore diffrents preprocessing chain wihtout to save the entire file to disk. +# Everything is lazy, so you can change the previsous cell (parameters, step order, ...) and visualize it immediatly. # # # + -# here we use a static plot using matplotlib backend +# here we use static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) -si.plot_traces(rec1, backend="matplotlib", clim=(-50, 50), ax=axs[0]) -si.plot_traces(rec4, backend="matplotlib", clim=(-50, 50), ax=axs[1]) -si.plot_traces(rec, backend="matplotlib", clim=(-50, 50), ax=axs[2]) -for i, label in enumerate(("filter", "cmr", "final")): +si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) +si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) +si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) +for i, label in enumerate(('filter', 'cmr', 'final')): axs[i].set_title(label) # - # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) -some_chans = rec.channel_ids[ - [ - 100, - 150, - 200, - ] -] -si.plot_traces({"filter": rec1, "cmr": rec4}, backend="matplotlib", mode="line", ax=ax, channel_ids=some_chans) +some_chans = rec.channel_ids[[100, 150, 200, ]] +si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) # ### Should we save the preprocessed data to a binary file? @@ -119,14 +113,14 @@ # # Saving is not necessarily a good choice, as it consumes a lot of disk space and sometimes the writing to disk can be slower than recomputing the preprocessing chain on-the-fly. # -# Here, we decide to save it because Kilosort requires a binary file as input, so the preprocessed recording will need to be saved at some point. +# Here, we decide to do save it because Kilosort requires a binary file as input, so the preprocessed recording will need to be saved at some point. # # Depending on the complexity of the preprocessing chain, this operation can take a while. However, we can make use of the powerful parallelization mechanism of SpikeInterface. # + -job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True) +job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) -rec = rec.save(folder=base_folder / "preprocess", format="binary", **job_kwargs) +rec = rec.save(folder=base_folder / 'preprocess', format='binary', **job_kwargs) # - # our recording now points to the new binary folder @@ -134,7 +128,7 @@ # ## Check spiking activity and drift before spike sorting # -# A good practice before running a spike sorter is to check the "peaks activity" and the presence of drift. +# A good practice before running a spike sorter is to check the "peaks activity" and the presence of drifts. # # SpikeInterface has several tools to: # @@ -148,24 +142,24 @@ # Noise levels can be estimated on the scaled traces or on the raw (`int16`) traces. # -# we can estimate the noise on the scaled traces (microV) or on the raw ones (which in our case are int16). +# we can estimate the noise on the scaled traces (microV) or on the raw one (which is in our case int16). noise_levels_microV = si.get_noise_levels(rec, return_scaled=True) noise_levels_int16 = si.get_noise_levels(rec, return_scaled=False) fig, ax = plt.subplots() _ = ax.hist(noise_levels_microV, bins=np.arange(5, 30, 2.5)) -ax.set_xlabel("noise [microV]") +ax.set_xlabel('noise [microV]') # ### Detect and localize peaks # -# SpikeInterface includes built-in algorithms to detect peaks and also to localize their positions. +# SpikeInterface includes built-in algorithms to detect peaks and also to localize their position. # # This is part of the **sortingcomponents** module and needs to be imported explicitly. # # The two functions (detect + localize): # -# * can be run in parallel +# * can be run parallel # * are very fast when the preprocessed recording is already saved (and a bit slower otherwise) # * implement several methods # @@ -174,54 +168,53 @@ # + from spikeinterface.sortingcomponents.peak_detection import detect_peaks -job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True) -peaks = detect_peaks( - rec, method="locally_exclusive", noise_levels=noise_levels_int16, detect_threshold=5, radius_um=50.0, **job_kwargs -) +job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) +peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, + detect_threshold=5, local_radius_um=50., **job_kwargs) peaks # + from spikeinterface.sortingcomponents.peak_localization import localize_peaks -peak_locations = localize_peaks(rec, peaks, method="center_of_mass", radius_um=50.0, **job_kwargs) +peak_locations = localize_peaks(rec, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs) # - -# ### Check for drift +# ### Check for drifts # -# We can *manually* check for drift with a simple scatter plots of peak times VS estimated peak depths. +# We can *manually* check for drifts with a simple scatter plots of peak times VS estimated peak depths. # # In this example, we do not see any apparent drift. # -# In case we notice apparent drift in the recording, one can use the SpikeInterface modules to estimate and correct motion. See the documentation for motion estimation and correction for more details. +# In case we notice apparent drifts in the recording, one can use the SpikeInterface modules to estimate and correct motion. See the documentation for motion estimation and correction for more details. -# check for drift +# check for drifts fs = rec.sampling_frequency fig, ax = plt.subplots(figsize=(10, 8)) -ax.scatter(peaks["sample_index"] / fs, peak_locations["y"], color="k", marker=".", alpha=0.002) +ax.scatter(peaks['sample_ind'] / fs, peak_locations['y'], color='k', marker='.', alpha=0.002) # + -# we can also use the peak location estimates to have insight of cluster separation before sorting +# we can also use the peak location estimates to have an insight of cluster separation before sorting fig, ax = plt.subplots(figsize=(15, 10)) si.plot_probe_map(rec, ax=ax, with_channel_ids=True) ax.set_ylim(-100, 150) -ax.scatter(peak_locations["x"], peak_locations["y"], color="purple", alpha=0.002) +ax.scatter(peak_locations['x'], peak_locations['y'], color='purple', alpha=0.002) # - # ## Run a spike sorter # -# Despite beingthe most critical part of the pipeline, spike sorting in SpikeInterface is dead-simple: one function. +# Even if running spike sorting is probably the most critical part of the pipeline, in SpikeInterface this is dead-simple: one function. # # **Important notes**: # -# * most of sorters are wrapped from external tools (kilosort, kilosort2.5, spykingcircus, mountainsort4 ...) that often also need other requirements (e.g., MATLAB, CUDA) -# * some sorters are internally developed (spykingcircus2) -# * external sorters can be run inside of a container (docker, singularity) WITHOUT pre-installation +# * most of sorters are wrapped from external tools (kilosort, kisolort2.5, spykingcircus, montainsort4 ...) that often also need other requirements (e.g., MATLAB, CUDA) +# * some sorters are internally developed (spyekingcircus2) +# * external sorter can be run inside a container (docker, singularity) WITHOUT pre-installation # -# Please carefully read the `spikeinterface.sorters` documentation for more information. +# Please carwfully read the `spikeinterface.sorters` documentation for more information. # -# In this example: +# In this example: # # * we will run kilosort2.5 # * we apply no drift correction (because we don't have drift) @@ -229,82 +222,94 @@ # # check default params for kilosort2.5 -si.get_default_sorter_params("kilosort2_5") +si.get_default_sorter_params('kilosort2_5') # + # run kilosort2.5 without drift correction -params_kilosort2_5 = {"do_correction": False} - -sorting = si.run_sorter( - "kilosort2_5", - rec, - output_folder=base_folder / "kilosort2.5_output", - docker_image=True, - verbose=True, - **params_kilosort2_5, -) +params_kilosort2_5 = {'do_correction': False} + +sorting = si.run_sorter('kilosort2_5', rec, output_folder=base_folder / 'kilosort2.5_output', + docker_image=True, verbose=True, **params_kilosort2_5) # - -# the results can be read back for future sessions -sorting = si.read_sorter_folder(base_folder / "kilosort2.5_output") +# the results can be read back for futur session +sorting = si.read_sorter_folder(base_folder / 'kilosort2.5_output') -# here we have 31 units in our recording +# here we have 31 untis in our recording sorting # ## Post processing # -# All postprocessing steps are based on the **WaveformExtractor** object. +# All the postprocessing step is based on the **SortingAnalyzer** object. +# +# This object combines a `sorting` and a `recording` object. It will also help to run some computation aka "extensions" to +# get an insight on the qulity of units. # -# This object combines a `recording` and a `sorting` object and extracts some waveform snippets (500 by default) for each unit. +# The first extentions we will run are: +# * select some spikes per units +# * etxract waveforms +# * compute templates +# * compute noise levels # # Note that we use the `sparse=True` option. This option is important because the waveforms will be extracted only for a few channels around the main channel of each unit. This saves tons of disk space and speeds up the waveforms extraction and further processing. # +# Note that our object is not persistent to disk because we use `format="memory"` we could use `format="binary_folder"` or `format="zarr"`. -we = si.extract_waveforms( - rec, - sorting, - folder=base_folder / "waveforms_kilosort2.5", - sparse=True, - max_spikes_per_unit=500, - ms_before=1.5, - ms_after=2.0, - **job_kwargs, -) +# + -# the `WaveformExtractor` contains all information and is persistent on disk -print(we) -print(we.folder) +analyzer = si.create_sorting_analyzer(sorting, rec, sparse=True, format="memory") +analyzer +# - -# the `WaveformExtrator` can be easily loaded back from its folder -we = si.load_waveforms(base_folder / "waveforms_kilosort2.5") -we +analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500) +analyzer.compute("waveforms", ms_before=1.5,ms_after=2., **job_kwargs) +analyzer.compute("templates", operators=["average", "median", "std"]) +analyzer.compute("noise_levels") +analyzer -# Many additional computations rely on the `WaveformExtractor`. +# Many additional computations rely on the `SortingAnalyzer`. # Some computations are slower than others, but can be performed in parallel using the `**job_kwargs` mechanism. # -# Every computation will also be persistent on disk in the same folder, since they represent waveform extensions. +# -_ = si.compute_noise_levels(we) -_ = si.compute_correlograms(we) -_ = si.compute_unit_locations(we) -_ = si.compute_spike_amplitudes(we, **job_kwargs) -_ = si.compute_template_similarity(we) +analyzer.compute("correlograms") +analyzer.compute("unit_locations") +analyzer.compute("spike_amplitudes", **job_kwargs) +analyzer.compute("template_similarity") +analyzer +# Our `SortingAnalyzer` can be saved to disk using `save_as()` which make a copy of the analyzer and all computed extensions. + +analyzer_saved = analyzer.save_as(folder=base_folder / "analyzer", format="binary_folder") +analyzer_saved + # ## Quality metrics # -# We have a single function `compute_quality_metrics(WaveformExtractor)` that returns a `pandas.Dataframe` with the desired metrics. +# We have a single function `compute_quality_metrics(SortingAnalyzer)` that returns a `pandas.Dataframe` with the desired metrics. +# +# Note that this function is also an extension and so can be saved. And so this is equivalent to do : +# `metrics = analyzer.compute("quality_metrics").get_data()` +# # # Please visit the [metrics documentation](https://spikeinterface.readthedocs.io/en/latest/modules/qualitymetrics.html) for more information and a list of all supported metrics. # -# Some metrics are based on PCA (like `'isolation_distance', 'l_ratio', 'd_prime'`) and require PCA values for their computation. This can be achieved with: +# Some metrics are based on PCA (like `'isolation_distance', 'l_ratio', 'd_prime'`) and require to estimate PCA for their computation. This can be achieved with: # -# `si.compute_principal_components(waveform_extractor)` +# `analyzer.compute("principal_components")` +# +# + +# + +metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'amplitude_cutoff'] + + +# metrics = analyzer.compute("quality_metrics").get_data() +# equivalent to +metrics = si.compute_quality_metrics(analyzer, metric_names=metric_names) -metrics = si.compute_quality_metrics( - we, metric_names=["firing_rate", "presence_ratio", "snr", "isi_violation", "amplitude_cutoff"] -) metrics +# - # ## Curation using metrics # @@ -325,22 +330,24 @@ # ## Export final results to disk folder and visulize with sortingview # -# In order to export the final results we need to make a copy of the the waveforms, but only for the selected units (so we can avoid computing them again). +# In order to export the final results we need to make a copy of the the waveforms, but only for the selected units (so we can avoid to compute them again). -we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / "waveforms_clean") +analyzer_clean = analyzer.select_units(keep_unit_ids, folder=base_folder / 'analyzer_clean', format='binary_folder') -we_clean +analyzer_clean # Then we export figures to a report folder # export spike sorting report to a folder -si.export_report(we_clean, base_folder / "report", format="png") +si.export_report(analyzer_clean, base_folder / 'report', format='png') -we_clean = si.load_waveforms(base_folder / "waveforms_clean") -we_clean +analyzer_clean = si.load_sorting_analyzer(base_folder / 'analyzer_clean') +analyzer_clean # And push the results to sortingview webased viewer # # ```python -# si.plot_sorting_summary(we_clean, backend='sortingview') +# si.plot_sorting_summary(analyzer_clean, backend='sortingview') # ``` + + From 1ccd7df3a35735e8f1b3a582f9a5c7469d0ed1d1 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 27 Feb 2024 11:50:22 +0100 Subject: [PATCH 151/192] wip --- .../sortingcomponents/clustering/random_projections.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 826ac2abe8..9b357b6511 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -44,7 +44,6 @@ class RandomProjectionClustering: _default_params = { "hdbscan_kwargs": { "min_cluster_size": 20, - "min_samples" : 10, "allow_single_cluster": True, "core_dist_n_jobs": os.cpu_count(), "cluster_selection_method": "leaf", @@ -53,7 +52,7 @@ class RandomProjectionClustering: "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "ptp", "threshold": 0.25}, "radius_um": 100, - "nb_projections": 20, + "nb_projections": 10, "ms_before": 0.5, "ms_after": 0.5, "random_seed": 42, From 8609fc419cd34fa7059eb465733b47e69acb45be Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 27 Feb 2024 15:26:01 +0100 Subject: [PATCH 152/192] Fix a bug in Silence_period --- src/spikeinterface/preprocessing/silence_periods.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 6413ec06b4..a1bff8a6bc 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -119,8 +119,9 @@ def get_traces(self, start_frame, end_frame, channel_indices): if self.mode == "zeros": traces[onset:offset, :] = 0 elif self.mode == "noise": + num_samples = traces[onset:offset, :].shape[0] traces[onset:offset, :] = self.noise_levels[channel_indices] * np.random.randn( - offset - onset, num_channels + num_samples, num_channels ) return traces From 0bbc8b1dadc7154e28e98a39cb94b0e726de529d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 27 Feb 2024 17:03:36 +0100 Subject: [PATCH 153/192] Fix in benchmarks --- .../sorters/internal/spyking_circus2.py | 2 +- .../benchmark/benchmark_clustering.py | 62 ++++++++++++++++++- .../sortingcomponents/clustering/circus.py | 7 ++- .../clustering/random_projections.py | 4 +- 4 files changed, 68 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 92999d5ea4..394973b204 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -28,7 +28,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 50}, "sparsity": {"method": "ptp", "threshold": 0.25}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 373fe6b37b..b67eddd6f6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -12,6 +12,7 @@ ) from spikeinterface.comparison.comparisontools import make_matching_events +import matplotlib.patches as mpatches # from spikeinterface.postprocessing import get_template_extremum_channel from spikeinterface.core import get_noise_levels @@ -126,7 +127,7 @@ def homogeneity_score(self, ignore_noise=True, case_keys=None): if ignore_noise: gt_labels = gt_labels[~noise] found_labels = found_labels[~noise] - print(self.cases[key]["label"], homogeneity_score(gt_labels, found_labels), np.mean(noise)) + print(self.cases[key]["label"], "Homogeneity:", homogeneity_score(gt_labels, found_labels), "Noise (%):", np.mean(noise)) def plot_agreements(self, case_keys=None, figsize=(15, 15)): if case_keys is None: @@ -236,6 +237,65 @@ def plot_metrics_vs_snr(self, metric="cosine", case_keys=None, figsize=(15, 5)): label = self.cases[key]["label"] axs[count].set_title(label) + def plot_comparison_clustering( + self, + case_keys=None, + performance_names=["accuracy", "recall", "precision"], + colors=["g", "b", "r"], + ylim=(-0.1, 1.1), + figsize=None, + ): + + if case_keys is None: + case_keys = list(self.cases.keys()) + + num_methods = len(case_keys) + fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10)) + for i, key1 in enumerate(case_keys): + for j, key2 in enumerate(case_keys): + if len(axs.shape) > 1: + ax = axs[i, j] + else: + ax = axs[j] + comp1 = self.get_result(key1)["gt_comparison"] + comp2 = self.get_result(key2)["gt_comparison"] + if i <= j: + for performance, color in zip(performance_names, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.plot(perf2, perf1, ".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + label1 = self.cases[key1]["label"] + label2 = self.cases[key2]["label"] + if j == i: + ax.set_ylabel(f"{label1}") + else: + ax.set_yticks([]) + if i == j: + ax.set_xlabel(f"{label2}") + else: + ax.set_xticks([]) + if i == num_methods - 1 and j == num_methods - 1: + patches = [] + for color, name in zip(colors, performance_names): + patches.append(mpatches.Patch(color=color, label=name)) + ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) + else: + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + plt.tight_layout(h_pad=0, w_pad=0) + + # def _scatter_clusters( # self, diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 12217cf3a2..cf9fa8887a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -47,6 +47,7 @@ class CircusClustering: _default_params = { "hdbscan_kwargs": { "min_cluster_size": 20, + "min_samples" : 1, "allow_single_cluster": True, "core_dist_n_jobs": -1, "cluster_selection_method": "eom", @@ -54,7 +55,7 @@ class CircusClustering: "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "ptp", "threshold": 0.25}, - "radius_um": 100, + "radius_um": 50, "n_svd": [5, 10], "ms_before": 0.5, "ms_after": 0.5, @@ -96,7 +97,7 @@ def main_function(cls, recording, peaks, params): tmp_folder.mkdir(parents=True, exist_ok=True) # SVD for time compression - few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) + few_peaks = select_peaks(peaks, method="uniform", n_peaks=10000) few_wfs = extract_waveform_at_max_channel( recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"] ) @@ -158,7 +159,7 @@ def main_function(cls, recording, peaks, params): clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) local_labels = clustering[0] except Exception: - local_labels = -1 * np.ones(len(hdbscan_data)) + local_labels = np.zeros(len(hdbscan_data)) valid_clusters = local_labels > -1 if np.sum(valid_clusters) > 0: local_labels[valid_clusters] += nb_clusters diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 9b357b6511..9e17df3e49 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -45,13 +45,13 @@ class RandomProjectionClustering: "hdbscan_kwargs": { "min_cluster_size": 20, "allow_single_cluster": True, - "core_dist_n_jobs": os.cpu_count(), + "core_dist_n_jobs": -1, "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "ptp", "threshold": 0.25}, - "radius_um": 100, + "radius_um": 50, "nb_projections": 10, "ms_before": 0.5, "ms_after": 0.5, From a0cf6c7ac3652df8db6ed3d818d6fc3206c91008 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 16:05:13 +0000 Subject: [PATCH 154/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../benchmark/benchmark_clustering.py | 10 ++++++++-- .../sortingcomponents/clustering/circus.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index b67eddd6f6..b8afc813ab 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -13,6 +13,7 @@ from spikeinterface.comparison.comparisontools import make_matching_events import matplotlib.patches as mpatches + # from spikeinterface.postprocessing import get_template_extremum_channel from spikeinterface.core import get_noise_levels @@ -127,7 +128,13 @@ def homogeneity_score(self, ignore_noise=True, case_keys=None): if ignore_noise: gt_labels = gt_labels[~noise] found_labels = found_labels[~noise] - print(self.cases[key]["label"], "Homogeneity:", homogeneity_score(gt_labels, found_labels), "Noise (%):", np.mean(noise)) + print( + self.cases[key]["label"], + "Homogeneity:", + homogeneity_score(gt_labels, found_labels), + "Noise (%):", + np.mean(noise), + ) def plot_agreements(self, case_keys=None, figsize=(15, 15)): if case_keys is None: @@ -296,7 +303,6 @@ def plot_comparison_clustering( plt.tight_layout(h_pad=0, w_pad=0) - # def _scatter_clusters( # self, # xs, diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index cf9fa8887a..f4059f7d35 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -47,7 +47,7 @@ class CircusClustering: _default_params = { "hdbscan_kwargs": { "min_cluster_size": 20, - "min_samples" : 1, + "min_samples": 1, "allow_single_cluster": True, "core_dist_n_jobs": -1, "cluster_selection_method": "eom", From ccadee4e5841ea49e8504e0b726ab11e79254879 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 28 Feb 2024 09:20:08 +0100 Subject: [PATCH 155/192] WIP --- .../sortingcomponents/matching/circus.py | 7 +++---- .../waveforms/temporal_pca.py | 21 ++++++++++++------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 596ad84e64..95ff4b2d40 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -111,7 +111,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): "templates": None, "rank": 5, "ignored_ids": [], - "vicinity": 0, + "vicinity": 3, } @classmethod @@ -233,7 +233,7 @@ def unserialize_in_worker(cls, kwargs): @classmethod def get_margin(cls, recording, kwargs): - margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) + margin = 2 * kwargs["vicinity"] return margin @classmethod @@ -371,11 +371,10 @@ def main_function(cls, traces, d): selection = all_selections[:, :num_selection] res_sps = full_sps[selection[0], selection[1]] - if True: # vicinity == 0: + if vicinity == 0: all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) all_amplitudes /= norms[selection[0]] else: - # This is not working, need to figure out why is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) all_amplitudes = np.append(all_amplitudes, np.float32(1)) L = M[is_in_vicinity, :][:, is_in_vicinity] diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 029a7f44b0..ffd4886610 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -219,10 +219,12 @@ def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) num_channels = waveforms.shape[2] - temporal_waveforms = to_temporal_representation(waveforms) - projected_temporal_waveforms = self.pca_model.transform(temporal_waveforms) - projected_waveforms = from_temporal_representation(projected_temporal_waveforms, num_channels) - + if len(waveforms) > 0: + temporal_waveforms = to_temporal_representation(waveforms) + projected_temporal_waveforms = self.pca_model.transform(temporal_waveforms) + projected_waveforms = from_temporal_representation(projected_temporal_waveforms, num_channels) + else: + projected_waveforms = np.zeros((0, self.pca_model.n_components, num_channels), dtype=np.float32) return projected_waveforms @@ -274,9 +276,12 @@ def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) """ num_channels = waveforms.shape[2] - temporal_waveform = to_temporal_representation(waveforms) - projected_temporal_waveforms = self.pca_model.transform(temporal_waveform) - temporal_denoised_waveforms = self.pca_model.inverse_transform(projected_temporal_waveforms) - denoised_waveforms = from_temporal_representation(temporal_denoised_waveforms, num_channels) + if len(waveforms) > 0: + temporal_waveform = to_temporal_representation(waveforms) + projected_temporal_waveforms = self.pca_model.transform(temporal_waveform) + temporal_denoised_waveforms = self.pca_model.inverse_transform(projected_temporal_waveforms) + denoised_waveforms = from_temporal_representation(temporal_denoised_waveforms, num_channels) + else: + denoised_waveforms = waveforms return denoised_waveforms From 088d4c4680eb95766ddf9ead46ee2397aba0e1d6 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 28 Feb 2024 11:40:03 +0100 Subject: [PATCH 156/192] WIP --- .../sorters/internal/spyking_circus2.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 394973b204..7d75223fb8 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -260,8 +260,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): shutil.rmtree(sorting_folder) folder_to_delete = None - cache_mode = params["cache_preprocessing"]["mode"] - delete_cache = params["cache_preprocessing"]["delete_cache"] + + if "mode" in params["cache_preprocessing"]: + cache_mode = params["cache_preprocessing"]["mode"] + else: + cache_mode = "memory" + + if "delete_cache" in params["cache_preprocessing"]: + delete_cache = params["cache_preprocessing"] + else: + delete_cache = True + if cache_mode in ["folder", "zarr"] and delete_cache: folder_to_delete = recording_f._kwargs["folder_path"] From e459f739dde3347a6429b7b46b6142158a12534a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 28 Feb 2024 13:25:37 +0100 Subject: [PATCH 157/192] Default no cache for Windows --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 7d75223fb8..ed06b679b4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -41,7 +41,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "clustering": {"legacy": False}, "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, - "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, + "cache_preprocessing": {"mode": None, "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, "job_kwargs": {"n_jobs": 0.8}, "debug": False, From 90089bb8ee788e8c58b5f8ed353ecfce39a03058 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 28 Feb 2024 14:05:57 +0100 Subject: [PATCH 158/192] rebirth of simple_sorter in spikeinterface --- .../sorters/internal/simplesorter.py | 190 ++++++++++++++++++ .../internal/tests/test_simplesorter.py | 15 ++ src/spikeinterface/sorters/sorterlist.py | 2 + .../waveforms/temporal_pca.py | 27 ++- 4 files changed, 224 insertions(+), 10 deletions(-) create mode 100644 src/spikeinterface/sorters/internal/simplesorter.py create mode 100644 src/spikeinterface/sorters/internal/tests/test_simplesorter.py diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py new file mode 100644 index 0000000000..762c8c5c0a --- /dev/null +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -0,0 +1,190 @@ +from .si_based import ComponentsBasedSorter + +from spikeinterface.core import load_extractor, BaseRecording, get_noise_levels, extract_waveforms, NumpySorting +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore + +import numpy as np + + +import pickle +import json + + +class SimpleSorter(ComponentsBasedSorter): + """ + Implementation of a very simple sorter usefull for teaching. + The idea is quite old school: + * detect peaks + * project waveforms with SVD or PCA + * apply a well known clustering algos from scikit-learn + + No template matching. No auto cleaning. + + Mainly usefull for few channels (1 to 8), teaching and testing. + """ + sorter_name = "simple" + + handle_multi_segment = True + + + _default_params = { + "apply_preprocessing": False, + "waveforms": {"ms_before": 1.0, "ms_after": 1.5}, + "filtering": {"freq_min": 300, "freq_max": 8000.0}, + "detection": {"peak_sign": "neg", "detect_threshold": 5.0, "exclude_sweep_ms": 0.4}, + "features": {"n_components": 3}, + "clustering" :{ + "method": "hdbscan", + "min_cluster_size": 25, + "allow_single_cluster": True, + "core_dist_n_jobs": -1, + "cluster_selection_method": "leaf", + }, + "job_kwargs": {"n_jobs": -1, "chunk_duration": "1s"}, + } + + @classmethod + def get_sorter_version(cls): + return "1.0" + + @classmethod + def _run_from_folder(cls, sorter_output_folder, params, verbose): + job_kwargs = params["job_kwargs"].copy() + job_kwargs = fix_job_kwargs(job_kwargs) + job_kwargs["progress_bar"] = verbose + + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing + + + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.sortingcomponents.peak_selection import select_peaks + from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection + from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + PeakRetriever, + ) + + from sklearn.decomposition import TruncatedSVD + + + + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + num_chans = recording_raw.get_num_channels() + sampling_frequency = recording_raw.get_sampling_frequency() + + # preprocessing + if params["apply_preprocessing"]: + recording = bandpass_filter(recording_raw, **params["filtering"], dtype="float32") + recording = zscore(recording) + noise_levels = np.ones(num_chans, dtype="float32") + else: + recording = recording_raw + noise_levels = get_noise_levels(recording, return_scaled=False) + + + # detection + detection_params = params["detection"].copy() + detection_params["noise_levels"] = noise_levels + peaks = detect_peaks(recording, method="locally_exclusive", **detection_params, **job_kwargs) + + if verbose: + print("We found %d peaks in total" % len(peaks)) + + ms_before = params["waveforms"]["ms_before"] + ms_after = params["waveforms"]["ms_after"] + + # SVD for time compression + few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) + + wfs = few_wfs[:, :, 0] + tsvd = TruncatedSVD(params["features"]["n_components"]) + tsvd.fit(wfs) + + model_folder = sorter_output_folder / "tsvd_model" + + model_folder.mkdir(exist_ok=True) + with open(model_folder / "pca_model.pkl", "wb") as f: + pickle.dump(tsvd, f) + + model_params = { + "ms_before": ms_before, + "ms_after": ms_after, + "sampling_frequency": float(sampling_frequency), + } + with open(model_folder / "params.json", "w") as f: + json.dump(model_params, f) + + # features + + features_folder = sorter_output_folder / "features" + node0 = PeakRetriever(recording, peaks) + + node1 = ExtractDenseWaveforms( + recording, + parents=[node0], + return_output=False, + ms_before=ms_before, + ms_after=ms_after, + ) + + model_folder_path = sorter_output_folder / "tsvd_model" + + node2 = TemporalPCAProjection( + recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder_path + ) + + pipeline_nodes = [node0, node1, node2] + + output = run_node_pipeline( + recording, + pipeline_nodes, + job_kwargs, + gather_mode="npy", + gather_kwargs=dict(exist_ok=True), + folder=features_folder, + names=["features_tsvd"], + ) + + features_tsvd = np.load(features_folder / "features_tsvd.npy") + features_flat = features_tsvd.reshape(features_tsvd.shape[0], -1) + + # run hdscan for clustering + + clust_params = params["clustering"].copy() + clust_method = clust_params.pop("method", "hdbscan") + + if clust_method == "hdbscan": + import hdbscan + out = hdbscan.hdbscan(features_flat, **clust_params) + peak_labels = out[0] + elif clust_method in ("kmeans"): + from sklearn.cluster import MiniBatchKMeans + peak_labels = MiniBatchKMeans(**clust_params).fit_predict(features_flat) + elif clust_method in ("mean_shift"): + from sklearn.cluster import MeanShift + peak_labels = MeanShift().fit_predict(features_flat) + elif clust_method in ("affinity_propagation"): + from sklearn.cluster import AffinityPropagation + peak_labels = AffinityPropagation().fit_predict(features_flat) + elif clust_method in ("gaussian_mixture"): + from sklearn.mixture import GaussianMixture + # n_components = clust_params["n_clusters"] + peak_labels = GaussianMixture(**clust_params).fit_predict(features_flat) + else: + raise ValueError(f"simple_sorter : unkown clustering method {clust_method}") + + + np.save(features_folder / "peak_labels.npy", peak_labels) + + # keep positive labels + keep = peak_labels >= 0 + sorting_final = NumpySorting.from_times_labels(peaks["sample_index"][keep], peak_labels[keep], sampling_frequency) + sorting_final = sorting_final.save(folder=sorter_output_folder / "sorting") + + return sorting_final diff --git a/src/spikeinterface/sorters/internal/tests/test_simplesorter.py b/src/spikeinterface/sorters/internal/tests/test_simplesorter.py new file mode 100644 index 0000000000..f3764806ab --- /dev/null +++ b/src/spikeinterface/sorters/internal/tests/test_simplesorter.py @@ -0,0 +1,15 @@ +import unittest + +from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite + +from spikeinterface.sorters import SimpleSorter + + +class SimpleSorterSorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): + SorterClass = SimpleSorter + + +if __name__ == "__main__": + test = SimpleSorterSorterCommonTestSuite() + test.setUp() + test.test_with_run() diff --git a/src/spikeinterface/sorters/sorterlist.py b/src/spikeinterface/sorters/sorterlist.py index 47557423f6..4b68c5f420 100644 --- a/src/spikeinterface/sorters/sorterlist.py +++ b/src/spikeinterface/sorters/sorterlist.py @@ -21,6 +21,7 @@ # based on spikeinertface.sortingcomponents from .internal.spyking_circus2 import Spykingcircus2Sorter from .internal.tridesclous2 import Tridesclous2Sorter +from .internal.simplesorter import SimpleSorter sorter_full_list = [ # external @@ -44,6 +45,7 @@ # internal Spykingcircus2Sorter, Tridesclous2Sorter, + SimpleSorter, ] sorter_dict = {s.sorter_name: s for s in sorter_full_list} diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 029a7f44b0..1a276b351b 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -191,11 +191,13 @@ class TemporalPCAProjection(TemporalPCBaseNode): """ def __init__( - self, recording: BaseRecording, parents: List[PipelineNode], model_folder_path: str, return_output=True + self, recording: BaseRecording, parents: List[PipelineNode], model_folder_path: str, dtype="float32", return_output=True ): TemporalPCBaseNode.__init__( self, recording=recording, parents=parents, return_output=return_output, model_folder_path=model_folder_path ) + self.n_components = self.pca_model.n_components + self.dtype = np.dtype(dtype) def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: """ @@ -218,12 +220,14 @@ def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) """ num_channels = waveforms.shape[2] + if waveforms.shape[0] > 0: + temporal_waveforms = to_temporal_representation(waveforms) + projected_temporal_waveforms = self.pca_model.transform(temporal_waveforms) + projected_waveforms = from_temporal_representation(projected_temporal_waveforms, num_channels) + else: + projected_waveforms = np.zeros((0, self.n_components, num_channels), dtype=self.dtype) - temporal_waveforms = to_temporal_representation(waveforms) - projected_temporal_waveforms = self.pca_model.transform(temporal_waveforms) - projected_waveforms = from_temporal_representation(projected_temporal_waveforms, num_channels) - - return projected_waveforms + return projected_waveforms.astype(self.dtype, copy=False) class TemporalPCADenoising(TemporalPCBaseNode): @@ -274,9 +278,12 @@ def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) """ num_channels = waveforms.shape[2] - temporal_waveform = to_temporal_representation(waveforms) - projected_temporal_waveforms = self.pca_model.transform(temporal_waveform) - temporal_denoised_waveforms = self.pca_model.inverse_transform(projected_temporal_waveforms) - denoised_waveforms = from_temporal_representation(temporal_denoised_waveforms, num_channels) + if waveforms.shape[0] > 0: + temporal_waveform = to_temporal_representation(waveforms) + projected_temporal_waveforms = self.pca_model.transform(temporal_waveform) + temporal_denoised_waveforms = self.pca_model.inverse_transform(projected_temporal_waveforms) + denoised_waveforms = from_temporal_representation(temporal_denoised_waveforms, num_channels) + else: + denoised_waveforms = np.zeros_like(waveforms) return denoised_waveforms From 586cf76b227dcd2f3a8f6002292e1b70293e9c02 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 28 Feb 2024 14:19:55 +0100 Subject: [PATCH 159/192] yep --- .../sorters/internal/simplesorter.py | 20 ++++++++++--------- .../waveforms/temporal_pca.py | 7 ++++++- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 762c8c5c0a..8de4b0e6ae 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -18,23 +18,23 @@ class SimpleSorter(ComponentsBasedSorter): * detect peaks * project waveforms with SVD or PCA * apply a well known clustering algos from scikit-learn - + No template matching. No auto cleaning. Mainly usefull for few channels (1 to 8), teaching and testing. """ + sorter_name = "simple" handle_multi_segment = True - _default_params = { "apply_preprocessing": False, "waveforms": {"ms_before": 1.0, "ms_after": 1.5}, "filtering": {"freq_min": 300, "freq_max": 8000.0}, "detection": {"peak_sign": "neg", "detect_threshold": 5.0, "exclude_sweep_ms": 0.4}, "features": {"n_components": 3}, - "clustering" :{ + "clustering": { "method": "hdbscan", "min_cluster_size": 25, "allow_single_cluster": True, @@ -57,7 +57,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing - from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection @@ -69,8 +68,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from sklearn.decomposition import TruncatedSVD - - recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) num_chans = recording_raw.get_num_channels() sampling_frequency = recording_raw.get_sampling_frequency() @@ -84,7 +81,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = recording_raw noise_levels = get_noise_levels(recording, return_scaled=False) - # detection detection_params = params["detection"].copy() detection_params["noise_levels"] = noise_levels @@ -161,30 +157,36 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if clust_method == "hdbscan": import hdbscan + out = hdbscan.hdbscan(features_flat, **clust_params) peak_labels = out[0] elif clust_method in ("kmeans"): from sklearn.cluster import MiniBatchKMeans + peak_labels = MiniBatchKMeans(**clust_params).fit_predict(features_flat) elif clust_method in ("mean_shift"): from sklearn.cluster import MeanShift + peak_labels = MeanShift().fit_predict(features_flat) elif clust_method in ("affinity_propagation"): from sklearn.cluster import AffinityPropagation + peak_labels = AffinityPropagation().fit_predict(features_flat) elif clust_method in ("gaussian_mixture"): from sklearn.mixture import GaussianMixture + # n_components = clust_params["n_clusters"] peak_labels = GaussianMixture(**clust_params).fit_predict(features_flat) else: raise ValueError(f"simple_sorter : unkown clustering method {clust_method}") - np.save(features_folder / "peak_labels.npy", peak_labels) # keep positive labels keep = peak_labels >= 0 - sorting_final = NumpySorting.from_times_labels(peaks["sample_index"][keep], peak_labels[keep], sampling_frequency) + sorting_final = NumpySorting.from_times_labels( + peaks["sample_index"][keep], peak_labels[keep], sampling_frequency + ) sorting_final = sorting_final.save(folder=sorter_output_folder / "sorting") return sorting_final diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 1a276b351b..b092fde5ce 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -191,7 +191,12 @@ class TemporalPCAProjection(TemporalPCBaseNode): """ def __init__( - self, recording: BaseRecording, parents: List[PipelineNode], model_folder_path: str, dtype="float32", return_output=True + self, + recording: BaseRecording, + parents: List[PipelineNode], + model_folder_path: str, + dtype="float32", + return_output=True, ): TemporalPCBaseNode.__init__( self, recording=recording, parents=parents, return_output=return_output, model_folder_path=model_folder_path From a9aa2f45aa4488cf9418f43cc20e3fa3fc509485 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 28 Feb 2024 15:35:39 +0100 Subject: [PATCH 160/192] typos --- .../sortingcomponents/peak_detection.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index f1b7e1b9ef..29df40e8c0 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -397,10 +397,10 @@ def check_params( if noise_levels is None: noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) - abs_threholds = noise_levels * detect_threshold + abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) - return (peak_sign, abs_threholds, exclude_sweep_size) + return (peak_sign, abs_thresholds, exclude_sweep_size) @classmethod def get_method_margin(cls, *args): @@ -408,12 +408,12 @@ def get_method_margin(cls, *args): return exclude_sweep_size @classmethod - def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size): + def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size): traces_center = traces[exclude_sweep_size:-exclude_sweep_size, :] length = traces_center.shape[0] if peak_sign in ("pos", "both"): - peak_mask = traces_center > abs_threholds[None, :] + peak_mask = traces_center > abs_thresholds[None, :] for i in range(exclude_sweep_size): peak_mask &= traces_center > traces[i : i + length, :] peak_mask &= ( @@ -424,7 +424,7 @@ def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size): if peak_sign == "both": peak_mask_pos = peak_mask.copy() - peak_mask = traces_center < -abs_threholds[None, :] + peak_mask = traces_center < -abs_thresholds[None, :] for i in range(exclude_sweep_size): peak_mask &= traces_center < traces[i : i + length, :] peak_mask &= ( @@ -489,10 +489,10 @@ def check_params( if noise_levels is None: noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) - abs_threholds = noise_levels * detect_threshold + abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) - return (peak_sign, abs_threholds, exclude_sweep_size, device, return_tensor) + return (peak_sign, abs_thresholds, exclude_sweep_size, device, return_tensor) @classmethod def get_method_margin(cls, *args): @@ -500,8 +500,8 @@ def get_method_margin(cls, *args): return exclude_sweep_size @classmethod - def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size, device, return_tensor): - sample_inds, chan_inds = _torch_detect_peaks(traces, peak_sign, abs_threholds, exclude_sweep_size, None, device) + def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size, device, return_tensor): + sample_inds, chan_inds = _torch_detect_peaks(traces, peak_sign, abs_thresholds, exclude_sweep_size, None, device) if not return_tensor: sample_inds = np.array(sample_inds.cpu()) chan_inds = np.array(chan_inds.cpu()) @@ -555,23 +555,23 @@ def get_method_margin(cls, *args): return exclude_sweep_size @classmethod - def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size, neighbours_mask): + def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask): assert HAVE_NUMBA, "You need to install numba" traces_center = traces[exclude_sweep_size:-exclude_sweep_size, :] if peak_sign in ("pos", "both"): - peak_mask = traces_center > abs_threholds[None, :] + peak_mask = traces_center > abs_thresholds[None, :] peak_mask = _numba_detect_peak_pos( - traces, traces_center, peak_mask, exclude_sweep_size, abs_threholds, peak_sign, neighbours_mask + traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask ) if peak_sign in ("neg", "both"): if peak_sign == "both": peak_mask_pos = peak_mask.copy() - peak_mask = traces_center < -abs_threholds[None, :] + peak_mask = traces_center < -abs_thresholds[None, :] peak_mask = _numba_detect_peak_neg( - traces, traces_center, peak_mask, exclude_sweep_size, abs_threholds, peak_sign, neighbours_mask + traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask ) if peak_sign == "both": @@ -641,9 +641,9 @@ def get_method_margin(cls, *args): return exclude_sweep_size @classmethod - def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size, device, return_tensor, neighbor_idxs): + def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size, device, return_tensor, neighbor_idxs): sample_inds, chan_inds = _torch_detect_peaks( - traces, peak_sign, abs_threholds, exclude_sweep_size, neighbor_idxs, device + traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbor_idxs, device ) if not return_tensor and isinstance(sample_inds, torch.Tensor) and isinstance(chan_inds, torch.Tensor): sample_inds = np.array(sample_inds.cpu()) @@ -655,7 +655,7 @@ def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size, devi @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_pos( - traces, traces_center, peak_mask, exclude_sweep_size, abs_threholds, peak_sign, neighbours_mask + traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask ): num_chans = traces_center.shape[1] for chan_ind in range(num_chans): @@ -680,7 +680,7 @@ def _numba_detect_peak_pos( @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_neg( - traces, traces_center, peak_mask, exclude_sweep_size, abs_threholds, peak_sign, neighbours_mask + traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask ): num_chans = traces_center.shape[1] for chan_ind in range(num_chans): @@ -857,12 +857,12 @@ def check_params( assert peak_sign in ("both", "neg", "pos") if noise_levels is None: noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) - abs_threholds = noise_levels * detect_threshold + abs_thresholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) channel_distance = get_channel_distances(recording) neighbours_mask = channel_distance <= radius_um - executor = OpenCLDetectPeakExecutor(abs_threholds, exclude_sweep_size, neighbours_mask, peak_sign) + executor = OpenCLDetectPeakExecutor(abs_thresholds, exclude_sweep_size, neighbours_mask, peak_sign) return (executor,) @@ -879,12 +879,12 @@ def detect_peaks(cls, traces, executor): class OpenCLDetectPeakExecutor: - def __init__(self, abs_threholds, exclude_sweep_size, neighbours_mask, peak_sign): + def __init__(self, abs_thresholds, exclude_sweep_size, neighbours_mask, peak_sign): import pyopencl self.chunk_size = None - self.abs_threholds = abs_threholds.astype("float32") + self.abs_thresholds = abs_thresholds.astype("float32") self.exclude_sweep_size = exclude_sweep_size self.neighbours_mask = neighbours_mask.astype("uint8") self.peak_sign = peak_sign @@ -909,7 +909,7 @@ def create_buffers_and_compile(self, chunk_size): self.neighbours_mask_cl = pyopencl.Buffer( self.ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=self.neighbours_mask ) - self.abs_threholds_cl = pyopencl.Buffer(self.ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=self.abs_threholds) + self.abs_thresholds_cl = pyopencl.Buffer(self.ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=self.abs_thresholds) num_channels = self.neighbours_mask.shape[0] self.traces_cl = pyopencl.Buffer(self.ctx, mf.READ_WRITE, size=int(chunk_size * num_channels * 4)) @@ -935,7 +935,7 @@ def create_buffers_and_compile(self, chunk_size): self.kern_detect_peaks = getattr(self.opencl_prg, "detect_peaks") self.kern_detect_peaks.set_args( - self.traces_cl, self.neighbours_mask_cl, self.abs_threholds_cl, self.peaks_cl, self.num_peaks_cl + self.traces_cl, self.neighbours_mask_cl, self.abs_thresholds_cl, self.peaks_cl, self.num_peaks_cl ) s = self.chunk_size - 2 * self.exclude_sweep_size @@ -989,7 +989,7 @@ def detect_peak(self, traces): //in __global float *traces, __global uchar *neighbours_mask, - __global float *abs_threholds, + __global float *abs_thresholds, //out __global st_peak *peaks, volatile __global int *num_peaks @@ -1023,11 +1023,11 @@ def detect_peak(self, traces): v = traces[index]; if(peak_sign==1){ - if (v>abs_threholds[chan]){peak=1;} + if (v>abs_thresholds[chan]){peak=1;} else {peak=0;} } else if(peak_sign==-1){ - if (v<-abs_threholds[chan]){peak=1;} + if (v<-abs_thresholds[chan]){peak=1;} else {peak=0;} } From c77c08a604aa1bc0b901d6d6ab346cd4734744b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:36:02 +0000 Subject: [PATCH 161/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/peak_detection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 29df40e8c0..ac176f0ca0 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -501,7 +501,9 @@ def get_method_margin(cls, *args): @classmethod def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size, device, return_tensor): - sample_inds, chan_inds = _torch_detect_peaks(traces, peak_sign, abs_thresholds, exclude_sweep_size, None, device) + sample_inds, chan_inds = _torch_detect_peaks( + traces, peak_sign, abs_thresholds, exclude_sweep_size, None, device + ) if not return_tensor: sample_inds = np.array(sample_inds.cpu()) chan_inds = np.array(chan_inds.cpu()) From bea6c1b83862f18a60371db43056ec0a36094395 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 28 Feb 2024 16:14:50 +0100 Subject: [PATCH 162/192] Conflict --- .../sortingcomponents/waveforms/temporal_pca.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 8dcd573cfb..74e4ce5466 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -229,13 +229,6 @@ def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) temporal_waveforms = to_temporal_representation(waveforms) projected_temporal_waveforms = self.pca_model.transform(temporal_waveforms) projected_waveforms = from_temporal_representation(projected_temporal_waveforms, num_channels) - else: - projected_waveforms = np.zeros((0, self.n_components, num_channels), dtype=self.dtype) - - if len(waveforms) > 0: - temporal_waveforms = to_temporal_representation(waveforms) - projected_temporal_waveforms = self.pca_model.transform(temporal_waveforms) - projected_waveforms = from_temporal_representation(projected_temporal_waveforms, num_channels) else: projected_waveforms = np.zeros((0, self.pca_model.n_components, num_channels), dtype=np.float32) return projected_waveforms.astype(self.dtype, copy=False) From 6142697b9bba6498874b5ac07c71c6997e0d2e3c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 28 Feb 2024 16:15:24 +0100 Subject: [PATCH 163/192] Conflict --- src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 74e4ce5466..ed7f6936de 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -230,7 +230,7 @@ def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) projected_temporal_waveforms = self.pca_model.transform(temporal_waveforms) projected_waveforms = from_temporal_representation(projected_temporal_waveforms, num_channels) else: - projected_waveforms = np.zeros((0, self.pca_model.n_components, num_channels), dtype=np.float32) + projected_waveforms = np.zeros((0, self.n_components, num_channels), dtype=np.float32) return projected_waveforms.astype(self.dtype, copy=False) From f82d76001a393c7e80cf666cc1021013516bd1e3 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 28 Feb 2024 16:15:52 +0100 Subject: [PATCH 164/192] Conflict --- src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index ed7f6936de..3a16ef1843 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -230,7 +230,7 @@ def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) projected_temporal_waveforms = self.pca_model.transform(temporal_waveforms) projected_waveforms = from_temporal_representation(projected_temporal_waveforms, num_channels) else: - projected_waveforms = np.zeros((0, self.n_components, num_channels), dtype=np.float32) + projected_waveforms = np.zeros((0, self.n_components, num_channels), dtype=self.dtype) return projected_waveforms.astype(self.dtype, copy=False) From fcfe8f8d2745de8b79d7a5cf7604d45b736cff25 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 29 Feb 2024 09:19:30 +0100 Subject: [PATCH 165/192] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ed06b679b4..a7e1e28c79 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -28,7 +28,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 50}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "sparsity": {"method": "ptp", "threshold": 0.25}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index f4059f7d35..054c357d0b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -55,7 +55,7 @@ class CircusClustering: "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "ptp", "threshold": 0.25}, - "radius_um": 50, + "radius_um": 100, "n_svd": [5, 10], "ms_before": 0.5, "ms_after": 0.5, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 9e17df3e49..2e2c080994 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -51,7 +51,7 @@ class RandomProjectionClustering: "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "ptp", "threshold": 0.25}, - "radius_um": 50, + "radius_um": 100, "nb_projections": 10, "ms_before": 0.5, "ms_after": 0.5, From 9c3d5ca8724dd93bf1732bd805298fddb3909dd3 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 29 Feb 2024 08:47:39 +0000 Subject: [PATCH 166/192] Fix sparsity check for old waveform extractor --- .../core/waveforms_extractor_backwards_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index f10454e085..aff620c4c5 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -362,7 +362,7 @@ def _read_old_waveforms_extractor_binary(folder): params = json.load(f) sparsity_file = folder / "sparsity.json" - if params_file.exists(): + if sparsity_file.exists(): with open(sparsity_file, "r") as f: sparsity_dict = json.load(f) sparsity = ChannelSparsity.from_dict(sparsity_dict) From 4477f7496e5a758b80fdb4fb0c99f0f2a90533c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Feb 2024 08:52:05 +0000 Subject: [PATCH 167/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/modules_gallery/core/plot_4_sorting_analyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modules_gallery/core/plot_4_sorting_analyzer.py b/examples/modules_gallery/core/plot_4_sorting_analyzer.py index 864f11ad1d..20dc078197 100644 --- a/examples/modules_gallery/core/plot_4_sorting_analyzer.py +++ b/examples/modules_gallery/core/plot_4_sorting_analyzer.py @@ -18,7 +18,7 @@ * "noise_levels" : compute noise level from traces (usefull to get snr of units) * can be in memory or persistent to disk (2 formats binary/npy or zarr) -More extesions are available in `spikeinterface.postprocessing` like "principal_components", "spike_amplitudes", +More extesions are available in `spikeinterface.postprocessing` like "principal_components", "spike_amplitudes", "unit_lcations", ... From c473277c32f56f0f3b54a0a1aa02b12f876ff1aa Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 29 Feb 2024 10:28:17 +0100 Subject: [PATCH 168/192] Playing with simple sorter --- src/spikeinterface/sorters/internal/simplesorter.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 8de4b0e6ae..a18047f787 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -50,9 +50,9 @@ def get_sorter_version(cls): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - job_kwargs = params["job_kwargs"].copy() + job_kwargs = params["job_kwargs"] job_kwargs = fix_job_kwargs(job_kwargs) - job_kwargs["progress_bar"] = verbose + job_kwargs.update({"verbose": verbose, "progress_bar": verbose}) from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing @@ -144,6 +144,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): gather_mode="npy", gather_kwargs=dict(exist_ok=True), folder=features_folder, + job_name="extracting features", names=["features_tsvd"], ) @@ -157,25 +158,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if clust_method == "hdbscan": import hdbscan - out = hdbscan.hdbscan(features_flat, **clust_params) peak_labels = out[0] elif clust_method in ("kmeans"): from sklearn.cluster import MiniBatchKMeans - peak_labels = MiniBatchKMeans(**clust_params).fit_predict(features_flat) elif clust_method in ("mean_shift"): from sklearn.cluster import MeanShift - peak_labels = MeanShift().fit_predict(features_flat) elif clust_method in ("affinity_propagation"): from sklearn.cluster import AffinityPropagation - peak_labels = AffinityPropagation().fit_predict(features_flat) elif clust_method in ("gaussian_mixture"): from sklearn.mixture import GaussianMixture - - # n_components = clust_params["n_clusters"] peak_labels = GaussianMixture(**clust_params).fit_predict(features_flat) else: raise ValueError(f"simple_sorter : unkown clustering method {clust_method}") From 88676e55fce30f71f482906a2cb70182f428a46a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Feb 2024 09:29:06 +0000 Subject: [PATCH 169/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/simplesorter.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index a18047f787..e6afb745fe 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -158,19 +158,24 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if clust_method == "hdbscan": import hdbscan + out = hdbscan.hdbscan(features_flat, **clust_params) peak_labels = out[0] elif clust_method in ("kmeans"): from sklearn.cluster import MiniBatchKMeans + peak_labels = MiniBatchKMeans(**clust_params).fit_predict(features_flat) elif clust_method in ("mean_shift"): from sklearn.cluster import MeanShift + peak_labels = MeanShift().fit_predict(features_flat) elif clust_method in ("affinity_propagation"): from sklearn.cluster import AffinityPropagation + peak_labels = AffinityPropagation().fit_predict(features_flat) elif clust_method in ("gaussian_mixture"): from sklearn.mixture import GaussianMixture + peak_labels = GaussianMixture(**clust_params).fit_predict(features_flat) else: raise ValueError(f"simple_sorter : unkown clustering method {clust_method}") From 7527d3e842e6d8e8b9181bf48a8946a15c39e720 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 29 Feb 2024 17:14:27 +0100 Subject: [PATCH 170/192] WIP --- src/spikeinterface/sortingcomponents/matching/circus.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 95ff4b2d40..3d6d2fb2db 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -233,7 +233,10 @@ def unserialize_in_worker(cls, kwargs): @classmethod def get_margin(cls, recording, kwargs): - margin = 2 * kwargs["vicinity"] + if kwargs['vicinity'] > 0: + margin = kwargs["vicinity"] + else: + margin = 2 * kwargs["num_samples"] return margin @classmethod From 8755144bd16a72cdc8618f30baa56ddead486dae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Feb 2024 16:17:53 +0000 Subject: [PATCH 171/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 3d6d2fb2db..ecf63f973e 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -233,7 +233,7 @@ def unserialize_in_worker(cls, kwargs): @classmethod def get_margin(cls, recording, kwargs): - if kwargs['vicinity'] > 0: + if kwargs["vicinity"] > 0: margin = kwargs["vicinity"] else: margin = 2 * kwargs["num_samples"] From 50cca530a1bd29c51e42acfa9dab79a7f2d550e5 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 1 Mar 2024 08:35:55 +0100 Subject: [PATCH 172/192] WIP --- .../sorters/internal/spyking_circus2.py | 4 +- .../benchmark/benchmark_tools.py | 2 +- .../clustering/random_projections.py | 9 +- .../sortingcomponents/features_from_peaks.py | 187 ++---------------- 4 files changed, 29 insertions(+), 173 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a7e1e28c79..e087ab5b20 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -30,7 +30,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "sparsity": {"method": "ptp", "threshold": 0.25}, - "filtering": {"freq_min": 150, "dtype": "float32"}, + "filtering": {"freq_min": 150}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { "method": "smart_sampling_amplitudes", @@ -102,7 +102,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - recording_f = highpass_filter(recording, **filtering_params) + recording_f = highpass_filter(recording, **filtering_params, dtype="float32") if num_channels > 1: recording_f = common_reference(recording_f) else: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 5f23fab255..969b72d72a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -315,7 +315,7 @@ def _save_keys(self, saved_keys, folder): with open(folder / f"{k}.pickle", mode="wb") as f: pickle.dump(self.result[k], f) elif format == "sorting": - self.result[k].save(folder=folder / k, format="numpy_folder") + self.result[k].save(folder=folder / k, format="numpy_folder", overwrite=True) elif format == "zarr_templates": self.result[k].to_zarr(folder / k) elif format == "sorting_analyzer": diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 2e2c080994..8ed881fc6d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -24,7 +24,7 @@ from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser -from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature +from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature, RandomProjectionsEnergyFeature from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates @@ -111,12 +111,17 @@ def main_function(cls, recording, peaks, params): nafter = int(params["ms_after"] * fs / 1000) nsamples = nbefore + nafter - node3 = RandomProjectionsFeature( + noise_ptps = np.linalg.norm(np.random.randn(1000, nsamples), axis=1) + noise_threshold = np.mean(noise_ptps) + 3*np.std(noise_ptps) + print(noise_threshold) + + node3 = RandomProjectionsEnergyFeature( recording, parents=[node0, node2], return_output=True, projections=projections, radius_um=params["radius_um"], + noise_threshold=noise_threshold, sparse=True, ) diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index 4479319fe3..1b70546964 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -144,50 +144,6 @@ def compute(self, traces, peaks, waveforms): return all_ptps -class PeakToPeakLagsFeature(PipelineNode): - def __init__( - self, - recording, - name="ptp_lag_feature", - return_output=True, - parents=None, - radius_um=150.0, - all_channels=True, - ): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - - self.all_channels = all_channels - self.radius_um = radius_um - - self.contact_locations = recording.get_channel_locations() - self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance <= radius_um - - self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) - self._dtype = recording.get_dtype() - - def get_dtype(self): - return self._dtype - - def compute(self, traces, peaks, waveforms): - if self.all_channels: - all_maxs = np.argmax(waveforms, axis=1) - all_mins = np.argmin(waveforms, axis=1) - all_lags = all_maxs - all_mins - else: - all_lags = np.zeros(peaks.size) - for main_chan in np.unique(peaks["channel_index"]): - (idx,) = np.nonzero(peaks["channel_index"] == main_chan) - (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) - wfs = waveforms[idx][:, :, chan_inds] - maxs = np.argmax(wfs, axis=1) - mins = np.argmin(wfs, axis=1) - lags = maxs - mins - ptps = np.argmax(np.ptp(wfs, axis=1), axis=1) - all_lags[idx] = lags[np.arange(len(idx)), ptps] - return all_lags - - class RandomProjectionsFeature(PipelineNode): def __init__( self, @@ -196,30 +152,25 @@ def __init__( return_output=True, parents=None, projections=None, - sigmoid=None, radius_um=None, sparse=True, + noise_threshold=None ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.projections = projections - self.sigmoid = sigmoid self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance <= radius_um self.radius_um = radius_um self.sparse = sparse - self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse)) + self.noise_threshold = noise_threshold + self._kwargs.update(dict(projections=projections, radius_um=radius_um, sparse=sparse, noise_threshold=noise_threshold)) self._dtype = recording.get_dtype() def get_dtype(self): return self._dtype - def _sigmoid(self, x): - L, x0, k, b = self.sigmoid - y = L / (1 + np.exp(-k * (x - x0))) + b - return y - def compute(self, traces, peaks, waveforms): all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype) @@ -232,8 +183,9 @@ def compute(self, traces, peaks, waveforms): else: wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) - if self.sigmoid is not None: - wf_ptp *= self._sigmoid(wf_ptp) + if self.noise_threshold is not None: + local_map = np.median(wf_ptp, axis=0) < self.noise_threshold + wf_ptp[wf_ptp < local_map] = 0 denom = np.sum(wf_ptp, axis=1) mask = denom != 0 @@ -252,17 +204,20 @@ def __init__( projections=None, radius_um=150.0, min_values=None, + sparse=True, + noise_threshold=None ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance <= radius_um - + self.sparse = sparse + self.noise_threshold = noise_threshold self.projections = projections self.min_values = min_values self.radius_um = radius_um - self._kwargs.update(dict(projections=projections, min_values=min_values, radius_um=radius_um)) + self._kwargs.update(dict(projections=projections, radius_um=radius_um, sparse=sparse, noise_threshold=noise_threshold)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -274,10 +229,14 @@ def compute(self, traces, peaks, waveforms): (idx,) = np.nonzero(peaks["channel_index"] == main_chan) (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] - energies = np.linalg.norm(waveforms[idx][:, :, chan_inds], axis=1) - - if self.min_values is not None: - energies = (energies / self.min_values[chan_inds]) ** 4 + if self.sparse: + energies = np.linalg.norm(waveforms[idx][:, :, :len(chan_inds)], axis=1) + else: + energies = np.linalg.norm(waveforms[idx][:, :, chan_inds], axis=1) + + if self.noise_threshold is not None: + local_map = np.median(energies, axis=0) < self.noise_threshold + energies[energies < local_map] = 0 denom = np.sum(energies, axis=1) mask = denom != 0 @@ -286,117 +245,9 @@ def compute(self, traces, peaks, waveforms): return all_projections -class StdPeakToPeakFeature(PipelineNode): - def __init__(self, recording, name="std_ptp_feature", return_output=True, parents=None, radius_um=150.0): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - - self.contact_locations = recording.get_channel_locations() - self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance <= radius_um - - self._kwargs.update(dict(radius_um=radius_um)) - - self._dtype = recording.get_dtype() - - def get_dtype(self): - return self._dtype - - def compute(self, traces, peaks, waveforms): - all_ptps = np.zeros(peaks.size) - for main_chan in np.unique(peaks["channel_index"]): - (idx,) = np.nonzero(peaks["channel_index"] == main_chan) - (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) - wfs = waveforms[idx][:, :, chan_inds] - all_ptps[idx] = np.std(np.ptp(wfs, axis=1), axis=1) - return all_ptps - - -class GlobalPeakToPeakFeature(PipelineNode): - def __init__(self, recording, name="global_ptp_feature", return_output=True, parents=None, radius_um=150.0): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - - self.contact_locations = recording.get_channel_locations() - self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance <= radius_um - - self._kwargs.update(dict(radius_um=radius_um)) - - self._dtype = recording.get_dtype() - - def get_dtype(self): - return self._dtype - - def compute(self, traces, peaks, waveforms): - all_ptps = np.zeros(peaks.size) - for main_chan in np.unique(peaks["channel_index"]): - (idx,) = np.nonzero(peaks["channel_index"] == main_chan) - (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) - wfs = waveforms[idx][:, :, chan_inds] - all_ptps[idx] = np.max(wfs, axis=(1, 2)) - np.min(wfs, axis=(1, 2)) - return all_ptps - - -class KurtosisPeakToPeakFeature(PipelineNode): - def __init__(self, recording, name="kurtosis_ptp_feature", return_output=True, parents=None, radius_um=150.0): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - - self.contact_locations = recording.get_channel_locations() - self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance <= radius_um - - self._kwargs.update(dict(radius_um=radius_um)) - - self._dtype = recording.get_dtype() - - def get_dtype(self): - return self._dtype - - def compute(self, traces, peaks, waveforms): - all_ptps = np.zeros(peaks.size) - import scipy - - for main_chan in np.unique(peaks["channel_index"]): - (idx,) = np.nonzero(peaks["channel_index"] == main_chan) - (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) - wfs = waveforms[idx][:, :, chan_inds] - all_ptps[idx] = scipy.stats.kurtosis(np.ptp(wfs, axis=1), axis=1) - return all_ptps - - -class EnergyFeature(PipelineNode): - def __init__(self, recording, name="energy_feature", return_output=True, parents=None, radius_um=50.0): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - - self.contact_locations = recording.get_channel_locations() - self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance <= radius_um - - self._kwargs.update(dict(radius_um=radius_um)) - - def get_dtype(self): - return np.dtype("float32") - - def compute(self, traces, peaks, waveforms): - energy = np.zeros(peaks.size, dtype="float32") - for main_chan in np.unique(peaks["channel_index"]): - (idx,) = np.nonzero(peaks["channel_index"] == main_chan) - (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) - - wfs = waveforms[idx][:, :, chan_inds] - energy[idx] = np.linalg.norm(wfs, axis=(1, 2)) / chan_inds.size - return energy - - _features_class = { "amplitude": AmplitudeFeature, "ptp": PeakToPeakFeature, - "center_of_mass": LocalizeCenterOfMass, - "monopolar_triangulation": LocalizeMonopolarTriangulation, - "energy": EnergyFeature, - "std_ptp": StdPeakToPeakFeature, - "kurtosis_ptp": KurtosisPeakToPeakFeature, "random_projections_ptp": RandomProjectionsFeature, "random_projections_energy": RandomProjectionsEnergyFeature, - "ptp_lag": PeakToPeakLagsFeature, - "global_ptp": GlobalPeakToPeakFeature, } From 00340edf54e47e658ab46bd9b3e2e004f9c9653c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Mar 2024 07:36:18 +0000 Subject: [PATCH 173/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../clustering/random_projections.py | 7 +++++-- .../sortingcomponents/features_from_peaks.py | 18 +++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 8ed881fc6d..52197e04ab 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -24,7 +24,10 @@ from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser -from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature, RandomProjectionsEnergyFeature +from spikeinterface.sortingcomponents.features_from_peaks import ( + RandomProjectionsFeature, + RandomProjectionsEnergyFeature, +) from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates @@ -112,7 +115,7 @@ def main_function(cls, recording, peaks, params): nsamples = nbefore + nafter noise_ptps = np.linalg.norm(np.random.randn(1000, nsamples), axis=1) - noise_threshold = np.mean(noise_ptps) + 3*np.std(noise_ptps) + noise_threshold = np.mean(noise_ptps) + 3 * np.std(noise_ptps) print(noise_threshold) node3 = RandomProjectionsEnergyFeature( diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index 1b70546964..5cdb62020b 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -154,7 +154,7 @@ def __init__( projections=None, radius_um=None, sparse=True, - noise_threshold=None + noise_threshold=None, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) @@ -165,7 +165,9 @@ def __init__( self.radius_um = radius_um self.sparse = sparse self.noise_threshold = noise_threshold - self._kwargs.update(dict(projections=projections, radius_um=radius_um, sparse=sparse, noise_threshold=noise_threshold)) + self._kwargs.update( + dict(projections=projections, radius_um=radius_um, sparse=sparse, noise_threshold=noise_threshold) + ) self._dtype = recording.get_dtype() def get_dtype(self): @@ -204,8 +206,8 @@ def __init__( projections=None, radius_um=150.0, min_values=None, - sparse=True, - noise_threshold=None + sparse=True, + noise_threshold=None, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) @@ -217,7 +219,9 @@ def __init__( self.projections = projections self.min_values = min_values self.radius_um = radius_um - self._kwargs.update(dict(projections=projections, radius_um=radius_um, sparse=sparse, noise_threshold=noise_threshold)) + self._kwargs.update( + dict(projections=projections, radius_um=radius_um, sparse=sparse, noise_threshold=noise_threshold) + ) self._dtype = recording.get_dtype() def get_dtype(self): @@ -230,10 +234,10 @@ def compute(self, traces, peaks, waveforms): (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] if self.sparse: - energies = np.linalg.norm(waveforms[idx][:, :, :len(chan_inds)], axis=1) + energies = np.linalg.norm(waveforms[idx][:, :, : len(chan_inds)], axis=1) else: energies = np.linalg.norm(waveforms[idx][:, :, chan_inds], axis=1) - + if self.noise_threshold is not None: local_map = np.median(energies, axis=0) < self.noise_threshold energies[energies < local_map] = 0 From 8a2a23ca9b64b958bb024742a9ccabe88eccc6bb Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 1 Mar 2024 08:50:40 +0100 Subject: [PATCH 174/192] Cleaning --- .../clustering/random_projections.py | 7 +- .../sortingcomponents/features_from_peaks.py | 80 +++++-------------- 2 files changed, 22 insertions(+), 65 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 8ed881fc6d..4206f4b9fd 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -24,7 +24,7 @@ from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser -from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature, RandomProjectionsEnergyFeature +from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates @@ -113,15 +113,14 @@ def main_function(cls, recording, peaks, params): noise_ptps = np.linalg.norm(np.random.randn(1000, nsamples), axis=1) noise_threshold = np.mean(noise_ptps) + 3*np.std(noise_ptps) - print(noise_threshold) - node3 = RandomProjectionsEnergyFeature( + node3 = RandomProjectionsFeature( recording, parents=[node0, node2], return_output=True, projections=projections, radius_um=params["radius_um"], - noise_threshold=noise_threshold, + noise_threshold=None, sparse=True, ) diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index 1b70546964..cc3d305c90 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -149,23 +149,27 @@ def __init__( self, recording, name="random_projections_feature", + feature="ptp", return_output=True, parents=None, projections=None, - radius_um=None, + radius_um=100, sparse=True, noise_threshold=None ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) + assert feature in ['ptp', 'energy'] self.projections = projections + self.feature = feature self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance <= radius_um self.radius_um = radius_um self.sparse = sparse self.noise_threshold = noise_threshold - self._kwargs.update(dict(projections=projections, radius_um=radius_um, sparse=sparse, noise_threshold=noise_threshold)) + self._kwargs.update(dict(projections=projections, radius_um=radius_um, sparse=sparse, + noise_threshold=noise_threshold, feature=feature)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -179,75 +183,29 @@ def compute(self, traces, peaks, waveforms): (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] if self.sparse: - wf_ptp = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1) + if self.feature == 'ptp': + features = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1) + elif self.feature == 'energy': + features = np.linalg.norm(waveforms[idx][:, :, : len(chan_inds)], axis=1) else: - wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) + if self.feature == 'ptp': + features = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) + elif self.feature == 'energy': + features = np.linalg.norm(waveforms[idx][:, :, chan_inds], axis=1) if self.noise_threshold is not None: - local_map = np.median(wf_ptp, axis=0) < self.noise_threshold - wf_ptp[wf_ptp < local_map] = 0 + local_map = np.median(features, axis=0) < self.noise_threshold + features[features < local_map] = 0 - denom = np.sum(wf_ptp, axis=1) + denom = np.sum(features, axis=1) mask = denom != 0 - all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis]) + all_projections[idx[mask]] = np.dot(features[mask], local_projections) / (denom[mask][:, np.newaxis]) return all_projections -class RandomProjectionsEnergyFeature(PipelineNode): - def __init__( - self, - recording, - name="random_projections_energy_feature", - return_output=True, - parents=None, - projections=None, - radius_um=150.0, - min_values=None, - sparse=True, - noise_threshold=None - ): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - - self.contact_locations = recording.get_channel_locations() - self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance <= radius_um - self.sparse = sparse - self.noise_threshold = noise_threshold - self.projections = projections - self.min_values = min_values - self.radius_um = radius_um - self._kwargs.update(dict(projections=projections, radius_um=radius_um, sparse=sparse, noise_threshold=noise_threshold)) - self._dtype = recording.get_dtype() - - def get_dtype(self): - return self._dtype - - def compute(self, traces, peaks, waveforms): - all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype) - for main_chan in np.unique(peaks["channel_index"]): - (idx,) = np.nonzero(peaks["channel_index"] == main_chan) - (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) - local_projections = self.projections[chan_inds, :] - if self.sparse: - energies = np.linalg.norm(waveforms[idx][:, :, :len(chan_inds)], axis=1) - else: - energies = np.linalg.norm(waveforms[idx][:, :, chan_inds], axis=1) - - if self.noise_threshold is not None: - local_map = np.median(energies, axis=0) < self.noise_threshold - energies[energies < local_map] = 0 - - denom = np.sum(energies, axis=1) - mask = denom != 0 - - all_projections[idx[mask]] = np.dot(energies[mask], local_projections) / (denom[mask][:, np.newaxis]) - return all_projections - - _features_class = { "amplitude": AmplitudeFeature, "ptp": PeakToPeakFeature, - "random_projections_ptp": RandomProjectionsFeature, - "random_projections_energy": RandomProjectionsEnergyFeature, + "random_projections": RandomProjectionsFeature, } From e156f0a92f6e5e0418c1f33fb3a02d29c62349a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Mar 2024 07:52:20 +0000 Subject: [PATCH 175/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../clustering/random_projections.py | 4 ++-- .../sortingcomponents/features_from_peaks.py | 21 ++++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index bf18cc05e6..31d25222a9 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -111,8 +111,8 @@ def main_function(cls, recording, peaks, params): nafter = int(params["ms_after"] * fs / 1000) nsamples = nbefore + nafter - #noise_ptps = np.linalg.norm(np.random.randn(1000, nsamples), axis=1) - #noise_threshold = np.mean(noise_ptps) + 3 * np.std(noise_ptps) + # noise_ptps = np.linalg.norm(np.random.randn(1000, nsamples), axis=1) + # noise_threshold = np.mean(noise_ptps) + 3 * np.std(noise_ptps) node3 = RandomProjectionsFeature( recording, diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index 4bef8aa3ba..8770dec6f9 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -159,7 +159,7 @@ def __init__( ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - assert feature in ['ptp', 'energy'] + assert feature in ["ptp", "energy"] self.projections = projections self.feature = feature self.contact_locations = recording.get_channel_locations() @@ -168,8 +168,15 @@ def __init__( self.radius_um = radius_um self.sparse = sparse self.noise_threshold = noise_threshold - self._kwargs.update(dict(projections=projections, radius_um=radius_um, sparse=sparse, - noise_threshold=noise_threshold, feature=feature)) + self._kwargs.update( + dict( + projections=projections, + radius_um=radius_um, + sparse=sparse, + noise_threshold=noise_threshold, + feature=feature, + ) + ) self._dtype = recording.get_dtype() def get_dtype(self): @@ -183,14 +190,14 @@ def compute(self, traces, peaks, waveforms): (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] if self.sparse: - if self.feature == 'ptp': + if self.feature == "ptp": features = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1) - elif self.feature == 'energy': + elif self.feature == "energy": features = np.linalg.norm(waveforms[idx][:, :, : len(chan_inds)], axis=1) else: - if self.feature == 'ptp': + if self.feature == "ptp": features = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) - elif self.feature == 'energy': + elif self.feature == "energy": features = np.linalg.norm(waveforms[idx][:, :, chan_inds], axis=1) if self.noise_threshold is not None: From 1a5eac346e544aa42213bacaab2c99f9cfe00e03 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 1 Mar 2024 09:28:19 +0100 Subject: [PATCH 176/192] Add the cache for simple sorter --- .../sorters/internal/simplesorter.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index e6afb745fe..3f2ada3085 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -2,6 +2,7 @@ from spikeinterface.core import load_extractor, BaseRecording, get_noise_levels, extract_waveforms, NumpySorting from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.sortingcomponents.tools import cache_preprocessing from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore import numpy as np @@ -41,6 +42,7 @@ class SimpleSorter(ComponentsBasedSorter): "core_dist_n_jobs": -1, "cluster_selection_method": "leaf", }, + "cache_preprocessing": {"mode": None, "memory_limit": 0.5, "delete_cache": True}, "job_kwargs": {"n_jobs": -1, "chunk_duration": "1s"}, } @@ -81,6 +83,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = recording_raw noise_levels = get_noise_levels(recording, return_scaled=False) + recording = cache_preprocessing(recording, **job_kwargs, **params["cache_preprocessing"]) + # detection detection_params = params["detection"].copy() detection_params["noise_levels"] = noise_levels @@ -182,6 +186,25 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(features_folder / "peak_labels.npy", peak_labels) + folder_to_delete = None + + if "mode" in params["cache_preprocessing"]: + cache_mode = params["cache_preprocessing"]["mode"] + else: + cache_mode = "memory" + + if "delete_cache" in params["cache_preprocessing"]: + delete_cache = params["cache_preprocessing"] + else: + delete_cache = True + + if cache_mode in ["folder", "zarr"] and delete_cache: + folder_to_delete = recording._kwargs["folder_path"] + + del recording + if folder_to_delete is not None: + shutil.rmtree(folder_to_delete) + # keep positive labels keep = peak_labels >= 0 sorting_final = NumpySorting.from_times_labels( From 9c3108c369f15ef040ff67ed8f4bc33616ca6837 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 1 Mar 2024 14:17:58 +0100 Subject: [PATCH 177/192] Fixes in benchmarks --- .../benchmark/benchmark_clustering.py | 21 +- .../benchmark/benchmark_matching.py | 8 +- .../benchmark/benchmark_peak_detection.py | 435 ++++++++++++++++++ .../benchmark/benchmark_peak_selection.py | 19 +- 4 files changed, 453 insertions(+), 30 deletions(-) create mode 100644 src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index b8afc813ab..cc30862180 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -36,8 +36,7 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.indices = indices sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) - sorting_analyzer.compute("random_spikes") - ext = sorting_analyzer.compute("fast_templates") + sorting_analyzer.compute(["random_spikes", "fast_templates"]) extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) @@ -140,10 +139,10 @@ def plot_agreements(self, case_keys=None, figsize=(15, 15)): if case_keys is None: case_keys = list(self.cases.keys()) - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) for count, key in enumerate(case_keys): - ax = axs[count] + ax = axs[0, count] ax.set_title(self.cases[key]["label"]) plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) @@ -174,7 +173,7 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) for count, key in enumerate(case_keys): @@ -201,14 +200,14 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): axs[count].set_title(metric) fig.colorbar(im, ax=axs[count]) label = self.cases[key]["label"] - axs[count].set_title(label) + axs[0, count].set_title(label) def plot_metrics_vs_snr(self, metric="cosine", case_keys=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) for count, key in enumerate(case_keys): @@ -238,11 +237,11 @@ def plot_metrics_vs_snr(self, metric="cosine", case_keys=None, figsize=(15, 5)): to_plot = [] for found, real in zip(inds_2, inds_1): to_plot += [distances[real, found]] - axs[count].plot(snr, to_plot, ".") - axs[count].set_xlabel("snr") - axs[count].set_ylabel(metric) + axs[0, count].plot(snr, to_plot, ".") + axs[0, count].set_xlabel("snr") + axs[0, count].set_ylabel(metric) label = self.cases[key]["label"] - axs[count].set_title(label) + axs[0, count].set_title(label) def plot_comparison_clustering( self, diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index ffecbe028f..bb6d0f7683 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -70,10 +70,10 @@ def plot_agreements(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) for count, key in enumerate(case_keys): - ax = axs[count] + ax = axs[0, count] ax.set_title(self.cases[key]["label"]) plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) @@ -103,14 +103,14 @@ def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize) + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) for count, key in enumerate(case_keys): templates_array = self.get_result(key)["templates"].templates_array plot_comparison_collision_by_similarity( self.get_result(key)["gt_collision"], templates_array, - ax=axs[count], + ax=axs[0, count], show_legend=True, mode="lines", good_only=False, diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py new file mode 100644 index 0000000000..421255bee2 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py @@ -0,0 +1,435 @@ +from __future__ import annotations + +from spikeinterface.preprocessing import bandpass_filter, common_reference +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.core import NumpySorting +from spikeinterface.qualitymetrics import compute_quality_metrics +from spikeinterface.comparison import GroundTruthComparison +from spikeinterface.widgets import ( + plot_probe_map, + plot_agreement_matrix, + plot_comparison_collision_by_similarity, + plot_unit_templates, + plot_unit_waveforms, +) +from spikeinterface.comparison.comparisontools import make_matching_events +from spikeinterface.core import get_noise_levels + +import time +import string, random +import pylab as plt +import os +import numpy as np + +from .benchmark_tools import BenchmarkStudy, Benchmark +from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from spikeinterface.core.template_tools import get_template_extremum_channel + + +class PeakDetectionBenchmark(Benchmark): + + def __init__(self, recording, gt_sorting, params, exhaustive_gt=True): + self.recording = recording + self.gt_sorting = gt_sorting + + sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) + sorting_analyzer.compute(["random_spikes", "fast_templates"]) + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + self.gt_peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + self.params = params + self.exhaustive_gt = exhaustive_gt + self.method = params["method"] + self.method_kwargs = params["method_kwargs"] + self.result = {"gt_peaks" : self.gt_peaks} + + def run(self, **job_kwargs): + peaks = detect_peaks( + self.recording, method=self.method, **self.method_kwargs, **job_kwargs + ) + self.result["peaks"] = peaks + + def compute_result(self, **result_params): + spikes = self.result["peaks"] + self.result["peak_on_channels"] = NumpySorting.from_peaks(spikes, + self.recording.sampling_frequency, + unit_ids=self.recording.channel_ids) + spikes = self.result["gt_peaks"] + self.result["gt_on_channels"] = NumpySorting.from_peaks(spikes, + self.recording.sampling_frequency, + unit_ids=self.recording.channel_ids) + + self.result["gt_comparison"] = GroundTruthComparison( + self.result["gt_on_channels"], self.result["peak_on_channels"], exhaustive_gt=self.exhaustive_gt + ) + + gt_peaks = self.gt_sorting.to_spike_vector() + times1 = self.result["gt_peaks"]["sample_index"] + times2 = self.result["peaks"]["sample_index"] + + print("The gt recording has {} peaks and {} have been detected".format(len(times1), len(times2))) + + matches = make_matching_events(times1, times2, int(0.4 * self.recording.sampling_frequency / 1000)) + self.matches = matches + self.gt_matches = matches["index1"] + + self.deltas = {"labels": [], "channels" : [], "delta": matches["delta_frame"]} + self.deltas["labels"] = gt_peaks["unit_index"][self.gt_matches] + self.deltas["channels"] = self.result["gt_peaks"]["unit_index"][self.gt_matches] + + self.result["sliced_gt_sorting"] = NumpySorting(gt_peaks[self.gt_matches], + self.recording.sampling_frequency, + self.gt_sorting.unit_ids) + + ratio = 100 * len(self.gt_matches) / len(times1) + print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) + + # matches = make_matching_events(times2, times1, int(delta * self.sampling_rate / 1000)) + # self.good_matches = matches["index1"] + + # garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) + # garbage_channels = self.peaks["channel_index"][garbage_matches] + # garbage_peaks = times2[garbage_matches] + # nb_garbage = len(garbage_peaks) + + # ratio = 100 * len(garbage_peaks) / len(times2) + # self.garbage_sorting = NumpySorting.from_times_labels(garbage_peaks, garbage_channels, self.sampling_rate) + + # print("The peaks have {0:.2f}% of garbage (without gt around)".format(ratio)) + + _run_key_saved = [("peaks", "npy"), + ("gt_peaks", "npy")] + + _result_key_saved = [ + ("gt_comparison", "pickle"), + ("sliced_gt_sorting", "sorting"), + ("peak_on_channels", "sorting"), + ("gt_on_channels", "sorting"), + ] + + +class PeakDetectionStudy(BenchmarkStudy): + + benchmark_class = PeakDetectionBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + init_kwargs = self.cases[key]["init_kwargs"] + benchmark = PeakDetectionBenchmark(recording, gt_sorting, params, **init_kwargs) + return benchmark + + def plot_agreements(self, case_keys=None, figsize=(15, 15)): + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + + for count, key in enumerate(case_keys): + ax = axs[0, count] + ax.set_title(self.cases[key]["label"]) + plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + + def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + + for count, k in enumerate(("accuracy", "recall", "precision")): + + ax = axs[count] + for key in case_keys: + label = self.cases[key]["label"] + + analyzer = self.get_sorting_analyzer(key) + metrics = analyzer.get_extension("quality_metrics").get_data() + x = metrics["snr"].values + y = self.get_result(key)["gt_comparison"].get_performance()[k].values + ax.scatter(x, y, marker=".", label=label) + ax.set_title(k) + + if count == 2: + ax.legend() + + +# def run(self, peaks=None, positions=None, delta=0.2): +# t_start = time.time() + +# if peaks is not None: +# self._peaks = peaks + +# nb_peaks = len(self.peaks) + +# if positions is not None: +# self._positions = positions + +# spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0]["sample_index"] +# times2 = self.peaks["sample_index"] + +# print("The gt recording has {} peaks and {} have been detected".format(len(times1[0]), len(times2))) + +# matches = make_matching_events(spikes1["sample_index"], times2, int(delta * self.sampling_rate / 1000)) +# self.matches = matches + +# self.deltas = {"labels": [], "delta": matches["delta_frame"]} +# self.deltas["labels"] = spikes1["unit_index"][matches["index1"]] + +# gt_matches = matches["index1"] +# self.sliced_gt_sorting = NumpySorting(spikes1[gt_matches], self.sampling_rate, self.gt_sorting.unit_ids) + +# ratio = 100 * len(gt_matches) / len(spikes1) +# print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) + +# matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) +# self.good_matches = matches["index1"] + +# garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) +# garbage_channels = self.peaks["channel_index"][garbage_matches] +# garbage_peaks = times2[garbage_matches] +# nb_garbage = len(garbage_peaks) + +# ratio = 100 * len(garbage_peaks) / len(times2) +# self.garbage_sorting = NumpySorting.from_times_labels(garbage_peaks, garbage_channels, self.sampling_rate) + +# print("The peaks have {0:.2f}% of garbage (without gt around)".format(ratio)) + +# self.comp = GroundTruthComparison(self.gt_sorting, self.sliced_gt_sorting, exhaustive_gt=self.exhaustive_gt) + +# for label, sorting in zip( +# ["gt", "full_gt", "garbage"], [self.sliced_gt_sorting, self.gt_sorting, self.garbage_sorting] +# ): +# tmp_folder = os.path.join(self.tmp_folder, label) +# if os.path.exists(tmp_folder): +# import shutil + +# shutil.rmtree(tmp_folder) + +# if not (label == "full_gt" and label in self.waveforms): +# if self.verbose: +# print(f"Extracting waveforms for {label}") + +# self.waveforms[label] = extract_waveforms( +# self.recording, +# sorting, +# tmp_folder, +# load_if_exists=True, +# ms_before=2.5, +# ms_after=3.5, +# max_spikes_per_unit=500, +# return_scaled=False, +# **self.job_kwargs, +# ) + +# self.templates[label] = self.waveforms[label].get_all_templates(mode="median") + +# if self.gt_peaks is None: +# if self.verbose: +# print("Computing gt peaks") +# gt_peaks_ = self.gt_sorting.to_spike_vector() +# self.gt_peaks = np.zeros( +# gt_peaks_.size, +# dtype=[ +# ("sample_index", " Date: Fri, 1 Mar 2024 13:22:08 +0000 Subject: [PATCH 178/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../benchmark/benchmark_peak_detection.py | 30 ++++++++----------- .../benchmark/benchmark_peak_selection.py | 6 ++-- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py index 421255bee2..d926d80c48 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py @@ -41,23 +41,21 @@ def __init__(self, recording, gt_sorting, params, exhaustive_gt=True): self.exhaustive_gt = exhaustive_gt self.method = params["method"] self.method_kwargs = params["method_kwargs"] - self.result = {"gt_peaks" : self.gt_peaks} + self.result = {"gt_peaks": self.gt_peaks} def run(self, **job_kwargs): - peaks = detect_peaks( - self.recording, method=self.method, **self.method_kwargs, **job_kwargs - ) + peaks = detect_peaks(self.recording, method=self.method, **self.method_kwargs, **job_kwargs) self.result["peaks"] = peaks def compute_result(self, **result_params): spikes = self.result["peaks"] - self.result["peak_on_channels"] = NumpySorting.from_peaks(spikes, - self.recording.sampling_frequency, - unit_ids=self.recording.channel_ids) + self.result["peak_on_channels"] = NumpySorting.from_peaks( + spikes, self.recording.sampling_frequency, unit_ids=self.recording.channel_ids + ) spikes = self.result["gt_peaks"] - self.result["gt_on_channels"] = NumpySorting.from_peaks(spikes, - self.recording.sampling_frequency, - unit_ids=self.recording.channel_ids) + self.result["gt_on_channels"] = NumpySorting.from_peaks( + spikes, self.recording.sampling_frequency, unit_ids=self.recording.channel_ids + ) self.result["gt_comparison"] = GroundTruthComparison( self.result["gt_on_channels"], self.result["peak_on_channels"], exhaustive_gt=self.exhaustive_gt @@ -73,13 +71,13 @@ def compute_result(self, **result_params): self.matches = matches self.gt_matches = matches["index1"] - self.deltas = {"labels": [], "channels" : [], "delta": matches["delta_frame"]} + self.deltas = {"labels": [], "channels": [], "delta": matches["delta_frame"]} self.deltas["labels"] = gt_peaks["unit_index"][self.gt_matches] self.deltas["channels"] = self.result["gt_peaks"]["unit_index"][self.gt_matches] - self.result["sliced_gt_sorting"] = NumpySorting(gt_peaks[self.gt_matches], - self.recording.sampling_frequency, - self.gt_sorting.unit_ids) + self.result["sliced_gt_sorting"] = NumpySorting( + gt_peaks[self.gt_matches], self.recording.sampling_frequency, self.gt_sorting.unit_ids + ) ratio = 100 * len(self.gt_matches) / len(times1) print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) @@ -97,8 +95,7 @@ def compute_result(self, **result_params): # print("The peaks have {0:.2f}% of garbage (without gt around)".format(ratio)) - _run_key_saved = [("peaks", "npy"), - ("gt_peaks", "npy")] + _run_key_saved = [("peaks", "npy"), ("gt_peaks", "npy")] _result_key_saved = [ ("gt_comparison", "pickle"), @@ -256,7 +253,6 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): # self.garbage_peaks = self.peaks[garbage_matches] - # def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5): # fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(15, 10)) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1469a23e62..b8a3e34a2f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -41,7 +41,7 @@ def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True): self.exhaustive_gt = exhaustive_gt self.method = params["method"] self.method_kwargs = params["method_kwargs"] - self.result = {"gt_peaks" : self.gt_peaks} + self.result = {"gt_peaks": self.gt_peaks} def run(self, **job_kwargs): labels, peak_labels = find_cluster_from_peaks( @@ -82,7 +82,9 @@ def compute_result(self, **result_params): ext = sorting_analyzer.compute("fast_templates") self.result["clustering_templates"] = ext.get_data(outputs="Templates") - _run_key_saved = [("peak_labels", "npy"),] + _run_key_saved = [ + ("peak_labels", "npy"), + ] _result_key_saved = [ ("gt_comparison", "pickle"), From 019c79727b596c40c86b041f7f14af2264012aca Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 1 Mar 2024 14:39:08 +0100 Subject: [PATCH 179/192] Plots for detection --- .../benchmark/benchmark_peak_detection.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py index 421255bee2..016de4ccec 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py @@ -34,7 +34,7 @@ def __init__(self, recording, gt_sorting, params, exhaustive_gt=True): self.gt_sorting = gt_sorting sorting_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=False) - sorting_analyzer.compute(["random_spikes", "fast_templates"]) + sorting_analyzer.compute(["random_spikes", "fast_templates", "spike_amplitudes"]) extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") self.gt_peaks = self.gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) self.params = params @@ -42,6 +42,7 @@ def __init__(self, recording, gt_sorting, params, exhaustive_gt=True): self.method = params["method"] self.method_kwargs = params["method_kwargs"] self.result = {"gt_peaks" : self.gt_peaks} + self.result['gt_amplitudes'] = sorting_analyzer.get_extension('spike_amplitudes').get_data() def run(self, **job_kwargs): peaks = detect_peaks( @@ -98,7 +99,8 @@ def compute_result(self, **result_params): # print("The peaks have {0:.2f}% of garbage (without gt around)".format(ratio)) _run_key_saved = [("peaks", "npy"), - ("gt_peaks", "npy")] + ("gt_peaks", "npy"), + ("gt_amplitudes", "npy")] _result_key_saved = [ ("gt_comparison", "pickle"), @@ -153,6 +155,23 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): if count == 2: ax.legend() + def plot_detected_amplitudes(self, case_keys=None, figsize=(15,5)): + + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + + for count, key in enumerate(case_keys): + ax = axs[0, count] + data1 = self.get_result(key)['peaks']['amplitude'] + data2 = self.get_result(key)['gt_amplitudes'] + bins = np.linspace(data2.min(), data2.max(), 100) + ax.hist(data1, bins=bins, alpha=0.5, label='detected') + ax.hist(data2, bins=bins, alpha=0.5, label='gt') + ax.set_title(self.cases[key]["label"]) + ax.legend() + # def run(self, peaks=None, positions=None, delta=0.2): # t_start = time.time() From 3a18ec4a81b77109f190d6e6a910630d813d312c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:37:26 +0000 Subject: [PATCH 180/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../benchmark/benchmark_peak_detection.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py index 8496f46f92..8a1e370dbf 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py @@ -41,8 +41,8 @@ def __init__(self, recording, gt_sorting, params, exhaustive_gt=True): self.exhaustive_gt = exhaustive_gt self.method = params["method"] self.method_kwargs = params["method_kwargs"] - self.result = {"gt_peaks" : self.gt_peaks} - self.result['gt_amplitudes'] = sorting_analyzer.get_extension('spike_amplitudes').get_data() + self.result = {"gt_peaks": self.gt_peaks} + self.result["gt_amplitudes"] = sorting_analyzer.get_extension("spike_amplitudes").get_data() def run(self, **job_kwargs): peaks = detect_peaks(self.recording, method=self.method, **self.method_kwargs, **job_kwargs) @@ -96,9 +96,7 @@ def compute_result(self, **result_params): # print("The peaks have {0:.2f}% of garbage (without gt around)".format(ratio)) - _run_key_saved = [("peaks", "npy"), - ("gt_peaks", "npy"), - ("gt_amplitudes", "npy")] + _run_key_saved = [("peaks", "npy"), ("gt_peaks", "npy"), ("gt_amplitudes", "npy")] _result_key_saved = [ ("gt_comparison", "pickle"), @@ -153,7 +151,7 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): if count == 2: ax.legend() - def plot_detected_amplitudes(self, case_keys=None, figsize=(15,5)): + def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -162,11 +160,11 @@ def plot_detected_amplitudes(self, case_keys=None, figsize=(15,5)): for count, key in enumerate(case_keys): ax = axs[0, count] - data1 = self.get_result(key)['peaks']['amplitude'] - data2 = self.get_result(key)['gt_amplitudes'] + data1 = self.get_result(key)["peaks"]["amplitude"] + data2 = self.get_result(key)["gt_amplitudes"] bins = np.linspace(data2.min(), data2.max(), 100) - ax.hist(data1, bins=bins, alpha=0.5, label='detected') - ax.hist(data2, bins=bins, alpha=0.5, label='gt') + ax.hist(data1, bins=bins, alpha=0.5, label="detected") + ax.hist(data2, bins=bins, alpha=0.5, label="gt") ax.set_title(self.cases[key]["label"]) ax.legend() From 8c9cef5f17176baa2b1285be3f5b205508c94d1c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 08:48:10 +0000 Subject: [PATCH 181/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/modules_gallery/core/plot_4_sorting_analyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modules_gallery/core/plot_4_sorting_analyzer.py b/examples/modules_gallery/core/plot_4_sorting_analyzer.py index 864f11ad1d..20dc078197 100644 --- a/examples/modules_gallery/core/plot_4_sorting_analyzer.py +++ b/examples/modules_gallery/core/plot_4_sorting_analyzer.py @@ -18,7 +18,7 @@ * "noise_levels" : compute noise level from traces (usefull to get snr of units) * can be in memory or persistent to disk (2 formats binary/npy or zarr) -More extesions are available in `spikeinterface.postprocessing` like "principal_components", "spike_amplitudes", +More extesions are available in `spikeinterface.postprocessing` like "principal_components", "spike_amplitudes", "unit_lcations", ... From 38fc6b4bb2ebde16547b24676db2c92a9bea73c4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 4 Mar 2024 14:43:29 +0100 Subject: [PATCH 182/192] oups --- doc/how_to/analyse_neuropixels.rst | 6 +++--- examples/how_to/analyse_neuropixels.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index 255172efc0..36eb0a0f63 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -266,7 +266,7 @@ the ipywydgets interactive ploter .. code:: python %matplotlib widget - si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') + si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') Note that using this ipywydgets make possible to explore diffrents preprocessing chain wihtout to save the entire file to disk. Everything @@ -417,7 +417,7 @@ Let’s use here the ``locally_exclusive`` method for detection and the job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, - detect_threshold=5, local_radius_um=50., **job_kwargs) + detect_threshold=5, radius_um=50., **job_kwargs) peaks @@ -442,7 +442,7 @@ Let’s use here the ``locally_exclusive`` method for detection and the from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(rec, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs) + peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs) diff --git a/examples/how_to/analyse_neuropixels.py b/examples/how_to/analyse_neuropixels.py index 3a936b072c..92ccfca602 100644 --- a/examples/how_to/analyse_neuropixels.py +++ b/examples/how_to/analyse_neuropixels.py @@ -82,7 +82,7 @@ # # ```python # # %matplotlib widget -# si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') +# si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') # ``` # # Note that using this ipywydgets make possible to explore diffrents preprocessing chain wihtout to save the entire file to disk. @@ -170,13 +170,13 @@ job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, - detect_threshold=5, local_radius_um=50., **job_kwargs) + detect_threshold=5, radius_um=50., **job_kwargs) peaks # + from spikeinterface.sortingcomponents.peak_localization import localize_peaks -peak_locations = localize_peaks(rec, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs) +peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs) # - # ### Check for drifts From 70cb10dc7d6d6aae7e6ff3aa19a7d800b7888448 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 13:44:27 +0000 Subject: [PATCH 183/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/how_to/analyse_neuropixels.rst | 47 +++++++++++++------------- examples/how_to/analyse_neuropixels.py | 10 +++--- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index 36eb0a0f63..c045f3e849 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -11,7 +11,7 @@ including custom pre- and post-processing. .. code:: ipython3 import spikeinterface.full as si - + import numpy as np import matplotlib.pyplot as plt from pathlib import Path @@ -19,7 +19,7 @@ including custom pre- and post-processing. .. code:: ipython3 base_folder = Path('/mnt/data/sam/DataSpikeSorting/howto_si/neuropixel_example/') - + spikeglx_folder = base_folder / 'Rec_1_10_11_2021_g0' @@ -54,7 +54,7 @@ We need to specify which one to read: .. parsed-literal:: - SpikeGLXRecordingExtractor: 384 channels - 30.0kHz - 1 segments - 34,145,070 samples + SpikeGLXRecordingExtractor: 384 channels - 30.0kHz - 1 segments - 34,145,070 samples 1,138.15s (18.97 minutes) - int16 dtype - 24.42 GiB @@ -74,11 +74,11 @@ We need to specify which one to read: .dataframe tbody tr th:only-of-type { vertical-align: middle; } - + .dataframe tbody tr th { vertical-align: top; } - + .dataframe thead th { text-align: right; } @@ -236,7 +236,7 @@ Let’s do something similar to the IBL destriping chain (See bad_channel_ids, channel_labels = si.detect_bad_channels(rec1) rec2 = rec1.remove_channels(bad_channel_ids) print('bad_channel_ids', bad_channel_ids) - + rec3 = si.phase_shift(rec2) rec4 = si.common_reference(rec3, operator="median", reference="global") rec = rec4 @@ -252,7 +252,7 @@ Let’s do something similar to the IBL destriping chain (See .. parsed-literal:: - CommonReferenceRecording: 383 channels - 30.0kHz - 1 segments - 34,145,070 samples + CommonReferenceRecording: 383 channels - 30.0kHz - 1 segments - 34,145,070 samples 1,138.15s (18.97 minutes) - int16 dtype - 24.36 GiB @@ -277,7 +277,7 @@ is lazy, so you can change the previsous cell (parameters, step order, # here we use static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) - + si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) @@ -331,7 +331,7 @@ parallelization mechanism of SpikeInterface. .. code:: ipython3 job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) - + rec = rec.save(folder=base_folder / 'preprocess', format='binary', **job_kwargs) .. code:: ipython3 @@ -344,7 +344,7 @@ parallelization mechanism of SpikeInterface. .. parsed-literal:: - BinaryFolderRecording: 383 channels - 30.0kHz - 1 segments - 34,145,070 samples + BinaryFolderRecording: 383 channels - 30.0kHz - 1 segments - 34,145,070 samples 1,138.15s (18.97 minutes) - int16 dtype - 24.36 GiB @@ -414,9 +414,9 @@ Let’s use here the ``locally_exclusive`` method for detection and the .. code:: ipython3 from spikeinterface.sortingcomponents.peak_detection import detect_peaks - + job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) - peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, + peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, detect_threshold=5, radius_um=50., **job_kwargs) peaks @@ -441,7 +441,7 @@ Let’s use here the ``locally_exclusive`` method for detection and the .. code:: ipython3 from spikeinterface.sortingcomponents.peak_localization import localize_peaks - + peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs) @@ -489,7 +489,7 @@ documentation for motion estimation and correction for more details. fig, ax = plt.subplots(figsize=(15, 10)) si.plot_probe_map(rec, ax=ax, with_channel_ids=True) ax.set_ylim(-100, 150) - + ax.scatter(peak_locations['x'], peak_locations['y'], color='purple', alpha=0.002) @@ -566,7 +566,7 @@ In this example: # run kilosort2.5 without drift correction params_kilosort2_5 = {'do_correction': False} - + sorting = si.run_sorter('kilosort2_5', rec, output_folder=base_folder / 'kilosort2.5_output', docker_image=True, verbose=True, **params_kilosort2_5) @@ -612,7 +612,7 @@ Note that our object is not persistent to disk because we use .. code:: ipython3 - + analyzer = si.create_sorting_analyzer(sorting, rec, sparse=True, format="memory") analyzer @@ -628,7 +628,7 @@ Note that our object is not persistent to disk because we use .. parsed-literal:: SortingAnalyzer: 383 channels - 31 units - 1 segments - memory - sparse - has recording - Loaded 0 extenstions: + Loaded 0 extenstions: @@ -719,12 +719,12 @@ PCA for their computation. This can be achieved with: .. code:: ipython3 metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'amplitude_cutoff'] - - + + # metrics = analyzer.compute("quality_metrics").get_data() # equivalent to metrics = si.compute_quality_metrics(analyzer, metric_names=metric_names) - + metrics @@ -743,11 +743,11 @@ PCA for their computation. This can be achieved with: .dataframe tbody tr th:only-of-type { vertical-align: middle; } - + .dataframe tbody tr th { vertical-align: top; } - + .dataframe thead th { text-align: right; } @@ -1061,7 +1061,7 @@ A very common curation approach is to threshold these metrics to select amplitude_cutoff_thresh = 0.1 isi_violations_ratio_thresh = 1 presence_ratio_thresh = 0.9 - + our_query = f"(amplitude_cutoff < {amplitude_cutoff_thresh}) & (isi_violations_ratio < {isi_violations_ratio_thresh}) & (presence_ratio > {presence_ratio_thresh})" print(our_query) @@ -1138,4 +1138,3 @@ And push the results to sortingview webased viewer .. code:: python si.plot_sorting_summary(analyzer_clean, backend='sortingview') - diff --git a/examples/how_to/analyse_neuropixels.py b/examples/how_to/analyse_neuropixels.py index 92ccfca602..ce5bacdda0 100644 --- a/examples/how_to/analyse_neuropixels.py +++ b/examples/how_to/analyse_neuropixels.py @@ -159,7 +159,7 @@ # # The two functions (detect + localize): # -# * can be run parallel +# * can be run parallel # * are very fast when the preprocessed recording is already saved (and a bit slower otherwise) # * implement several methods # @@ -169,7 +169,7 @@ from spikeinterface.sortingcomponents.peak_detection import detect_peaks job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) -peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, +peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, detect_threshold=5, radius_um=50., **job_kwargs) peaks @@ -214,7 +214,7 @@ # # Please carwfully read the `spikeinterface.sorters` documentation for more information. # -# In this example: +# In this example: # # * we will run kilosort2.5 # * we apply no drift correction (because we don't have drift) @@ -288,7 +288,7 @@ # # We have a single function `compute_quality_metrics(SortingAnalyzer)` that returns a `pandas.Dataframe` with the desired metrics. # -# Note that this function is also an extension and so can be saved. And so this is equivalent to do : +# Note that this function is also an extension and so can be saved. And so this is equivalent to do : # `metrics = analyzer.compute("quality_metrics").get_data()` # # @@ -349,5 +349,3 @@ # ```python # si.plot_sorting_summary(analyzer_clean, backend='sortingview') # ``` - - From bcdf29526e5b785234e1929482abc68b1f7c3066 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 4 Mar 2024 15:16:26 +0100 Subject: [PATCH 184/192] revert cache_preprocessing in simple_sorter --- .../sorters/internal/simplesorter.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/sorters/internal/simplesorter.py b/src/spikeinterface/sorters/internal/simplesorter.py index 3f2ada3085..69487baf6c 100644 --- a/src/spikeinterface/sorters/internal/simplesorter.py +++ b/src/spikeinterface/sorters/internal/simplesorter.py @@ -42,7 +42,7 @@ class SimpleSorter(ComponentsBasedSorter): "core_dist_n_jobs": -1, "cluster_selection_method": "leaf", }, - "cache_preprocessing": {"mode": None, "memory_limit": 0.5, "delete_cache": True}, + # "cache_preprocessing": {"mode": None, "memory_limit": 0.5, "delete_cache": True}, "job_kwargs": {"n_jobs": -1, "chunk_duration": "1s"}, } @@ -57,7 +57,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs.update({"verbose": verbose, "progress_bar": verbose}) from spikeinterface.sortingcomponents.peak_detection import detect_peaks - from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing + from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks @@ -83,7 +83,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = recording_raw noise_levels = get_noise_levels(recording, return_scaled=False) - recording = cache_preprocessing(recording, **job_kwargs, **params["cache_preprocessing"]) + # recording = cache_preprocessing(recording, **job_kwargs, **params["cache_preprocessing"]) # detection detection_params = params["detection"].copy() @@ -186,24 +186,24 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(features_folder / "peak_labels.npy", peak_labels) - folder_to_delete = None + # folder_to_delete = None - if "mode" in params["cache_preprocessing"]: - cache_mode = params["cache_preprocessing"]["mode"] - else: - cache_mode = "memory" + # if "mode" in params["cache_preprocessing"]: + # cache_mode = params["cache_preprocessing"]["mode"] + # else: + # cache_mode = "memory" - if "delete_cache" in params["cache_preprocessing"]: - delete_cache = params["cache_preprocessing"] - else: - delete_cache = True + # if "delete_cache" in params["cache_preprocessing"]: + # delete_cache = params["cache_preprocessing"] + # else: + # delete_cache = True - if cache_mode in ["folder", "zarr"] and delete_cache: - folder_to_delete = recording._kwargs["folder_path"] + # if cache_mode in ["folder", "zarr"] and delete_cache: + # folder_to_delete = recording._kwargs["folder_path"] - del recording - if folder_to_delete is not None: - shutil.rmtree(folder_to_delete) + # del recording + # if folder_to_delete is not None: + # shutil.rmtree(folder_to_delete) # keep positive labels keep = peak_labels >= 0 From 902106b741f506efb5d7164cd1231bab4fd067b4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 4 Mar 2024 15:30:02 +0100 Subject: [PATCH 185/192] Make pandas weak import for pandas in core --- src/spikeinterface/core/sortinganalyzer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index f1858810fb..802981ccf2 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1331,8 +1331,14 @@ def _save_data(self, **kwargs): if self.sorting_analyzer.is_read_only(): raise ValueError(f"The SortingAnalyzer is read-only saving extension {self.extension_name} is not possible") - if self.format == "binary_folder": + try: + # pandas is a weak dependency for spikeinterface.core import pandas as pd + HAS_PANDAS = True + except: + HAS_PANDAS = False + + if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() for ext_data_name, ext_data in self.data.items(): @@ -1347,7 +1353,7 @@ def _save_data(self, **kwargs): pass else: np.save(data_file, ext_data) - elif isinstance(ext_data, pd.DataFrame): + elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): ext_data.to_csv(extension_folder / f"{ext_data_name}.csv", index=True) else: try: @@ -1357,7 +1363,6 @@ def _save_data(self, **kwargs): raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - import pandas as pd import numcodecs extension_group = self._get_zarr_extension_group(mode="r+") @@ -1375,7 +1380,7 @@ def _save_data(self, **kwargs): ) elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) - elif isinstance(ext_data, pd.DataFrame): + elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): ext_data.to_xarray().to_zarr( store=extension_group.store, group=f"{extension_group.name}/{ext_data_name}", From beb7cf3d069f4e8a3122ce8d1846c5b986a12832 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:37:43 +0000 Subject: [PATCH 186/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 802981ccf2..062ad3b60a 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1334,6 +1334,7 @@ def _save_data(self, **kwargs): try: # pandas is a weak dependency for spikeinterface.core import pandas as pd + HAS_PANDAS = True except: HAS_PANDAS = False From fd35a30481d5a100ddb0e71f9eb406a419125cba Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 4 Mar 2024 18:22:29 +0100 Subject: [PATCH 187/192] oups --- .../qualitymetrics/tests/test_quality_metric_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 8e1be24753..d39b25379d 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -146,7 +146,7 @@ def test_empty_units(sorting_analyzer_simple): assert len(sorting_empty.get_empty_unit_ids()) == 3 sorting_analyzer_empty = create_sorting_analyzer(sorting_empty, sorting_analyzer.recording, format="memory") - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) + sorting_analyzer_empty.compute("random_spikes", max_spikes_per_unit=300, seed=2205) sorting_analyzer_empty.compute("noise_levels") sorting_analyzer_empty.compute("waveforms", **job_kwargs) sorting_analyzer_empty.compute("templates") From e7937757827133f927da884a4633165eb3db8007 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 4 Mar 2024 18:36:42 +0100 Subject: [PATCH 188/192] oups --- src/spikeinterface/widgets/tests/test_widgets.py | 6 +++--- src/spikeinterface/widgets/unit_templates.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 2d228d7d5f..6c15e1dbf6 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -584,7 +584,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_traces() # mytest.test_plot_spikes_on_traces() # mytest.test_plot_unit_waveforms() - # mytest.test_plot_unit_templates() + mytest.test_plot_unit_templates() # mytest.test_plot_unit_depths() # mytest.test_plot_autocorrelograms() # mytest.test_plot_crosscorrelograms() @@ -603,7 +603,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_unit_presence() # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() - mytest.test_plot_sorting_summary() - # plt.show() + # mytest.test_plot_sorting_summary() + plt.show() # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index a39d5e0f0c..1350bb71a5 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -50,7 +50,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations) if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.waveform_extractor.sorting) + v_units_table = generate_unit_table_view(dp.sorting_analyzer.sorting) self.view = vv.Box( direction="horizontal", From f9640e818a0e003a78b603d90d8d90799aa75458 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 4 Mar 2024 20:35:16 +0100 Subject: [PATCH 189/192] more fix in dev --- src/spikeinterface/sorters/tests/test_launcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 2f2dc583d2..8019f4e620 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -70,6 +70,7 @@ def test_run_sorter_jobs_loop(job_list): print(sortings) +@pytest.mark.skipif(True, reason="tridesclous is already multiprocessing, joblib cannot run it in parralel") def test_run_sorter_jobs_joblib(job_list): if base_output.is_dir(): shutil.rmtree(base_output) From 5b48417a5582904b9b41adeb10a883c344ef0baf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 4 Mar 2024 21:23:38 +0100 Subject: [PATCH 190/192] Fixes after Pierre remove many features methods!!! --- .../clustering/position_and_features.py | 10 ++++++---- .../sortingcomponents/features_from_peaks.py | 1 + .../tests/test_features_from_peaks.py | 6 ++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index c4a999ae92..805a1572fd 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -47,6 +47,8 @@ class PositionAndFeaturesClustering: @classmethod def main_function(cls, recording, peaks, params): + from sklearn.preprocessing import QuantileTransformer + assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" if "n_jobs" in params["job_kwargs"]: @@ -68,22 +70,22 @@ def main_function(cls, recording, peaks, params): position_method = d["peak_localization_kwargs"]["method"] - features_list = [position_method, "ptp", "energy"] + features_list = [position_method, "ptp",] features_params = { position_method: {"radius_um": params["radius_um"]}, "ptp": {"all_channels": False, "radius_um": params["radius_um"]}, - "energy": {"radius_um": params["radius_um"]}, } features_data = compute_features_from_peaks( recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"] ) - hdbscan_data = np.zeros((len(peaks), 4), dtype=np.float32) + hdbscan_data = np.zeros((len(peaks), 3), dtype=np.float32) hdbscan_data[:, 0] = features_data[0]["x"] hdbscan_data[:, 1] = features_data[0]["y"] hdbscan_data[:, 2] = features_data[1] - hdbscan_data[:, 3] = features_data[2] + + preprocessing = QuantileTransformer(output_distribution="uniform") hdbscan_data = preprocessing.fit_transform(hdbscan_data) diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index 8770dec6f9..40f89068f9 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -215,4 +215,5 @@ def compute(self, traces, peaks, waveforms): "amplitude": AmplitudeFeature, "ptp": PeakToPeakFeature, "random_projections": RandomProjectionsFeature, + "center_of_mass": LocalizeCenterOfMass, } diff --git a/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py b/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py index 896c4e1e1e..160ba3cb36 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py @@ -26,12 +26,11 @@ def test_features_from_peaks(): **job_kwargs, ) - feature_list = ["amplitude", "ptp", "center_of_mass", "energy"] + feature_list = ["amplitude", "ptp", "center_of_mass",] feature_params = { "amplitude": {"all_channels": False, "peak_sign": "neg"}, "ptp": {"all_channels": False}, "center_of_mass": {"radius_um": 120.0}, - "energy": {"radius_um": 160.0}, } features = compute_features_from_peaks(recording, peaks, feature_list, feature_params=feature_params, **job_kwargs) @@ -45,14 +44,13 @@ def test_features_from_peaks(): # split feature variable job_kwargs["n_jobs"] = 2 - amplitude, ptp, com, energy = compute_features_from_peaks( + amplitude, ptp, com, = compute_features_from_peaks( recording, peaks, feature_list, feature_params=feature_params, **job_kwargs ) assert amplitude.ndim == 1 # because all_channels=False assert ptp.ndim == 1 # because all_channels=False assert com.ndim == 1 assert "x" in com.dtype.fields - assert energy.ndim == 1 # amplitude and peak to peak with multi channels d = {"all_channels": True} From 6bcadd895f4dede2218c9ee252b8236d5ca05e53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 20:24:14 +0000 Subject: [PATCH 191/192] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../clustering/position_and_features.py | 9 +++++---- .../tests/test_features_from_peaks.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index 805a1572fd..3c58b5edb9 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -48,7 +48,7 @@ class PositionAndFeaturesClustering: @classmethod def main_function(cls, recording, peaks, params): from sklearn.preprocessing import QuantileTransformer - + assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" if "n_jobs" in params["job_kwargs"]: @@ -70,7 +70,10 @@ def main_function(cls, recording, peaks, params): position_method = d["peak_localization_kwargs"]["method"] - features_list = [position_method, "ptp",] + features_list = [ + position_method, + "ptp", + ] features_params = { position_method: {"radius_um": params["radius_um"]}, "ptp": {"all_channels": False, "radius_um": params["radius_um"]}, @@ -85,8 +88,6 @@ def main_function(cls, recording, peaks, params): hdbscan_data[:, 1] = features_data[0]["y"] hdbscan_data[:, 2] = features_data[1] - - preprocessing = QuantileTransformer(output_distribution="uniform") hdbscan_data = preprocessing.fit_transform(hdbscan_data) diff --git a/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py b/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py index 160ba3cb36..9bc9fd9ab0 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py @@ -26,7 +26,11 @@ def test_features_from_peaks(): **job_kwargs, ) - feature_list = ["amplitude", "ptp", "center_of_mass",] + feature_list = [ + "amplitude", + "ptp", + "center_of_mass", + ] feature_params = { "amplitude": {"all_channels": False, "peak_sign": "neg"}, "ptp": {"all_channels": False}, @@ -44,9 +48,11 @@ def test_features_from_peaks(): # split feature variable job_kwargs["n_jobs"] = 2 - amplitude, ptp, com, = compute_features_from_peaks( - recording, peaks, feature_list, feature_params=feature_params, **job_kwargs - ) + ( + amplitude, + ptp, + com, + ) = compute_features_from_peaks(recording, peaks, feature_list, feature_params=feature_params, **job_kwargs) assert amplitude.ndim == 1 # because all_channels=False assert ptp.ndim == 1 # because all_channels=False assert com.ndim == 1 From e9418e9a3d91973a7f4529edb8a1d2fa39738404 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 4 Mar 2024 21:25:05 +0100 Subject: [PATCH 192/192] oups remove file push by accident some time ago. Sorry. --- dev_pool.py | 74 ----------------------------------------------------- 1 file changed, 74 deletions(-) delete mode 100644 dev_pool.py diff --git a/dev_pool.py b/dev_pool.py deleted file mode 100644 index 9a9b2ca0f2..0000000000 --- a/dev_pool.py +++ /dev/null @@ -1,74 +0,0 @@ -import multiprocessing -from concurrent.futures import ProcessPoolExecutor - -def f(x): - import os - # global _worker_num - p = multiprocessing.current_process() - print(p, type(p), p.name, p._identity[0], type(p._identity[0]), p.ident) - return x * x - - -def init_worker(lock, array_pid): - print(array_pid, len(array_pid)) - child_process = multiprocessing.current_process() - - lock.acquire() - num_worker = None - for i in range(len(array_pid)): - print(array_pid[i]) - if array_pid[i] == -1: - num_worker = i - array_pid[i] = child_process.ident - break - print(num_worker, child_process.ident) - lock.release() - -num_worker = 6 -lock = multiprocessing.Lock() -array_pid = multiprocessing.Array('i', num_worker) -for i in range(num_worker): - array_pid[i] = -1 - - -# with ProcessPoolExecutor( -# max_workers=4, -# ) as executor: -# print(executor._processes) -# results = executor.map(f, range(6)) - -with ProcessPoolExecutor( - max_workers=4, - initializer=init_worker, - initargs=(lock, array_pid) -) as executor: - print(executor._processes) - results = executor.map(f, range(6)) - -exit() -# global _worker_num -# def set_worker_index(i): -# global _worker_num -# _worker_num = i - -p = multiprocessing.Pool(processes=3) -# children = multiprocessing.active_children() -# for i, child in enumerate(children): -# child.submit(set_worker_index) - -# print(children) -print(p.map(f, range(6))) -p.close() -p.join() - -p = multiprocessing.Pool(processes=3) -print(p.map(f, range(6))) -print(p.map(f, range(6))) - -p.close() -p.join() - - -# print(multiprocessing.current_process()) -# p = multiprocessing.current_process() -# print(p._identity)