diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 67f1e52011..e95edb968c 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -74,6 +74,8 @@ dtype (unless specified otherwise): Some scaling pre-processors, such as :code:`whiten()` or :code:`zscore()`, will force the output to :code:`float32`. +When converting from a :code:`float` to an :code:`int`, the value will first be rounded to the nearest integer. + Available preprocessing ----------------------- diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 0d08922543..23d13c0afe 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -21,7 +21,7 @@ # This is to separate names when the key are tuples when saving folders -_key_separator = " ## " +_key_separator = "_##_" class GroundTruthStudy: diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 1a8674697a..7cde209a8d 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -316,35 +316,69 @@ def to_dict( recursive: bool = False, ) -> dict: """ - Make a nested serialized dictionary out of the extractor. The dictionary produced can be used to re-initialize - an extractor using load_extractor_from_dict(dump_dict) + Construct a nested dictionary representation of the extractor. + + This method facilitates the serialization of the extractor instance by converting it + to a dictionary. The resulting dictionary can be used to re-initialize the extractor + through the `load_extractor_from_dict` function. + + Examples + -------- + >>> dump_dict = original_extractor.to_dict() + >>> reloaded_extractor = load_extractor_from_dict(dump_dict) Parameters ---------- - include_annotations: bool, default: False - If True, all annotations are added to the dict - include_properties: bool, default: False - If True, all properties are added to the dict - relative_to: str, Path, or None, default: None - If not None, files and folders are serialized relative to this path - Used in waveform extractor to maintain relative paths to binary files even if the - containing folder / diretory is moved - folder_metadata: str, Path, or None - Folder with numpy `npy` files containing additional information (e.g. probe in BaseRecording) and properties. - recursive: bool, default: False - If True, all dicitionaries in the kwargs are expanded with `to_dict` as well + include_annotations : bool, default: False + Whether to include all annotations in the dictionary + include_properties : bool, default: False + Whether to include all properties in the dictionary, by default False. + relative_to : Union[str, Path, None], default: None + If provided, file and folder paths will be made relative to this path, + enabling portability in folder formats such as the waveform extractor, + by default None. + folder_metadata : Union[str, Path, None], default: None + Path to a folder containing additional metadata files (e.g., probe information in BaseRecording) + in numpy `npy` format, by default None. + recursive : bool, default: False + If True, recursively apply `to_dict` to dictionaries within the kwargs, by default False. + + Raises + ------ + ValueError + If `relative_to` is specified while `recursive` is False. Returns ------- - dump_dict: dict - A dictionary representation of the extractor. + dict + A dictionary representation of the extractor, with the following structure: + { + "class": , + "module": , (e.g. 'spikeinterface'), + "kwargs": , + "version": , + "relative_paths": , + "annotations": , + "properties": , + "folder_metadata": + } + + Notes + ----- + - The `relative_to` argument only has an effect if `recursive` is set to True. + - The `folder_metadata` argument will be made relative to `relative_to` if both are specified. + - The `version` field in the resulting dictionary reflects the version of the module + from which the extractor class originates. + - The full class attribute above is the full import of the class, e.g. + 'spikeinterface.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor' + - The module is usually 'spikeinterface', but can be different for custom extractors such as those of + SpikeForest or any other project that inherits the Extractor class from spikeinterface. """ - kwargs = self._kwargs - if relative_to and not recursive: raise ValueError("`relative_to` is only possible when `recursive=True`") + kwargs = self._kwargs if recursive: to_dict_kwargs = dict( include_annotations=include_annotations, @@ -366,27 +400,24 @@ def to_dict( new_kwargs[name] = transform_extractors_to_dict(value) kwargs = new_kwargs - class_name = str(type(self)).replace("", "") + + module_import_path = self.__class__.__module__ + class_name_no_path = self.__class__.__name__ + class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass' module = class_name.split(".")[0] - imported_module = importlib.import_module(module) - try: - version = imported_module.__version__ - except AttributeError: - version = "unknown" + imported_module = importlib.import_module(module) + module_version = getattr(imported_module, "__version__", "unknown") dump_dict = { "class": class_name, "module": module, "kwargs": kwargs, - "version": version, + "version": module_version, "relative_paths": (relative_to is not None), } - try: - dump_dict["version"] = imported_module.__version__ - except AttributeError: - dump_dict["version"] = "unknown" + dump_dict["version"] = module_version # Can be spikeinterface, spikefores, etc. if include_annotations: dump_dict["annotations"] = self._annotations @@ -805,7 +836,7 @@ def save_to_folder(self, name=None, folder=None, overwrite=False, verbose=True, * explicit sub-folder, implicit base-folder : `extractor.save(name="extarctor_name")` * generated: `extractor.save()` - The second option saves to subfolder "extarctor_name" in + The second option saves to subfolder "extractor_name" in "get_global_tmp_folder()". You can set the global tmp folder with: "set_global_tmp_folder("path-to-global-folder")" diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 526911814c..95e2aa93d4 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -57,13 +57,13 @@ def add_sorting_segment(self, sorting_segment): self._sorting_segments.append(sorting_segment) sorting_segment.set_parent_extractor(self) - def get_sampling_frequency(self): + def get_sampling_frequency(self) -> float: return self._sampling_frequency - def get_num_segments(self): + def get_num_segments(self) -> int: return len(self._sorting_segments) - def get_num_samples(self, segment_index=None): + def get_num_samples(self, segment_index=None) -> int: """Returns the number of samples of the associated recording for a segment. Parameters @@ -82,7 +82,7 @@ def get_num_samples(self, segment_index=None): ), "This methods requires an associated recording. Call self.register_recording() first." return self._recording.get_num_samples(segment_index=segment_index) - def get_total_samples(self): + def get_total_samples(self) -> int: """Returns the total number of samples of the associated recording. Returns @@ -322,9 +322,11 @@ def count_num_spikes_per_unit(self, outputs="dict"): else: raise ValueError("count_num_spikes_per_unit() output must be 'dict' or 'array'") - def count_total_num_spikes(self): + def count_total_num_spikes(self) -> int: """ - Get total number of spikes summed across segment and units. + Get total number of spikes in the sorting. + + This is the sum of all spikes in all segments across all units. Returns ------- @@ -333,9 +335,10 @@ def count_total_num_spikes(self): """ return self.to_spike_vector().size - def select_units(self, unit_ids, renamed_unit_ids=None): + def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting: """ - Selects a subset of units + Returns a new sorting object which contains only a selected subset of units. + Parameters ---------- @@ -354,9 +357,30 @@ def select_units(self, unit_ids, renamed_unit_ids=None): sub_sorting = UnitsSelectionSorting(self, unit_ids, renamed_unit_ids=renamed_unit_ids) return sub_sorting - def remove_units(self, remove_unit_ids): + def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting: + """ + Returns a new sorting object with renamed units. + + + Parameters + ---------- + new_unit_ids : numpy.array or list + List of new names for unit ids. + They should map positionally to the existing unit ids. + + Returns + ------- + BaseSorting + Sorting object with renamed units + """ + from spikeinterface import UnitsSelectionSorting + + sub_sorting = UnitsSelectionSorting(self, renamed_unit_ids=new_unit_ids) + return sub_sorting + + def remove_units(self, remove_unit_ids) -> BaseSorting: """ - Removes a subset of units + Returns a new sorting object with contains only a selected subset of units. Parameters ---------- @@ -366,7 +390,7 @@ def remove_units(self, remove_unit_ids): Returns ------- BaseSorting - Sorting object without removed units + Sorting without the removed units """ from spikeinterface import UnitsSelectionSorting @@ -376,7 +400,8 @@ def remove_units(self, remove_unit_ids): def remove_empty_units(self): """ - Removes units with empty spike trains + Returns a new sorting object which contains only units with at least one spike. + For multi-segments, a unit is considered empty if it contains no spikes in all segments. Returns ------- @@ -387,16 +412,12 @@ def remove_empty_units(self): return self.select_units(non_empty_units) def get_non_empty_unit_ids(self): - non_empty_units = [] - for segment_index in range(self.get_num_segments()): - for unit in self.get_unit_ids(): - if len(self.get_unit_spike_train(unit, segment_index=segment_index)) > 0: - non_empty_units.append(unit) - non_empty_units = np.unique(non_empty_units) - return non_empty_units + num_spikes_per_unit = self.count_num_spikes_per_unit() + + return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] != 0]) def get_empty_unit_ids(self): - unit_ids = self.get_unit_ids() + unit_ids = self.unit_ids empty_units = unit_ids[~np.isin(unit_ids, self.get_non_empty_unit_ids())] return empty_units @@ -412,7 +433,7 @@ def get_all_spike_trains(self, outputs="unit_id"): """ Return all spike trains concatenated. - This is deprecated use sorting.to_spike_vector() instead + This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead """ warnings.warn( @@ -452,7 +473,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac Construct a unique structured numpy vector concatenating all spikes with several fields: sample_index, unit_index, segment_index. - See also `get_all_spike_trains()` Parameters ---------- diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 69e043b640..1c8661d12d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1336,15 +1336,60 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, seed=None): +def generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=20.0, + max_iteration=100, + distance_strict=False, + seed=None, +): rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") - for dim in (0, 1): - lim0 = np.min(channel_locations[:, dim]) - margin_um - lim1 = np.max(channel_locations[:, dim]) + margin_um - units_locations[:, dim] = rng.uniform(lim0, lim1, size=num_units) + + minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um + minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um + + units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units) + units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units) units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units) + if minimum_distance is not None: + solution_found = False + renew_inds = None + for i in range(max_iteration): + distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2) + inds0, inds1 = np.nonzero(distances < minimum_distance) + mask = inds0 != inds1 + inds0 = inds0[mask] + inds1 = inds1[mask] + + if inds0.size > 0: + if renew_inds is None: + renew_inds = np.unique(inds0) + else: + # random only bad ones in the previous set + renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] + + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) + units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) + units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) + else: + solution_found = True + break + + if not solution_found: + if distance_strict: + raise ValueError( + f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " + "You can use distance_strict=False or reduce minimum distance" + ) + else: + warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") + return units_locations @@ -1369,7 +1414,7 @@ def generate_ground_truth_recording( upsample_vector=None, generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), - generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), generate_templates_kwargs=dict(), dtype="float32", seed=None, diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 07a57f7807..3b8b6025ca 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -154,11 +154,8 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nd or a single sparsified waveform (template) with shape (num_samples, num_active_channels). """ - assert_msg = ( - "Waveforms must be dense to sparsify them. " - f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}" - ) - assert self.are_waveforms_dense(waveforms=waveforms), assert_msg + if self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + return waveforms non_zero_indices = self.unit_id_to_channel_indices[unit_id] sparsified_waveforms = waveforms[..., non_zero_indices] @@ -189,16 +186,20 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nda """ non_zero_indices = self.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) - assert_msg = ( - "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " - f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels." - ) - assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg + if not self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + error_message = ( + "Waveforms do not seem to be in the sparsity shape for this unit_id. The number of active channels is " + f"{num_active_channels}, but the waveform has non-zero values outsies of those active channels: \n" + f"{waveforms[..., num_active_channels:]}" + ) + raise ValueError(error_message) densified_shape = waveforms.shape[:-1] + (self.num_channels,) - densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) - densified_waveforms[..., non_zero_indices] = waveforms + densified_waveforms = np.zeros(shape=densified_shape, dtype=waveforms.dtype) + # Maps the active channels to their original indices + densified_waveforms[..., non_zero_indices] = waveforms[..., :num_active_channels] return densified_waveforms @@ -208,7 +209,21 @@ def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool: non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_active_channels = len(non_zero_indices) - return waveforms.shape[-1] == num_active_channels + + # If any channel is non-zero outside of the active channels, then the waveforms are not sparse + excess_zeros = waveforms[..., num_active_channels:].sum() + + return int(excess_zeros) == 0 + + def sparisfy_templates(self, templates_array: np.ndarray) -> np.ndarray: + max_num_active_channels = self.max_num_active_channels + sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_templates = np.zeros(shape=sparisfied_shape, dtype=templates_array.dtype) + for unit_index, unit_id in enumerate(self.unit_ids): + template = templates_array[unit_index, ...] + sparse_templates[unit_index, ...] = self.sparsify_waveforms(waveforms=template, unit_id=unit_id) + + return sparse_templates @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py new file mode 100644 index 0000000000..e6372c7082 --- /dev/null +++ b/src/spikeinterface/core/template.py @@ -0,0 +1,196 @@ +import numpy as np +import json +from dataclasses import dataclass, field, astuple +from .sparsity import ChannelSparsity + + +@dataclass +class Templates: + """ + A class to represent spike templates, which can be either dense or sparse. + + Parameters + ---------- + templates_array : np.ndarray + Array containing the templates data. + sampling_frequency : float + Sampling frequency of the templates. + nbefore : int + Number of samples before the spike peak. + sparsity_mask : np.ndarray or None, default: None + Boolean array indicating the sparsity pattern of the templates. + If `None`, the templates are considered dense. + channel_ids : np.ndarray, optional default: None + Array of channel IDs. If `None`, defaults to an array of increasing integers. + unit_ids : np.ndarray, optional default: None + Array of unit IDs. If `None`, defaults to an array of increasing integers. + check_for_consistent_sparsity : bool, optional default: None + When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the + structure fo the sparsity_masl. + + The following attributes are available after construction: + + Attributes + ---------- + num_units : int + Number of units in the templates. Automatically determined from `templates_array`. + num_samples : int + Number of samples per template. Automatically determined from `templates_array`. + num_channels : int + Number of channels in the templates. Automatically determined from `templates_array` or `sparsity_mask`. + nafter : int + Number of samples after the spike peak. Calculated as `num_samples - nbefore - 1`. + ms_before : float + Milliseconds before the spike peak. Calculated from `nbefore` and `sampling_frequency`. + ms_after : float + Milliseconds after the spike peak. Calculated from `nafter` and `sampling_frequency`. + sparsity : ChannelSparsity, optional + Object representing the sparsity pattern of the templates. Calculated from `sparsity_mask`. + If `None`, the templates are considered dense. + """ + + templates_array: np.ndarray + sampling_frequency: float + nbefore: int + + sparsity_mask: np.ndarray = None + channel_ids: np.ndarray = None + unit_ids: np.ndarray = None + + check_for_consistent_sparsity: bool = True + + num_units: int = field(init=False) + num_samples: int = field(init=False) + num_channels: int = field(init=False) + + nafter: int = field(init=False) + ms_before: float = field(init=False) + ms_after: float = field(init=False) + sparsity: ChannelSparsity = field(init=False, default=None) + + def __post_init__(self): + self.num_units, self.num_samples = self.templates_array.shape[:2] + if self.sparsity_mask is None: + self.num_channels = self.templates_array.shape[2] + else: + self.num_channels = self.sparsity_mask.shape[1] + + # Time and frames domain information + self.nafter = self.num_samples - self.nbefore + self.ms_before = self.nbefore / self.sampling_frequency * 1000 + self.ms_after = self.nafter / self.sampling_frequency * 1000 + + # Initialize sparsity object + if self.channel_ids is None: + self.channel_ids = np.arange(self.num_channels) + if self.unit_ids is None: + self.unit_ids = np.arange(self.num_units) + if self.sparsity_mask is not None: + self.sparsity = ChannelSparsity( + mask=self.sparsity_mask, + unit_ids=self.unit_ids, + channel_ids=self.channel_ids, + ) + + # Test that the templates are sparse if a sparsity mask is passed + if self.check_for_consistent_sparsity: + if not self._are_passed_templates_sparse(): + raise ValueError("Sparsity mask passed but the templates are not sparse") + + def get_dense_templates(self) -> np.ndarray: + # Assumes and object without a sparsity mask already has dense templates + if self.sparsity is None: + return self.templates_array + + densified_shape = (self.num_units, self.num_samples, self.num_channels) + dense_waveforms = np.zeros(shape=densified_shape, dtype=self.templates_array.dtype) + + for unit_index, unit_id in enumerate(self.unit_ids): + waveforms = self.templates_array[unit_index, ...] + dense_waveforms[unit_index, ...] = self.sparsity.densify_waveforms(waveforms=waveforms, unit_id=unit_id) + + return dense_waveforms + + def are_templates_sparse(self) -> bool: + return self.sparsity is not None + + def _are_passed_templates_sparse(self) -> bool: + """ + Tests if the templates passed to the init constructor are sparse + """ + are_templates_sparse = True + for unit_index, unit_id in enumerate(self.unit_ids): + waveforms = self.templates_array[unit_index, ...] + are_templates_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) + if not are_templates_sparse: + return False + + return are_templates_sparse + + def to_dict(self): + return { + "templates_array": self.templates_array, + "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask, + "channel_ids": self.channel_ids, + "unit_ids": self.unit_ids, + "sampling_frequency": self.sampling_frequency, + "nbefore": self.nbefore, + } + + @classmethod + def from_dict(cls, data): + return cls( + templates_array=np.asarray(data["templates_array"]), + sparsity_mask=None if data["sparsity_mask"] is None else np.asarray(data["sparsity_mask"]), + channel_ids=np.asarray(data["channel_ids"]), + unit_ids=np.asarray(data["unit_ids"]), + sampling_frequency=data["sampling_frequency"], + nbefore=data["nbefore"], + ) + + def to_json(self): + from spikeinterface.core.core_tools import SIJsonEncoder + + return json.dumps(self.to_dict(), cls=SIJsonEncoder) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) + + def __eq__(self, other): + """ + Necessary to compare templates because they naturally compare objects by equality of their fields + which is not possible for numpy arrays. Therefore, we override the __eq__ method to compare each numpy arrays + using np.array_equal instead + """ + if not isinstance(other, Templates): + return False + + # Convert the instances to tuples + self_tuple = astuple(self) + other_tuple = astuple(other) + + # Compare each field + for s_field, o_field in zip(self_tuple, other_tuple): + if isinstance(s_field, np.ndarray): + if not np.array_equal(s_field, o_field): + return False + + # Compare ChannelSparsity by its mask, unit_ids and channel_ids. + # Maybe ChannelSparsity should have its own __eq__ method + elif isinstance(s_field, ChannelSparsity): + if not isinstance(o_field, ChannelSparsity): + return False + + # Compare ChannelSparsity by its mask, unit_ids and channel_ids + if not np.array_equal(s_field.mask, o_field.mask): + return False + if not np.array_equal(s_field.unit_ids, o_field.unit_ids): + return False + if not np.array_equal(s_field.channel_ids, o_field.channel_ids): + return False + else: + if s_field != o_field: + return False + + return True diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index e6cefbf6b2..9e974387ff 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -22,6 +22,7 @@ ) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal +from spikeinterface.core.generate import generate_sorting if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" @@ -170,6 +171,18 @@ def test_npy_sorting(): assert_raises(Exception, sorting.register_recording, rec) +def test_rename_units_method(): + num_units = 2 + durations = [1.0, 1.0] + + sorting = generate_sorting(num_units=num_units, durations=durations) + + new_unit_ids = ["a", "b"] + new_sorting = sorting.rename_units(new_unit_ids=new_unit_ids) + + assert np.array_equal(new_sorting.get_unit_ids(), new_unit_ids) + + def test_empty_sorting(): sorting = NumpySorting.from_unit_dict({}, 30000) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 9a9c61766f..7b51abcccb 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -4,6 +4,8 @@ import numpy as np from spikeinterface.core import load_extractor, extract_waveforms + +from probeinterface import generate_multi_columns_probe from spikeinterface.core.generate import ( generate_recording, generate_sorting, @@ -289,6 +291,40 @@ def test_generate_single_fake_waveform(): # plt.show() +def test_generate_unit_locations(): + seed = 0 + + probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) + channel_locations = probe.contact_positions + + num_units = 100 + minimum_distance = 20.0 + unit_locations = generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=minimum_distance, + max_iteration=500, + distance_strict=False, + seed=seed, + ) + distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) + dist_flat = np.triu(distances, k=1).flatten() + dist_flat = dist_flat[dist_flat > 0] + assert np.all(dist_flat > minimum_distance) + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.hist(dist_flat, bins = np.arange(0, 400, 10)) + # fig, ax = plt.subplots() + # from probeinterface.plotting import plot_probe + # plot_probe(probe, ax=ax) + # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20) + # plt.show() + + def test_generate_templates(): seed = 0 @@ -297,7 +333,7 @@ def test_generate_templates(): num_units = 10 margin_um = 15.0 channel_locations = generate_channel_locations(num_chans, num_columns, 20.0) - unit_locations = generate_unit_locations(num_units, channel_locations, margin_um, seed) + unit_locations = generate_unit_locations(num_units, channel_locations, margin_um=margin_um, seed=seed) sampling_frequency = 30000.0 ms_before = 1.0 @@ -436,7 +472,8 @@ def test_generate_ground_truth_recording(): # test_noise_generator_consistency_after_dump(strategy, None) # test_generate_recording() # test_generate_single_fake_waveform() + test_generate_unit_locations() # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() - test_generate_sorting_with_spikes_on_borders() + # test_generate_sorting_with_spikes_on_borders() diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py new file mode 100644 index 0000000000..40bb3f2b34 --- /dev/null +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -0,0 +1,86 @@ +import pytest +import numpy as np +import pickle +from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import ChannelSparsity + + +def generate_test_template(template_type): + num_units = 2 + num_samples = 5 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + sampling_frequency = 30_000 + nbefore = 2 + + if template_type == "dense": + return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) + elif template_type == "sparse": # sparse with sparse templates + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + sparsity = ChannelSparsity( + mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels) + ) + + # Create sparse templates + sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels)) + for unit_index in range(num_units): + template = templates_array[unit_index, ...] + sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_index) + sparse_templates_array[unit_index, :, : sparse_template.shape[1]] = sparse_template + + return Templates( + templates_array=sparse_templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + + elif template_type == "sparse_with_dense_templates": # sparse with dense templates + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + + return Templates( + templates_array=templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + + +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) +def test_pickle_serialization(template_type, tmp_path): + template = generate_test_template(template_type) + + # Dump to pickle + pkl_path = tmp_path / "templates.pkl" + with open(pkl_path, "wb") as f: + pickle.dump(template, f) + + # Load from pickle + with open(pkl_path, "rb") as f: + template_reloaded = pickle.load(f) + + assert template == template_reloaded + + +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) +def test_json_serialization(template_type): + template = generate_test_template(template_type) + + json_str = template.to_json() + template_reloaded_from_json = Templates.from_json(json_str) + + assert template == template_reloaded_from_json + + +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) +def test_get_dense_templates(template_type): + template = generate_test_template(template_type) + dense_templates = template.get_dense_templates() + assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) + + +def test_initialization_fail_with_dense_templates(): + with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"): + template = generate_test_template(template_type="sparse_with_dense_templates") diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index c97a727340..a81d36139d 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -506,7 +506,7 @@ def get_recording_property(self, key) -> np.ndarray: def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) - def get_extension_class(self, extension_name): + def get_extension_class(self, extension_name: str): """ Get extension class from name and check if registered. @@ -525,7 +525,7 @@ def get_extension_class(self, extension_name): ext_class = extensions_dict[extension_name] return ext_class - def is_extension(self, extension_name) -> bool: + def has_extension(self, extension_name: str) -> bool: """ Check if the extension exists in memory or in the folder. @@ -556,7 +556,15 @@ def is_extension(self, extension_name) -> bool: and "params" in self._waveforms_root[extension_name].attrs.keys() ) - def load_extension(self, extension_name): + def is_extension(self, extension_name) -> bool: + warn( + "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.has_extension(extension_name) + + def load_extension(self, extension_name: str): """ Load an extension from its name. The module of the extension must be loaded and registered. @@ -572,7 +580,7 @@ def load_extension(self, extension_name): The loaded instance of the extension """ if self.folder is not None and extension_name not in self._loaded_extensions: - if self.is_extension(extension_name): + if self.has_extension(extension_name): ext_class = self.get_extension_class(extension_name) ext = ext_class.load(self.folder, self) if extension_name not in self._loaded_extensions: @@ -588,7 +596,7 @@ def delete_extension(self, extension_name) -> None: extension_name: str The extension name. """ - assert self.is_extension(extension_name), f"The extension {extension_name} is not available" + assert self.has_extension(extension_name), f"The extension {extension_name} is not available" del self._loaded_extensions[extension_name] if self.folder is not None and (self.folder / extension_name).is_dir(): shutil.rmtree(self.folder / extension_name) @@ -610,7 +618,7 @@ def get_available_extension_names(self): """ extension_names_in_folder = [] for extension_class in self.extensions: - if self.is_extension(extension_class.extension_name): + if self.has_extension(extension_class.extension_name): extension_names_in_folder.append(extension_class.extension_name) return extension_names_in_folder diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 57a5ab0166..8b14930859 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -51,7 +51,7 @@ def export_report( unit_ids = sorting.unit_ids # load or compute spike_amplitudes - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): spike_amplitudes = we.load_extension("spike_amplitudes").get_data(outputs="by_unit") elif force_computation: spike_amplitudes = compute_spike_amplitudes(we, peak_sign=peak_sign, outputs="by_unit", **job_kwargs) @@ -62,7 +62,7 @@ def export_report( ) # load or compute quality_metrics - if we.is_extension("quality_metrics"): + if we.has_extension("quality_metrics"): metrics = we.load_extension("quality_metrics").get_data() elif force_computation: metrics = compute_quality_metrics(we) @@ -73,7 +73,7 @@ def export_report( ) # load or compute correlograms - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): correlograms, bins = we.load_extension("correlograms").get_data() elif force_computation: correlograms, bins = compute_correlograms(we, window_ms=100.0, bin_ms=1.0) @@ -84,7 +84,7 @@ def export_report( ) # pre-compute unit locations if not done - if not we.is_extension("unit_locations"): + if not we.has_extension("unit_locations"): unit_locations = compute_unit_locations(we) output_folder = Path(output_folder).absolute() diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ecc5b316ec..607aa3e846 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -196,7 +196,7 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - if waveform_extractor.is_extension("similarity"): + if waveform_extractor.has_extension("similarity"): tmc = waveform_extractor.load_extension("similarity") template_similarity = tmc.get_data() else: @@ -219,7 +219,7 @@ def export_to_phy( np.save(str(output_folder / "channel_groups.npy"), channel_groups) if compute_amplitudes: - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): sac = waveform_extractor.load_extension("spike_amplitudes") amplitudes = sac.get_data(outputs="concatenated") else: @@ -231,7 +231,7 @@ def export_to_phy( np.save(str(output_folder / "amplitudes.npy"), amplitudes) if compute_pc_features: - if waveform_extractor.is_extension("principal_components"): + if waveform_extractor.has_extension("principal_components"): pc = waveform_extractor.load_extension("principal_components") else: pc = compute_principal_components( @@ -264,7 +264,7 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if waveform_extractor.is_extension("quality_metrics"): + if waveform_extractor.has_extension("quality_metrics"): qm = waveform_extractor.load_extension("quality_metrics") qm_data = qm.get_data() for column_name in qm_data.columns: diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index f7b445cdb9..010b22975c 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -566,12 +566,12 @@ def get_unit_spike_train( start_frame = 0 if end_frame is None: end_frame = np.inf - times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] + spike_times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] if self._timestamps is not None: - frames = np.searchsorted(times, self.timestamps).astype("int64") + frames = np.searchsorted(spike_times, self.timestamps).astype("int64") else: - frames = np.round(times * self._sampling_frequency).astype("int64") + frames = np.round(spike_times * self._sampling_frequency).astype("int64") return frames[(frames >= start_frame) & (frames < end_frame)] diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index 71a19f30d3..253ca2e4ce 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -1,9 +1,11 @@ from pathlib import Path +import pickle import pytest import numpy as np import h5py +from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor if hasattr(pytest, "global_test_folder"): @@ -15,7 +17,7 @@ @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_recording_s3_nwb_ros3(): +def test_recording_s3_nwb_ros3(tmp_path): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -40,9 +42,18 @@ def test_recording_s3_nwb_ros3(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" + tmp_file = tmp_path / "test_ros3_recording.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(rec, f) + + with open(tmp_file, "rb") as f: + reloaded_recording = pickle.load(f) + + check_recordings_equal(rec, reloaded_recording) + @pytest.mark.streaming_extractors -def test_recording_s3_nwb_fsspec(): +def test_recording_s3_nwb_fsspec(tmp_path): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -67,11 +78,20 @@ def test_recording_s3_nwb_fsspec(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" + tmp_file = tmp_path / "test_fsspec_recording.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(rec, f) + + with open(tmp_file, "rb") as f: + reloaded_recording = pickle.load(f) + + check_recordings_equal(rec, reloaded_recording) + @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_sorting_s3_nwb_ros3(): +def test_sorting_s3_nwb_ros3(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" # we provide the 'sampling_frequency' because the NWB file does not the electrical series sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3") @@ -90,9 +110,18 @@ def test_sorting_s3_nwb_ros3(): assert spike_train.dtype == "int64" assert np.all(spike_train >= 0) + tmp_file = tmp_path / "test_ros3_sorting.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(sort, f) + + with open(tmp_file, "rb") as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sort) + @pytest.mark.streaming_extractors -def test_sorting_s3_nwb_fsspec(): +def test_sorting_s3_nwb_fsspec(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" # we provide the 'sampling_frequency' because the NWB file does not the electrical series sort = NwbSortingExtractor( @@ -113,6 +142,15 @@ def test_sorting_s3_nwb_fsspec(): assert spike_train.dtype == "int64" assert np.all(spike_train >= 0) + tmp_file = tmp_path / "test_fsspec_sorting.pkl" + with open(tmp_file, "wb") as f: + pickle.dump(sort, f) + + with open(tmp_file, "rb") as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sort) + if __name__ == "__main__": test_recording_s3_nwb_ros3() diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index cf32e79b25..effd87007f 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -750,7 +750,7 @@ def compute_principal_components( >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ - if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name): + if load_if_exists and waveform_extractor.has_extension(WaveformPrincipalComponent.extension_name): pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) else: pc = WaveformPrincipalComponent.create(waveform_extractor) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 858af3ee08..f68081dbda 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -60,7 +60,7 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() metrics_kwargs = metrics_kwargs or dict() params = dict( - metric_names=[str(name) for name in metric_names], + metric_names=[str(name) for name in np.unique(metric_names)], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b539bbd5d4..2bef246bc2 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -143,7 +143,7 @@ def _test_extension_folder(self, we, in_memory=False): # reload as an extension from we assert self.extension_class.extension_name in we.get_available_extension_names() - assert we.is_extension(self.extension_class.extension_name) + assert we.has_extension(self.extension_class.extension_name) ext = we.load_extension(self.extension_class.extension_name) assert isinstance(ext, self.extension_class) for ext_name in self.extension_data_names: diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 49591d9b89..f5e315b18f 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -135,7 +135,7 @@ def test_project_new(self): from sklearn.decomposition import IncrementalPCA we = self.we1 - if we.is_extension("principal_components"): + if we.has_extension("principal_components"): we.delete_extension("principal_components") we_cp = we.select_units(we.unit_ids, self.cache_folder / "toy_waveforms_1seg_cp") diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 1d6947be79..172c666d62 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -153,6 +153,10 @@ def get_traces(self, start_frame, end_frame, channel_indices): filtered_traces = filtered_traces[left_margin:-right_margin, :] else: filtered_traces = filtered_traces[left_margin:, :] + + if np.issubdtype(self.dtype, np.integer): + filtered_traces = filtered_traces.round() + return filtered_traces.astype(self.dtype) diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index 79b5ba5bc3..325ce82074 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -74,6 +74,9 @@ def get_traces( filtered_fft = traces_fft * (gauss_high - gauss_low)[:, None] filtered_traces = np.real(np.fft.ifft(filtered_fft, axis=0)) + if np.issubdtype(dtype, np.integer): + filtered_traces = filtered_traces.round() + if right_margin > 0: return filtered_traces[left_margin:-right_margin, :].astype(dtype) else: diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 03afada380..f24aff6e79 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -20,6 +20,10 @@ def __init__(self, parent_recording_segment, gain, offset, dtype): def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) scaled_traces = traces * self.gain[:, channel_indices] + self.offset[:, channel_indices] + + if np.issubdtype(self._dtype, np.integer): + scaled_traces = scaled_traces.round() + return scaled_traces.astype(self._dtype) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 570ce48a5d..0734dad784 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -103,6 +103,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces_shift = traces_shift[left_margin:-right_margin, :] if self.tmp_dtype is not None: + if np.issubdtype(self.dtype, np.integer): + traces_shift = traces_shift.round() traces_shift = traces_shift.astype(self.dtype) return traces_shift diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 29617313cf..0edd52cce0 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -201,7 +201,7 @@ def compute_snrs( snrs : dict Computed signal to noise ratio for each unit. """ - if waveform_extractor.is_extension("noise_levels"): + if waveform_extractor.has_extension("noise_levels"): noise_levels = waveform_extractor.load_extension("noise_levels").get_data() else: if random_chunk_kwargs_dict is None: @@ -687,7 +687,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension(amplitude_extension): + if waveform_extractor.has_extension(amplitude_extension): sac = waveform_extractor.load_extension(amplitude_extension) amps = sac.get_data(outputs="concatenated") if amplitude_extension == "spike_amplitudes": @@ -803,7 +803,7 @@ def compute_amplitude_cutoffs( spike_amplitudes = None invert_amplitudes = False - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") if amp_calculator._params["peak_sign"] == "pos": @@ -881,7 +881,7 @@ def compute_amplitude_medians(waveform_extractor, peak_sign="neg", unit_ids=None extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign) spike_amplitudes = None - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") @@ -974,7 +974,7 @@ def compute_drift_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension("spike_locations"): + if waveform_extractor.has_extension("spike_locations"): locs_calculator = waveform_extractor.load_extension("spike_locations") spike_locations = locs_calculator.get_data(outputs="concatenated") spike_locations_by_unit = locs_calculator.get_data(outputs="by_unit") diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 53309db282..54b1027305 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -42,14 +42,14 @@ def _set_params( if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list - if self.waveform_extractor.is_extension("principal_components"): + if self.waveform_extractor.has_extension("principal_components"): # by default 'nearest_neightbor' is removed because too slow pc_metrics = _possible_pc_metric_names.copy() pc_metrics.remove("nn_isolation") pc_metrics.remove("nn_noise_overlap") metric_names += pc_metrics # if spike_locations are not available, drift is removed from the list - if not self.waveform_extractor.is_extension("spike_locations"): + if not self.waveform_extractor.has_extension("spike_locations"): if "drift" in metric_names: metric_names.remove("drift") @@ -61,7 +61,7 @@ def _set_params( qm_params_[k]["peak_sign"] = peak_sign params = dict( - metric_names=[str(name) for name in metric_names], + metric_names=[str(name) for name in np.unique(metric_names)], sparsity=sparsity, peak_sign=peak_sign, seed=seed, @@ -130,7 +130,7 @@ def _run(self, verbose, **job_kwargs): # metrics based on PCs pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self._params["skip_pc_metrics"]: - if not self.waveform_extractor.is_extension("principal_components"): + if not self.waveform_extractor.has_extension("principal_components"): raise ValueError("waveform_principal_component must be provied") pc_extension = self.waveform_extractor.load_extension("principal_components") pc_metrics = calculate_pc_metrics( @@ -216,7 +216,7 @@ def compute_quality_metrics( metrics: pandas.DataFrame Data frame with the computed metrics """ - if load_if_exists and waveform_extractor.is_extension(QualityMetricCalculator.extension_name): + if load_if_exists and waveform_extractor.has_extension(QualityMetricCalculator.extension_name): qmc = waveform_extractor.load_extension(QualityMetricCalculator.extension_name) else: qmc = QualityMetricCalculator(waveform_extractor) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index eb8317e4df..b601e5d6d8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -88,7 +88,7 @@ def test_metrics(self): we = self.we_long # avoid NaNs - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): we.delete_extension("spike_amplitudes") # without PC diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a16b642dd5..1d4f04a382 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,18 +21,25 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, + "waveforms": { + "max_spikes_per_unit": 200, + "overwrite": True, + "sparse": True, + "method": "energy", + "threshold": 0.25, + }, "filtering": {"freq_min": 150, "dtype": "float32"}, - "detection": {"peak_sign": "neg", "detect_threshold": 5}, + "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "localization": {}, - "clustering": {}, + "clustering": {"legacy": False}, "matching": {}, "apply_preprocessing": True, "shared_memory": True, "job_kwargs": {"n_jobs": -1}, } + handle_multi_segment = True + @classmethod def get_sorter_version(cls): return "2.0" @@ -64,6 +71,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f = common_reference(recording_f) else: recording_f = recording + recording_f.annotate(is_filtered=True) # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") @@ -109,8 +117,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["tmp_folder"] = sorter_output_folder / "clustering" clustering_params.update({"noise_levels": noise_levels}) + if "legacy" in clustering_params: + legacy = clustering_params.pop("legacy") + else: + legacy = False + + if legacy: + clustering_method = "circus" + else: + clustering_method = "random_projections" + labels, peak_labels = find_cluster_from_peaks( - recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params + recording_f, selected_peaks, method=clustering_method, method_kwargs=clustering_params ) ## We get the labels for our peaks @@ -138,13 +156,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): waveforms_folder = sorter_output_folder / "waveforms" we = extract_waveforms( - recording_f, sorting, waveforms_folder, mode=mode, **waveforms_params, return_scaled=False + recording_f, + sorting, + waveforms_folder, + return_scaled=False, + precompute_template=["median"], + mode=mode, + **waveforms_params, ) ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces matching_params = params["matching"].copy() matching_params["waveform_extractor"] = we - matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index e256915fa6..6d53414c9f 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -9,12 +9,14 @@ NumpySorting, get_channel_distances, ) -from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer + from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel + import numpy as np import pickle @@ -50,6 +52,8 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "save_array": True, } + handle_multi_segment = True + @classmethod def get_sorter_version(cls): return "2.0" @@ -113,9 +117,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("We kept %d peaks for clustering" % len(peaks)) + ms_before = params["waveforms"]["ms_before"] + ms_after = params["waveforms"]["ms_after"] + # SVD for time compression few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) - few_wfs = extract_waveform_at_max_channel(recording, few_peaks, **job_kwargs) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) wfs = few_wfs[:, :, 0] tsvd = TruncatedSVD(params["svd"]["n_components"]) @@ -127,8 +136,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): with open(model_folder / "pca_model.pkl", "wb") as f: pickle.dump(tsvd, f) - ms_before = params["waveforms"]["ms_before"] - ms_after = params["waveforms"]["ms_after"] model_params = { "ms_before": ms_before, "ms_after": ms_after, @@ -317,39 +324,3 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorter_output_folder / "sorting") return sorting - - -def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): - """ - Helper function to extractor waveforms at max channel from a peak list - - - """ - n = rec.get_num_channels() - unit_ids = np.arange(n, dtype="int64") - sparsity_mask = np.eye(n, dtype="bool") - - spikes = np.zeros( - peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - ) - spikes["sample_index"] = peaks["sample_index"] - spikes["unit_index"] = peaks["channel_index"] - spikes["segment_index"] = peaks["segment_index"] - - nbefore = int(ms_before * rec.sampling_frequency / 1000.0) - nafter = int(ms_after * rec.sampling_frequency / 1000.0) - - all_wfs = extract_waveforms_to_single_buffer( - rec, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - sparsity_mask=sparsity_mask, - copy=True, - **job_kwargs, - ) - - return all_wfs diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 39f46475dc..238b16260c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -12,13 +12,23 @@ HAVE_HDBSCAN = False import random, string, os -from spikeinterface.core import get_global_tmp_folder, get_noise_levels, get_channel_distances +from spikeinterface.core import get_global_tmp_folder, get_channel_distances from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms -from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks +from spikeinterface.sortingcomponents.peak_selection import select_peaks +from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection +from sklearn.decomposition import TruncatedSVD +import pickle, json +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, +) +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel class CircusClustering: @@ -27,174 +37,128 @@ class CircusClustering: """ _default_params = { - "peak_locations": None, - "peak_localization_kwargs": {"method": "center_of_mass"}, "hdbscan_kwargs": { - "min_cluster_size": 50, + "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1, - "cluster_selection_method": "leaf", + "cluster_selection_method": "eom", }, "cleaning_kwargs": {}, - "tmp_folder": None, + "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, - "n_pca": 10, - "max_spikes_per_unit": 200, - "ms_before": 1.5, - "ms_after": 2.5, - "cleaning_method": "dip", - "waveform_mode": "memmap", - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"}, + "selection_method": "closest_to_centroid", + "n_svd": [5, 10], + "ms_before": 1, + "ms_after": 1, + "random_seed": 42, + "shared_memory": True, + "tmp_folder": None, + "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } - @classmethod - def _check_params(cls, recording, peaks, params): - d = params - params2 = params.copy() - - tmp_folder = params["tmp_folder"] - if params["waveform_mode"] == "memmap": - if tmp_folder is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = Path(os.path.join(get_global_tmp_folder(), name)) - else: - tmp_folder = Path(tmp_folder) - tmp_folder.mkdir() - params2["tmp_folder"] = tmp_folder - elif params["waveform_mode"] == "shared_memory": - assert tmp_folder is None, "tmp_folder must be None for shared_memory" - else: - raise ValueError("'waveform_mode' must be 'memmap' or 'shared_memory'") - - return params2 - @classmethod def main_function(cls, recording, peaks, params): - assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" + assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed" - params = cls._check_params(recording, peaks, params) - d = params - - if d["peak_locations"] is None: - from spikeinterface.sortingcomponents.peak_localization import localize_peaks - - peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **d["job_kwargs"]) - else: - peak_locations = d["peak_locations"] + if "n_jobs" in params["job_kwargs"]: + if params["job_kwargs"]["n_jobs"] == -1: + params["job_kwargs"]["n_jobs"] = os.cpu_count() - tmp_folder = d["tmp_folder"] - if tmp_folder is not None: - tmp_folder.mkdir(exist_ok=True) + if "core_dist_n_jobs" in params["hdbscan_kwargs"]: + if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1: + params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count() - location_keys = ["x", "y"] - locations = np.stack([peak_locations[k] for k in location_keys], axis=1) - - chan_locs = recording.get_channel_locations() + d = params + verbose = d["job_kwargs"]["verbose"] peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - spikes = np.zeros(peaks.size, dtype=peak_dtype) - spikes["sample_index"] = peaks["sample_index"] - spikes["segment_index"] = peaks["segment_index"] - spikes["unit_index"] = peaks["channel_index"] + fs = recording.get_sampling_frequency() + ms_before = params["ms_before"] + ms_after = params["ms_after"] + nbefore = int(ms_before * fs / 1000.0) + nafter = int(ms_after * fs / 1000.0) + num_samples = nbefore + nafter num_chans = recording.get_num_channels() - sparsity_mask = np.zeros((peaks.size, num_chans), dtype="bool") - - unit_inds = range(num_chans) - chan_distances = get_channel_distances(recording) + np.random.seed(d["random_seed"]) - for main_chan in unit_inds: - (closest_chans,) = np.nonzero(chan_distances[main_chan, :] <= params["radius_um"]) - sparsity_mask[main_chan, closest_chans] = True - - if params["waveform_mode"] == "shared_memory": - wf_folder = None + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name else: - assert params["tmp_folder"] is not None, "tmp_folder must be supplied" - wf_folder = params["tmp_folder"] / "sparse_snippets" - wf_folder.mkdir() + tmp_folder = Path(params["tmp_folder"]).absolute() - fs = recording.get_sampling_frequency() - nbefore = int(params["ms_before"] * fs / 1000.0) - nafter = int(params["ms_after"] * fs / 1000.0) - num_samples = nbefore + nafter + tmp_folder.mkdir(parents=True, exist_ok=True) - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_inds, - nbefore, - nafter, - mode=params["waveform_mode"], - return_scaled=False, - folder=wf_folder, - dtype=recording.get_dtype(), - sparsity_mask=sparsity_mask, - copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], + # SVD for time compression + few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"] ) - n_loc = len(location_keys) - import sklearn.decomposition, hdbscan + wfs = few_wfs[:, :, 0] + tsvd = TruncatedSVD(params["n_svd"][0]) + tsvd.fit(wfs) - noise_levels = get_noise_levels(recording, return_scaled=False) - - nb_clusters = 0 - peak_labels = np.zeros(len(spikes), dtype=np.int32) + model_folder = tmp_folder / "tsvd_model" - noise = get_random_data_chunks( - recording, - return_scaled=False, - num_chunks_per_segment=params["max_spikes_per_unit"], - chunk_size=nbefore + nafter, - concatenated=False, - seed=None, - ) - noise = np.stack(noise, axis=0) + model_folder.mkdir(exist_ok=True) + with open(model_folder / "pca_model.pkl", "wb") as f: + pickle.dump(tsvd, f) - for main_chan, waveforms in wfs_arrays.items(): - idx = np.where(spikes["unit_index"] == main_chan)[0] - (channels,) = np.nonzero(sparsity_mask[main_chan]) - sub_noise = noise[:, :, channels] + model_params = { + "ms_before": ms_before, + "ms_after": ms_after, + "sampling_frequency": float(fs), + } - if len(waveforms) > 0: - sub_waveforms = waveforms + with open(model_folder / "params.json", "w") as f: + json.dump(model_params, f) - wfs = np.swapaxes(sub_waveforms, 1, 2).reshape(len(sub_waveforms), -1) - noise_wfs = np.swapaxes(sub_noise, 1, 2).reshape(len(sub_noise), -1) + # features + features_folder = model_folder / "features" + node0 = PeakRetriever(recording, peaks) - n_pca = min(d["n_pca"], len(wfs)) - pca = sklearn.decomposition.PCA(n_pca) - - hdbscan_data = np.vstack((wfs, noise_wfs)) - - pca.fit(wfs) - hdbscan_data_pca = pca.transform(hdbscan_data) - clustering = hdbscan.hdbscan(hdbscan_data_pca, **d["hdbscan_kwargs"]) - - noise_labels = clustering[0][len(wfs) :] - valid_labels = clustering[0][: len(wfs)] - - shared_indices = np.intersect1d(np.unique(noise_labels), np.unique(valid_labels)) - for l in shared_indices: - idx_noise = noise_labels == l - idx_valid = valid_labels == l - if np.sum(idx_noise) > np.sum(idx_valid): - valid_labels[idx_valid] = -1 + radius_um = params["radius_um"] + node1 = ExtractSparseWaveforms( + recording, + parents=[node0], + return_output=False, + ms_before=ms_before, + ms_after=ms_after, + radius_um=radius_um, + ) - if np.unique(valid_labels).min() == -1: - valid_labels += 1 + node2 = TemporalPCAProjection( + recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder + ) - for l in np.unique(valid_labels): - idx_valid = valid_labels == l - if np.sum(idx_valid) < d["hdbscan_kwargs"]["min_cluster_size"]: - valid_labels[idx_valid] = -1 + pipeline_nodes = [node0, node1, node2] - peak_labels[idx] = valid_labels + nb_clusters + all_pc_data = run_node_pipeline( + recording, + pipeline_nodes, + params["job_kwargs"], + job_name="extracting features", + ) - labels = np.unique(valid_labels) - labels = labels[labels >= 0] - nb_clusters += len(labels) + peak_labels = -1 * np.ones(len(peaks), dtype=int) + nb_clusters = 0 + for c in np.unique(peaks["channel_index"]): + mask = peaks["channel_index"] == c + tsvd = TruncatedSVD(params["n_svd"][1]) + sub_data = all_pc_data[mask] + hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) + try: + clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) + local_labels = clustering[0] + except Exception: + local_labels = -1 * np.ones(len(hdbscan_data)) + valid_clusters = local_labels > -1 + if np.sum(valid_clusters) > 0: + local_labels[valid_clusters] += nb_clusters + peak_labels[mask] = local_labels + nb_clusters += len(np.unique(local_labels[valid_clusters])) labels = np.unique(peak_labels) labels = labels[labels >= 0] @@ -202,11 +166,22 @@ def main_function(cls, recording, peaks, params): best_spikes = {} nb_spikes = 0 + import sklearn + all_indices = np.arange(0, peak_labels.size) + max_spikes = params["waveforms"]["max_spikes_per_unit"] + selection_method = params["selection_method"] + for unit_ind in labels: mask = peak_labels == unit_ind - best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[: params["max_spikes_per_unit"]] + if selection_method == "closest_to_centroid": + data = all_pc_data[mask].reshape(np.sum(mask), -1) + centroid = np.median(data, axis=0) + distances = sklearn.metrics.pairwise_distances(centroid[np.newaxis, :], data)[0] + best_spikes[unit_ind] = all_indices[mask][np.argsort(distances)[:max_spikes]] + elif selection_method == "random": + best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[:max_spikes] nb_spikes += best_spikes[unit_ind].size spikes = np.zeros(nb_spikes, dtype=peak_dtype) @@ -222,72 +197,58 @@ def main_function(cls, recording, peaks, params): spikes["segment_index"] = peaks[mask]["segment_index"] spikes["unit_index"] = peak_labels[mask] - if params["waveform_mode"] == "shared_memory": - wf_folder = None + if verbose: + print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) + + sorting_folder = tmp_folder / "sorting" + unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) + sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) + + if params["shared_memory"]: + waveform_folder = None + mode = "memory" else: - assert params["tmp_folder"] is not None, "tmp_folder must be supplied" - wf_folder = params["tmp_folder"] / "dense_snippets" - wf_folder.mkdir() - - cleaning_method = params["cleaning_method"] - - print(f"We found {len(labels)} raw clusters, starting to clean with {cleaning_method}...") - - if cleaning_method == "cosine": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode=params["waveform_mode"], - return_scaled=False, - folder=wf_folder, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates( - wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] - ) - - elif cleaning_method == "dip": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode=params["waveform_mode"], - return_scaled=False, - folder=wf_folder, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates_via_dip(wfs_arrays, peak_labels) - - elif cleaning_method == "matching": - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = Path(os.path.join(get_global_tmp_folder(), name)) - - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) - we = extract_waveforms( - recording, - sorting, - tmp_folder, - overwrite=True, - ms_before=params["ms_before"], - ms_after=params["ms_after"], - **params["job_kwargs"], - ) - labels, peak_labels = remove_duplicates_via_matching(we, peak_labels, job_kwargs=params["job_kwargs"]) + waveform_folder = tmp_folder / "waveforms" + mode = "folder" + sorting = sorting.save(folder=sorting_folder) + + we = extract_waveforms( + recording, + sorting, + waveform_folder, + return_scaled=False, + precompute_template=["median"], + mode=mode, + **params["job_kwargs"], + **params["waveforms"], + ) + + cleaning_matching_params = params["job_kwargs"].copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in cleaning_matching_params: + cleaning_matching_params.pop(value) + cleaning_matching_params["chunk_duration"] = "100ms" + cleaning_matching_params["n_jobs"] = 1 + cleaning_matching_params["verbose"] = False + cleaning_matching_params["progress_bar"] = False + + cleaning_params = params["cleaning_kwargs"].copy() + cleaning_params["tmp_folder"] = tmp_folder + + labels, peak_labels = remove_duplicates_via_matching( + we, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + ) + + del we, sorting + + if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) + else: + if not params["shared_memory"]: + shutil.rmtree(tmp_folder / "waveforms") + shutil.rmtree(tmp_folder / "sorting") - print(f"We kept {len(labels)} non-duplicated clusters...") + if verbose: + print("We kept %d non-duplicated clusters..." % len(labels)) return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index b4938717f8..629b0b13ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -534,7 +534,6 @@ def remove_duplicates( def remove_duplicates_via_matching( waveform_extractor, - noise_levels, peak_labels, method_kwargs={}, job_kwargs={}, @@ -542,7 +541,6 @@ def remove_duplicates_via_matching( method="circus-omp-svd", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates - from spikeinterface import get_noise_levels from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms @@ -595,14 +593,7 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() - local_params.update( - { - "waveform_extractor": waveform_extractor, - "noise_levels": noise_levels, - "amplitudes": [0.95, 1.05], - "omp_min_sps": 0.05, - } - ) + local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "omp_min_sps": 1e-3}) spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 72acd49f4f..dcb84cb6ff 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -12,7 +12,7 @@ HAVE_HDBSCAN = False import random, string, os -from spikeinterface.core import get_global_tmp_folder, get_noise_levels, get_channel_distances, get_random_data_chunks +from spikeinterface.core import get_global_tmp_folder, get_channel_distances, get_random_data_chunks from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip @@ -48,7 +48,6 @@ class RandomProjectionClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, "shared_memory": True, "tmp_folder": None, @@ -77,12 +76,6 @@ def main_function(cls, recording, peaks, params): nafter = int(params["ms_after"] * fs / 1000.0) num_samples = nbefore + nafter num_chans = recording.get_num_channels() - - if d["noise_levels"] is None: - noise_levels = get_noise_levels(recording, return_scaled=False) - else: - noise_levels = d["noise_levels"] - np.random.seed(d["random_seed"]) if params["tmp_folder"] is None: @@ -113,32 +106,12 @@ def main_function(cls, recording, peaks, params): nafter = int(params["ms_after"] * fs / 1000) nsamples = nbefore + nafter - import scipy - - x = np.random.randn(100, nsamples, num_chans).astype(np.float32) - x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1) - - ptps = np.ptp(x, axis=1) - a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000)) - ydata = np.cumsum(a) / a.sum() - xdata = b[1:] - - from scipy.optimize import curve_fit - - def sigmoid(x, L, x0, k, b): - y = L / (1 + np.exp(-k * (x - x0))) + b - return y - - p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess - popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) - node3 = RandomProjectionsFeature( recording, parents=[node0, node2], return_output=True, projections=projections, radius_um=params["radius_um"], - sigmoid=None, sparse=True, ) @@ -219,10 +192,11 @@ def sigmoid(x, L, x0, k, b): recording, sorting, waveform_folder, - **params["job_kwargs"], - **params["waveforms"], return_scaled=False, mode=mode, + precompute_template=["median"], + **params["job_kwargs"], + **params["waveforms"], ) cleaning_matching_params = params["job_kwargs"].copy() @@ -238,7 +212,7 @@ def sigmoid(x, L, x0, k, b): cleaning_params["tmp_folder"] = tmp_folder labels, peak_labels = remove_duplicates_via_matching( - we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + we, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) del we, sorting diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ea36b75847..cfdca6f612 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -495,24 +495,22 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): amplitude: tuple (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float - Stopping criteria of the OMP algorithm, in percentage of the norm - noise_levels: array - The noise levels, for every channels. If None, they will be automatically - computed - random_chunk_kwargs: dict - Parameters for computing noise levels, if not provided (sub optimal) + Stopping criteria of the OMP algorithm, as relative error sparse_kwargs: dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. + rank: int, default: 5 + Number of components used internally by the SVD + vicinity: int + Size of the area surrounding a spike to perform modification (expressed in terms + of template temporal width) ----- """ _default_params = { - "amplitudes": [0.6, 2], - "omp_min_sps": 0.1, + "amplitudes": [0.6, 1.4], + "omp_min_sps": 5e-5, "waveform_extractor": None, - "random_chunk_kwargs": {}, - "noise_levels": None, "rank": 5, "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], @@ -612,10 +610,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() d["vicinity"] *= d["num_samples"] - if d["noise_levels"] is None: - print("CircusOMPPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - if "templates" not in d: d = cls._prepare_templates(d) else: @@ -638,10 +632,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) - omp_min_sps = d["omp_min_sps"] - # d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) - d["stop_criteria"] = omp_min_sps * np.maximum(d["norms"], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) - + d["stop_criteria"] = d["omp_min_sps"] return d @classmethod @@ -675,7 +666,7 @@ def main_function(cls, traces, d): neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d["amplitudes"] ignored_ids = d["ignored_ids"] - stop_criteria = d["stop_criteria"][:, np.newaxis] + stop_criteria = d["stop_criteria"] vicinity = d["vicinity"] rank = d["rank"] @@ -687,13 +678,13 @@ def main_function(cls, traces, d): # Filter using overlap-and-add convolution if len(ignored_ids) > 0: - mask = ~np.isin(np.arange(num_templates), ignored_ids) - spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"][:, mask, :] + not_ignored = ~np.isin(np.arange(num_templates), ignored_ids) + spatially_filtered_data = np.matmul(d["spatial"][:, not_ignored, :], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"][:, not_ignored, :] objective_by_rank = scipy.signal.oaconvolve( - scaled_filtered_data, d["temporal"][:, mask, :], axes=2, mode="valid" + scaled_filtered_data, d["temporal"][:, not_ignored, :], axes=2, mode="valid" ) - scalar_products[mask] += np.sum(objective_by_rank, axis=0) + scalar_products[not_ignored] += np.sum(objective_by_rank, axis=0) scalar_products[ignored_ids] = -np.inf else: spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) @@ -704,7 +695,6 @@ def main_function(cls, traces, d): num_spikes = 0 spikes = np.empty(scalar_products.size, dtype=spike_dtype) - idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) M = np.zeros((num_templates, num_templates), dtype=np.float32) @@ -717,13 +707,17 @@ def main_function(cls, traces, d): neighbors = {} cached_overlaps = {} - is_valid = scalar_products > stop_criteria all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) + if len(ignored_ids) > 0: + new_error = np.linalg.norm(scalar_products[not_ignored]) + else: + new_error = np.linalg.norm(scalar_products) + delta_error = np.inf - while np.any(is_valid): - best_amplitude_ind = scalar_products[is_valid].argmax() - best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) + while delta_error > stop_criteria: + best_amplitude_ind = scalar_products.argmax() + best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape) if num_selection > 0: delta_t = selection[1] - peak_index @@ -818,7 +812,12 @@ def main_function(cls, traces, d): to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add - is_valid = scalar_products > stop_criteria + previous_error = new_error + if len(ignored_ids) > 0: + new_error = np.linalg.norm(scalar_products[not_ignored]) + else: + new_error = np.linalg.norm(scalar_products) + delta_error = np.abs(new_error / previous_error - 1) is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index cd9226d5e8..328e3b715d 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -1,6 +1,7 @@ import numpy as np from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever +from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer def make_multi_method_doc(methods, ident=" "): @@ -18,23 +19,53 @@ def make_multi_method_doc(methods, ident=" "): return doc -def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): - # TODO for Pierre: this function is really inefficient because it runs a full pipeline only for a few - # spikes, which means that all traces need to be accesses! Please find a better way - nb_peaks = min(len(peaks), nb_peaks) - idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) - peak_retriever = PeakRetriever(recording, peaks[idx]) - - sparse_waveforms = ExtractSparseWaveforms( - recording, - parents=[peak_retriever], - ms_before=ms_before, - ms_after=ms_after, - return_output=True, - radius_um=5, +def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): + """ + Helper function to extract waveforms at the max channel from a peak list + + + """ + n = rec.get_num_channels() + unit_ids = np.arange(n, dtype="int64") + sparsity_mask = np.eye(n, dtype="bool") + + spikes = np.zeros( + peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + ) + spikes["sample_index"] = peaks["sample_index"] + spikes["unit_index"] = peaks["channel_index"] + spikes["segment_index"] = peaks["segment_index"] + + nbefore = int(ms_before * rec.sampling_frequency / 1000.0) + nafter = int(ms_after * rec.sampling_frequency / 1000.0) + + all_wfs = extract_waveforms_to_single_buffer( + rec, + spikes, + unit_ids, + nbefore, + nafter, + mode="shared_memory", + return_scaled=False, + sparsity_mask=sparsity_mask, + copy=True, + **job_kwargs, ) - nbefore = sparse_waveforms.nbefore - waveforms = run_node_pipeline(recording, [peak_retriever, sparse_waveforms], job_kwargs=job_kwargs) + return all_wfs + + +def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): + if peaks.size > nb_peaks: + idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) + some_peaks = peaks[idx] + else: + some_peaks = peaks + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + + waveforms = extract_waveform_at_max_channel( + recording, some_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) prototype = np.median(waveforms[:, :, 0] / (waveforms[:, nbefore, 0][:, np.newaxis]), axis=0) return prototype diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index a5d3cb2429..6ff837065b 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -107,7 +107,7 @@ def check_extensions(waveform_extractor, extensions): error_msg = "" raise_error = False for extension in extensions: - if not waveform_extractor.is_extension(extension): + if not waveform_extractor.has_extension(extension): raise_error = True error_msg += ( f"The {extension} waveform extension is required for this widget. " diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 35fde07326..aa280ad658 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -80,13 +80,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 3 - if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): + if we.has_extension("correlograms") or we.has_extension("spike_amplitudes"): ncols += 1 - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) - if we.is_extension("unit_locations"): + if we.has_extension("unit_locations"): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( @@ -129,7 +129,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) ax3.set_ylabel(None) - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): ax4 = fig.add_subplot(gs[:2, 3]) AutoCorrelogramsWidget( we, @@ -142,7 +142,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_title(None) ax4.set_yticks([]) - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): ax5 = fig.add_subplot(gs[2, :3]) ax6 = fig.add_subplot(gs[2, 3]) axes = np.array([ax5, ax6])