From 30d1ecce4249a3e645ca09be39799277186e11c6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Sep 2023 11:47:37 +0200 Subject: [PATCH 01/33] Allow to postprocess on read-only waveform folders --- src/spikeinterface/core/waveform_extractor.py | 55 ++++++++++--------- .../tests/common_extension_tests.py | 23 +++++++- .../postprocessing/unit_localization.py | 4 +- 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 877c9fb00c..e404e74be4 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -4,6 +4,7 @@ import shutil from typing import Iterable, Literal, Optional import json +import os import numpy as np from copy import deepcopy @@ -87,6 +88,7 @@ def __init__( self._template_cache = {} self._params = {} self._loaded_extensions = dict() + self._is_read_only = False self.sparsity = sparsity self.folder = folder @@ -103,6 +105,8 @@ def __init__( if (self.folder / "params.json").is_file(): with open(str(self.folder / "params.json"), "r") as f: self._params = json.load(f) + if not os.access(self.folder, os.W_OK): + self._is_read_only = True else: # this is in case of in-memory self.format = "memory" @@ -399,6 +403,9 @@ def return_scaled(self) -> bool: def dtype(self): return self._params["dtype"] + def is_read_only(self) -> bool: + return self._is_read_only + def has_recording(self) -> bool: return self._recording is not None @@ -514,18 +521,8 @@ def is_extension(self, extension_name) -> bool: exists: bool Whether the extension exists or not """ - if self.folder is None: - return extension_name in self._loaded_extensions - else: - if self.format == "binary": - return (self.folder / extension_name).is_dir() and ( - self.folder / extension_name / "params.json" - ).is_file() - elif self.format == "zarr": - return ( - extension_name in self._waveforms_root.keys() - and "params" in self._waveforms_root[extension_name].attrs.keys() - ) + # Extensions are always loaded in memory + return extension_name in self._loaded_extensions def load_extension(self, extension_name): """ @@ -1735,20 +1732,28 @@ def __init__(self, waveform_extractor): self.waveform_extractor = waveform_extractor if self.waveform_extractor.folder is not None: - self.folder = self.waveform_extractor.folder - self.format = self.waveform_extractor.format - if self.format == "binary": - self.extension_folder = self.folder / self.extension_name - if not self.extension_folder.is_dir(): - self.extension_folder.mkdir() - else: - import zarr - - zarr_root = zarr.open(self.folder, mode="r+") - if self.extension_name not in zarr_root.keys(): - self.extension_group = zarr_root.create_group(self.extension_name) + if not self.waveform_extractor.is_read_only(): + self.folder = self.waveform_extractor.folder + self.format = self.waveform_extractor.format + if self.format == "binary": + self.extension_folder = self.folder / self.extension_name + if not self.extension_folder.is_dir(): + self.extension_folder.mkdir() else: - self.extension_group = zarr_root[self.extension_name] + import zarr + + zarr_root = zarr.open(self.folder, mode="r+") + if self.extension_name not in zarr_root.keys(): + self.extension_group = zarr_root.create_group(self.extension_name) + else: + self.extension_group = zarr_root[self.extension_name] + else: + warn( + "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None else: self.format = "memory" self.extension_folder = None diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b9c72f9b99..f44d58470c 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -4,7 +4,7 @@ import shutil from pathlib import Path -from spikeinterface import extract_waveforms, load_extractor, compute_sparsity +from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity from spikeinterface.extractors import toy_example if hasattr(pytest, "global_test_folder"): @@ -76,6 +76,15 @@ def setUp(self): overwrite=True, ) self.we2 = we2 + + # make we read-only + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + if not we_ro_folder.is_dir(): + shutil.copytree(we2.folder, we_ro_folder) + # change permissions (R+X) + we_ro_folder.chmod(0o555) + self.we_ro = load_waveforms(we_ro_folder) + self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30) we_memory = extract_waveforms( recording, @@ -97,6 +106,11 @@ def setUp(self): folder=cache_folder / "toy_sorting_2seg_sparse", format="binary", sparsity=sparsity, overwrite=True ) + def tearDown(self): + # allow pytest to delete RO folder + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + we_ro_folder.chmod(0o777) + def _test_extension_folder(self, we, in_memory=False): if self.extension_function_kwargs_list is None: extension_function_kwargs_list = [dict()] @@ -177,3 +191,10 @@ def test_extension(self): assert ext_data_mem.equals(ext_data_zarr) else: print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.") + + # read-only - Extension is memory only + _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) + assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() + ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) + assert ext_ro.format == "memory" + assert ext_ro.extension_folder is None diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 740fdd234b..d2739f69dd 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -570,6 +570,8 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals def get_grid_convolution_templates_and_weights( contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 ): + import sklearn.metrics + x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max() y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max() @@ -593,8 +595,6 @@ def get_grid_convolution_templates_and_weights( template_positions[:, 0] = all_x.flatten() template_positions[:, 1] = all_y.flatten() - import sklearn - # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) nearest_template_mask = dist < radius_um From b8ee13c208cf928573595d941803b11e38278eb0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Sep 2023 15:02:13 +0200 Subject: [PATCH 02/33] Restore extension loading --- src/spikeinterface/core/waveform_extractor.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index e404e74be4..6083732c11 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -521,8 +521,22 @@ def is_extension(self, extension_name) -> bool: exists: bool Whether the extension exists or not """ - # Extensions are always loaded in memory - return extension_name in self._loaded_extensions + if self.folder is None: + return extension_name in self._loaded_extensions + else: + # Extensions already loaded in memory + if extension_name in self._loaded_extensions: + return True + else: + if self.format == "binary": + return (self.folder / extension_name).is_dir() and ( + self.folder / extension_name / "params.json" + ).is_file() + elif self.format == "zarr": + return ( + extension_name in self._waveforms_root.keys() + and "params" in self._waveforms_root[extension_name].attrs.keys() + ) def load_extension(self, extension_name): """ From def525c20a463b625c2f014fd5a84be4f79a00ef Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 Sep 2023 15:38:06 +0200 Subject: [PATCH 03/33] handle re-loading correctly --- src/spikeinterface/core/waveform_extractor.py | 140 ++++++++++-------- 1 file changed, 77 insertions(+), 63 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6083732c11..39d115e22c 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1746,28 +1746,39 @@ def __init__(self, waveform_extractor): self.waveform_extractor = waveform_extractor if self.waveform_extractor.folder is not None: - if not self.waveform_extractor.is_read_only(): - self.folder = self.waveform_extractor.folder - self.format = self.waveform_extractor.format - if self.format == "binary": - self.extension_folder = self.folder / self.extension_name - if not self.extension_folder.is_dir(): + self.folder = self.waveform_extractor.folder + self.format = self.waveform_extractor.format + if self.format == "binary": + self.extension_folder = self.folder / self.extension_name + if not self.extension_folder.is_dir(): + if not self.waveform_extractor.is_read_only(): self.extension_folder.mkdir() - else: - import zarr + else: + raise Exception( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + import zarr - zarr_root = zarr.open(self.folder, mode="r+") - if self.extension_name not in zarr_root.keys(): + mode = "r+" if not self.waveform_extractor.is_read_only() else "r" + zarr_root = zarr.open(self.folder, mode=mode) + if self.extension_name not in zarr_root.keys(): + if not self.waveform_extractor.is_read_only(): self.extension_group = zarr_root.create_group(self.extension_name) else: - self.extension_group = zarr_root[self.extension_name] - else: - warn( - "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." - ) - self.format = "memory" - self.extension_folder = None - self.folder = None + raise Exception( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_group = zarr_root[self.extension_name] else: self.format = "memory" self.extension_folder = None @@ -1882,53 +1893,56 @@ def save(self, **kwargs): self._save(**kwargs) def _save(self, **kwargs): - if self.format == "binary": - import pandas as pd - - for ext_data_name, ext_data in self._extension_data.items(): - if isinstance(ext_data, dict): - with (self.extension_folder / f"{ext_data_name}.json").open("w") as f: - json.dump(ext_data, f) - elif isinstance(ext_data, np.ndarray): - np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True) - else: - try: - with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f: - pickle.dump(ext_data, f) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") - elif self.format == "zarr": - from .zarrrecordingextractor import get_default_zarr_compressor - import pandas as pd - import numcodecs - - compressor = kwargs.get("compressor", None) - if compressor is None: - compressor = get_default_zarr_compressor() - for ext_data_name, ext_data in self._extension_data.items(): - if ext_data_name in self.extension_group: - del self.extension_group[ext_data_name] - if isinstance(ext_data, dict): - self.extension_group.create_dataset( - name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() - ) - self.extension_group[ext_data_name].attrs["dict"] = True - elif isinstance(ext_data, np.ndarray): - self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_xarray().to_zarr( - store=self.extension_group.store, group=f"{self.extension_group.name}/{ext_data_name}", mode="a" - ) - self.extension_group[ext_data_name].attrs["dataframe"] = True - else: - try: + if not self.waveform_extractor.is_read_only(): + if self.format == "binary": + import pandas as pd + + for ext_data_name, ext_data in self._extension_data.items(): + if isinstance(ext_data, dict): + with (self.extension_folder / f"{ext_data_name}.json").open("w") as f: + json.dump(ext_data, f) + elif isinstance(ext_data, np.ndarray): + np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True) + else: + try: + with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f: + pickle.dump(ext_data, f) + except: + raise Exception(f"Could not save {ext_data_name} as extension data") + elif self.format == "zarr": + from .zarrrecordingextractor import get_default_zarr_compressor + import pandas as pd + import numcodecs + + compressor = kwargs.get("compressor", None) + if compressor is None: + compressor = get_default_zarr_compressor() + for ext_data_name, ext_data in self._extension_data.items(): + if ext_data_name in self.extension_group: + del self.extension_group[ext_data_name] + if isinstance(ext_data, dict): self.extension_group.create_dataset( - name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() + name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() + ) + self.extension_group[ext_data_name].attrs["dict"] = True + elif isinstance(ext_data, np.ndarray): + self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_xarray().to_zarr( + store=self.extension_group.store, + group=f"{self.extension_group.name}/{ext_data_name}", + mode="a", ) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") + self.extension_group[ext_data_name].attrs["dataframe"] = True + else: + try: + self.extension_group.create_dataset( + name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() + ) + except: + raise Exception(f"Could not save {ext_data_name} as extension data") def reset(self): """ From dfa67e681afec0ef741b16e61417c70123c97ef5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 6 Sep 2023 12:08:01 +0200 Subject: [PATCH 04/33] warn instead of raise --- src/spikeinterface/core/waveform_extractor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 39d115e22c..431440c846 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1754,7 +1754,7 @@ def __init__(self, waveform_extractor): if not self.waveform_extractor.is_read_only(): self.extension_folder.mkdir() else: - raise Exception( + warn( "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." ) @@ -1770,7 +1770,7 @@ def __init__(self, waveform_extractor): if not self.waveform_extractor.is_read_only(): self.extension_group = zarr_root.create_group(self.extension_name) else: - raise Exception( + warn( "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." ) From f60024b0c52e17edfebe02b8170f9ac3d78b053f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 6 Sep 2023 12:24:41 +0200 Subject: [PATCH 05/33] Do not overwrite similarity in Phy if available --- src/spikeinterface/exporters/to_phy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 5615402fdb..c92861a8bf 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -178,7 +178,11 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") + if waveform_extractor.is_extension("similarity"): + tmc = waveform_extractor.load_extension("similarity") + template_similarity = tmc.get_data() + else: + template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") np.save(str(output_folder / "templates.npy"), templates) np.save(str(output_folder / "template_ind.npy"), templates_ind) From fe178c67ac9428477ca146dd6ac453bf1cccfc78 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 11 Sep 2023 10:37:00 +0200 Subject: [PATCH 06/33] Apply suggestions and avoid using chmod on windows --- src/spikeinterface/core/waveform_extractor.py | 111 +++++++++--------- .../tests/common_extension_tests.py | 28 +++-- 2 files changed, 73 insertions(+), 66 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 431440c846..3647e915bf 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1751,9 +1751,7 @@ def __init__(self, waveform_extractor): if self.format == "binary": self.extension_folder = self.folder / self.extension_name if not self.extension_folder.is_dir(): - if not self.waveform_extractor.is_read_only(): - self.extension_folder.mkdir() - else: + if self.waveform_extractor.is_read_only(): warn( "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." @@ -1761,15 +1759,16 @@ def __init__(self, waveform_extractor): self.format = "memory" self.extension_folder = None self.folder = None + else: + self.extension_folder.mkdir() + else: import zarr mode = "r+" if not self.waveform_extractor.is_read_only() else "r" zarr_root = zarr.open(self.folder, mode=mode) if self.extension_name not in zarr_root.keys(): - if not self.waveform_extractor.is_read_only(): - self.extension_group = zarr_root.create_group(self.extension_name) - else: + if self.waveform_extractor.is_read_only(): warn( "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." @@ -1777,6 +1776,8 @@ def __init__(self, waveform_extractor): self.format = "memory" self.extension_folder = None self.folder = None + else: + self.extension_group = zarr_root.create_group(self.extension_name) else: self.extension_group = zarr_root[self.extension_name] else: @@ -1893,56 +1894,58 @@ def save(self, **kwargs): self._save(**kwargs) def _save(self, **kwargs): - if not self.waveform_extractor.is_read_only(): - if self.format == "binary": - import pandas as pd - - for ext_data_name, ext_data in self._extension_data.items(): - if isinstance(ext_data, dict): - with (self.extension_folder / f"{ext_data_name}.json").open("w") as f: - json.dump(ext_data, f) - elif isinstance(ext_data, np.ndarray): - np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True) - else: - try: - with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f: - pickle.dump(ext_data, f) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") - elif self.format == "zarr": - from .zarrrecordingextractor import get_default_zarr_compressor - import pandas as pd - import numcodecs - - compressor = kwargs.get("compressor", None) - if compressor is None: - compressor = get_default_zarr_compressor() - for ext_data_name, ext_data in self._extension_data.items(): - if ext_data_name in self.extension_group: - del self.extension_group[ext_data_name] - if isinstance(ext_data, dict): + # Only save if not read only + if self.waveform_extractor.is_read_only(): + return + if self.format == "binary": + import pandas as pd + + for ext_data_name, ext_data in self._extension_data.items(): + if isinstance(ext_data, dict): + with (self.extension_folder / f"{ext_data_name}.json").open("w") as f: + json.dump(ext_data, f) + elif isinstance(ext_data, np.ndarray): + np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True) + else: + try: + with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f: + pickle.dump(ext_data, f) + except: + raise Exception(f"Could not save {ext_data_name} as extension data") + elif self.format == "zarr": + from .zarrrecordingextractor import get_default_zarr_compressor + import pandas as pd + import numcodecs + + compressor = kwargs.get("compressor", None) + if compressor is None: + compressor = get_default_zarr_compressor() + for ext_data_name, ext_data in self._extension_data.items(): + if ext_data_name in self.extension_group: + del self.extension_group[ext_data_name] + if isinstance(ext_data, dict): + self.extension_group.create_dataset( + name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() + ) + self.extension_group[ext_data_name].attrs["dict"] = True + elif isinstance(ext_data, np.ndarray): + self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_xarray().to_zarr( + store=self.extension_group.store, + group=f"{self.extension_group.name}/{ext_data_name}", + mode="a", + ) + self.extension_group[ext_data_name].attrs["dataframe"] = True + else: + try: self.extension_group.create_dataset( - name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() - ) - self.extension_group[ext_data_name].attrs["dict"] = True - elif isinstance(ext_data, np.ndarray): - self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_xarray().to_zarr( - store=self.extension_group.store, - group=f"{self.extension_group.name}/{ext_data_name}", - mode="a", + name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() ) - self.extension_group[ext_data_name].attrs["dataframe"] = True - else: - try: - self.extension_group.create_dataset( - name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() - ) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") + except: + raise Exception(f"Could not save {ext_data_name} as extension data") def reset(self): """ diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index f44d58470c..f7272ddefe 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd import shutil +import platform from pathlib import Path from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity @@ -78,12 +79,13 @@ def setUp(self): self.we2 = we2 # make we read-only - we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" - if not we_ro_folder.is_dir(): - shutil.copytree(we2.folder, we_ro_folder) + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + if not we_ro_folder.is_dir(): + shutil.copytree(we2.folder, we_ro_folder) # change permissions (R+X) - we_ro_folder.chmod(0o555) - self.we_ro = load_waveforms(we_ro_folder) + we_ro_folder.chmod(0o555) + self.we_ro = load_waveforms(we_ro_folder) self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30) we_memory = extract_waveforms( @@ -108,8 +110,9 @@ def setUp(self): def tearDown(self): # allow pytest to delete RO folder - we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" - we_ro_folder.chmod(0o777) + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + we_ro_folder.chmod(0o777) def _test_extension_folder(self, we, in_memory=False): if self.extension_function_kwargs_list is None: @@ -193,8 +196,9 @@ def test_extension(self): print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.") # read-only - Extension is memory only - _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) - assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() - ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) - assert ext_ro.format == "memory" - assert ext_ro.extension_folder is None + if platform.system() != "Windows": + _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) + assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() + ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) + assert ext_ro.format == "memory" + assert ext_ro.extension_folder is None From f013828bf4cc1363518fdc0e7940cfac07555149 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 14 Sep 2023 17:20:49 +0200 Subject: [PATCH 07/33] Allow MergeUnitsSorting to handle tuples --- src/spikeinterface/curation/mergeunitssorting.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 264ac3a56d..6baa68b0da 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting): ---------- parent_sorting: Recording The sorting object - units_to_merge: list of lists + units_to_merge: list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids: None or list @@ -24,6 +24,7 @@ class MergeUnitsSorting(BaseSorting): Default: 'keep' delta_time_ms: float or None Number of ms to consider for duplicated spikes. None won't check for duplications + Returns ------- sorting: Sorting @@ -33,7 +34,7 @@ class MergeUnitsSorting(BaseSorting): def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4): self._parent_sorting = parent_sorting - if not isinstance(units_to_merge[0], list): + if not isinstance(units_to_merge[0], (list, tuple)): # keep backward compatibility : the previous behavior was only one merge units_to_merge = [units_to_merge] From b78257cf7217764de00be0eac72b56deb499e1bd Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 15 Sep 2023 11:51:55 +0200 Subject: [PATCH 08/33] Speed up searchsorted calls --- src/spikeinterface/core/basesorting.py | 3 +-- src/spikeinterface/core/generate.py | 7 +++---- src/spikeinterface/core/node_pipeline.py | 12 ++++-------- src/spikeinterface/core/numpyextractors.py | 3 +-- src/spikeinterface/core/segmentutils.py | 6 ++---- src/spikeinterface/core/waveform_tools.py | 15 +++++---------- .../curation/remove_duplicated_spikes.py | 3 +-- .../postprocessing/amplitude_scalings.py | 3 +-- .../postprocessing/principal_component.py | 3 +-- .../postprocessing/spike_amplitudes.py | 4 +--- .../postprocessing/spike_locations.py | 3 +-- src/spikeinterface/qualitymetrics/misc_metrics.py | 6 ++---- .../sortingcomponents/motion_interpolation.py | 3 +-- 13 files changed, 24 insertions(+), 47 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 52f71c2399..eb141abde4 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -473,8 +473,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac if not concatenated: spikes_ = [] for segment_index in range(self.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left") - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left") + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left") spikes_.append(spikes[s0:s1]) spikes = spikes_ diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..56a2bb4f48 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1109,8 +1109,7 @@ def __init__( num_samples = [num_samples] for segment_index in range(sorting.get_num_segments()): - start = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="left") - end = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="right") + start, end = np.searchsorted(self.spike_vector["segment_index"], [segment_index, segment_index+1], side="left") spikes = self.spike_vector[start:end] amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None @@ -1208,8 +1207,8 @@ def get_traces( else: traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype) - start = np.searchsorted(self.spike_vector["sample_index"], start_frame - self.templates.shape[1], side="left") - end = np.searchsorted(self.spike_vector["sample_index"], end_frame + self.templates.shape[1], side="right") + start, end = np.searchsorted(self.spike_vector["sample_index"], [start_frame - self.templates.shape[1], + end_frame + self.templates.shape[1] + 1], side="left") for i in range(start, end): spike = self.spike_vector[i] diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index b11f40a441..5627eba518 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -111,8 +111,7 @@ def __init__(self, recording, peaks): # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(peaks["segment_index"], segment_index) - i1 = np.searchsorted(peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -125,8 +124,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["segment_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -183,8 +181,7 @@ def __init__( # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(self.peaks["segment_index"], segment_index) - i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -197,8 +194,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["segment_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 97f22615df..d5663156c7 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -338,8 +338,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): if self.spikes_in_seg is None: # the slicing of segment is done only once the first time # this fasten the constructor a lot - s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left") - s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left") + s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1]) self.spikes_in_seg = self.spikes[s0:s1] unit_index = self.unit_ids.index(unit_id) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index f70c45bfe5..85e36cf7a5 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -174,8 +174,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # Return (0 * num_channels) array of correct dtype return self.parent_segments[0].get_traces(0, 0, channel_indices) - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) @@ -469,8 +468,7 @@ def get_unit_spike_train( if end_frame is None: end_frame = self.get_num_samples() - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index da8e3d64b6..0ac20b9fec 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -344,15 +344,13 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes with the correct segment_index # this is a slice so no copy!! - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) in_seg_spikes = spikes[s0:s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted(in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]) # slice in absolut in spikes vector l0 = i0 + s0 @@ -562,8 +560,7 @@ def _init_worker_distribute_single_buffer( # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) worker_ctx["segment_slices"] = segment_slices @@ -590,8 +587,7 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted(in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]) # slice in absolut in spikes vector l0 = i0 + s0 @@ -685,8 +681,7 @@ def has_exceeding_spikes(recording, sorting): """ spike_vector = sorting.to_spike_vector() for segment_index in range(recording.get_num_segments()): - start_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index) - end_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind] if len(spike_vector_seg) > 0: if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1: diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index 04af69b37a..3badaa9402 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -82,8 +82,7 @@ def get_unit_spike_train( if end_frame == None: end_frame = spike_train[-1] if len(spike_train) > 0 else 0 - start = np.searchsorted(spike_train, start_frame, side="left") - end = np.searchsorted(spike_train, end_frame, side="right") + start, end = np.searchsorted(spike_train, [start_frame, end + 1], side="left") return spike_train[start:end] diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..bb97f246d9 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -99,8 +99,7 @@ def _run(self, **job_kwargs): # precompute segment slice segment_slices = [] for segment_index in range(we.get_num_segments()): - i0 = np.searchsorted(self.spikes["segment_index"], segment_index) - i1 = np.searchsorted(self.spikes["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append(slice(i0, i1)) # and run diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 233625e09e..ce1c3bd5a0 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -600,8 +600,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): seg_size = recording.get_num_samples(segment_index=segment_index) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) if i0 != i1: # protect from spikes on border : spike_time<0 or spike_time>seg_size diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 62a4e2c320..fd6078b9b0 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -218,9 +218,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): d = np.diff(spike_times) assert np.all(d >= 0) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) - + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) n_spikes = i1 - i0 amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype()) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..5f23e25b32 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -77,8 +77,7 @@ def get_data(self, outputs="concatenated"): elif outputs == "by_unit": locations_by_unit = [] for segment_index in range(self.waveform_extractor.get_num_segments()): - i0 = np.searchsorted(self.spikes["segment_index"], segment_index, side="left") - i1 = np.searchsorted(self.spikes["segment_index"], segment_index, side="right") + i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1], side="left") spikes = self.spikes[i0:i1] locations = self._extension_data["spike_locations"][i0:i1] diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index ee28485983..01701e4f65 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -848,16 +848,14 @@ def compute_drift_metrics( spike_vector = sorting.to_spike_vector() # retrieve spikes in segment - i0 = np.searchsorted(spike_vector["segment_index"], segment_index) - i1 = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spikes_in_segment = spike_vector[i0:i1] spike_locations_in_segment = spike_locations[i0:i1] # compute median positions (if less than min_spikes_per_interval, median position is 0) median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index b4a44105e4..1f6c348574 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -155,8 +155,7 @@ def interpolate_motion_on_traces( **spatial_interpolation_kwargs, ) - i0 = np.searchsorted(bin_inds, bin_ind, side="left") - i1 = np.searchsorted(bin_inds, bin_ind, side="right") + i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1] side="left") # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing From 164430c83cf66221bed677198fa8d468a8781c1d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 09:54:35 +0000 Subject: [PATCH 09/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 11 ++++++++--- src/spikeinterface/core/waveform_tools.py | 8 ++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 56a2bb4f48..6f85e76f1f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1109,7 +1109,9 @@ def __init__( num_samples = [num_samples] for segment_index in range(sorting.get_num_segments()): - start, end = np.searchsorted(self.spike_vector["segment_index"], [segment_index, segment_index+1], side="left") + start, end = np.searchsorted( + self.spike_vector["segment_index"], [segment_index, segment_index + 1], side="left" + ) spikes = self.spike_vector[start:end] amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None @@ -1207,8 +1209,11 @@ def get_traces( else: traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype) - start, end = np.searchsorted(self.spike_vector["sample_index"], [start_frame - self.templates.shape[1], - end_frame + self.templates.shape[1] + 1], side="left") + start, end = np.searchsorted( + self.spike_vector["sample_index"], + [start_frame - self.templates.shape[1], end_frame + self.templates.shape[1] + 1], + side="left", + ) for i in range(start, end): spike = self.spike_vector[i] diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 0ac20b9fec..a2f1296e31 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -350,7 +350,9 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0, i1 = np.searchsorted(in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -587,7 +589,9 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0, i1 = np.searchsorted(in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 From 426f395c6cb210b016b119225af540fd968fb30f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Sep 2023 12:38:50 +0200 Subject: [PATCH 10/33] Removed unnecessary else --- src/spikeinterface/core/waveform_extractor.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 3647e915bf..6881ab3ec5 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -523,20 +523,20 @@ def is_extension(self, extension_name) -> bool: """ if self.folder is None: return extension_name in self._loaded_extensions + + if extension_name in self._loaded_extensions: + # extension already loaded in memory + return True else: - # Extensions already loaded in memory - if extension_name in self._loaded_extensions: - return True - else: - if self.format == "binary": - return (self.folder / extension_name).is_dir() and ( - self.folder / extension_name / "params.json" - ).is_file() - elif self.format == "zarr": - return ( - extension_name in self._waveforms_root.keys() - and "params" in self._waveforms_root[extension_name].attrs.keys() - ) + if self.format == "binary": + return (self.folder / extension_name).is_dir() and ( + self.folder / extension_name / "params.json" + ).is_file() + elif self.format == "zarr": + return ( + extension_name in self._waveforms_root.keys() + and "params" in self._waveforms_root[extension_name].attrs.keys() + ) def load_extension(self, extension_name): """ From 9ad5f56907a848b757977e8dc2316445f867e269 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Sep 2023 13:01:43 +0200 Subject: [PATCH 11/33] Update src/spikeinterface/sortingcomponents/motion_interpolation.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/motion_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 1f6c348574..18bb4f5a99 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -155,7 +155,7 @@ def interpolate_motion_on_traces( **spatial_interpolation_kwargs, ) - i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1] side="left") + i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1], side="left") # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing From 9c6e6c1cef249d0382c6c441cdd7d2a7b0194cb1 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 15 Sep 2023 13:30:36 +0200 Subject: [PATCH 12/33] Typos while copy/paste --- src/spikeinterface/core/node_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 5627eba518..651804c995 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -124,7 +124,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0, i1 = np.searchsorted(peaks_in_segment["segment_index"], [start_frame, end_frame]) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -194,7 +194,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0, i1 = np.searchsorted(peaks_in_segment["segment_index"], [start_frame, end_frame]) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces From 646455a1054bf4cebed133c3197e8598ef75e59f Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 15 Sep 2023 13:38:26 +0200 Subject: [PATCH 13/33] Some more searchsorted --- .../postprocessing/amplitude_scalings.py | 15 +++++---------- .../widgets/_legacy_mpl_widgets/activity.py | 3 +-- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index bb97f246d9..73e75870f9 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -316,8 +316,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) spikes_in_segment = spikes[segment_slices[segment_index]] - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) if i0 != i1: local_spikes = spikes_in_segment[i0:i1] @@ -334,8 +333,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! - i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) - i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + i0_margin, i1_margin = np.searchsorted(spikes_in_segment["sample_index"], [start_frame - left, end_frame + right]) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices @@ -461,14 +459,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] # find the possible spikes per and post within delta_collision_samples - consecutive_window_pre = np.searchsorted( + consecutive_window_pre, consecutive_window_post = np.searchsorted( spikes_w_margin["sample_index"], - spike["sample_index"] - delta_collision_samples, - ) - consecutive_window_post = np.searchsorted( - spikes_w_margin["sample_index"], - spike["sample_index"] + delta_collision_samples, + [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples] ) + # exclude the spike itself (it is included in the collision_spikes by construction) pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py index 939475c17d..9715b7ea87 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py @@ -95,8 +95,7 @@ def plot(self): num_frames = int(duration / self.bin_duration_s) def animate_func(i): - i0 = np.searchsorted(peaks["sample_index"], bin_size * i) - i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1)) + i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s) return artists From 4410d6e8d06a8f3db8004846152be90bf04b8615 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:40:20 +0000 Subject: [PATCH 14/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/postprocessing/amplitude_scalings.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 73e75870f9..d4446e2289 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -333,7 +333,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! - i0_margin, i1_margin = np.searchsorted(spikes_in_segment["sample_index"], [start_frame - left, end_frame + right]) + i0_margin, i1_margin = np.searchsorted( + spikes_in_segment["sample_index"], [start_frame - left, end_frame + right] + ) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices @@ -461,7 +463,7 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ # find the possible spikes per and post within delta_collision_samples consecutive_window_pre, consecutive_window_post = np.searchsorted( spikes_w_margin["sample_index"], - [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples] + [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples], ) # exclude the spike itself (it is included in the collision_spikes by construction) From 334f178aaafc0cccbc81db9821749691b7d67da6 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 15 Sep 2023 13:52:20 +0200 Subject: [PATCH 15/33] Fix --- src/spikeinterface/curation/remove_duplicated_spikes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index 3badaa9402..d01ca1f6a1 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -82,7 +82,7 @@ def get_unit_spike_train( if end_frame == None: end_frame = spike_train[-1] if len(spike_train) > 0 else 0 - start, end = np.searchsorted(spike_train, [start_frame, end + 1], side="left") + start, end = np.searchsorted(spike_train, [start_frame, end_frame + 1], side="left") return spike_train[start:end] From 1ac47ffd3c2525b4fa406937b7d2391ee759e4ea Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 15 Sep 2023 14:12:04 +0200 Subject: [PATCH 16/33] in1d to isin --- src/spikeinterface/comparison/basecomparison.py | 4 ++-- src/spikeinterface/comparison/comparisontools.py | 2 +- src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/basesnippets.py | 2 +- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/generate.py | 4 ++-- src/spikeinterface/core/tests/test_sparsity.py | 2 +- src/spikeinterface/curation/mergeunitssorting.py | 4 ++-- src/spikeinterface/extractors/bids.py | 2 +- .../postprocessing/amplitude_scalings.py | 4 ++-- src/spikeinterface/postprocessing/spike_amplitudes.py | 4 ++-- src/spikeinterface/postprocessing/spike_locations.py | 4 ++-- .../preprocessing/interpolate_bad_channels.py | 2 +- src/spikeinterface/qualitymetrics/misc_metrics.py | 2 +- src/spikeinterface/qualitymetrics/pca_metrics.py | 10 +++++----- .../sortingcomponents/benchmark/benchmark_matching.py | 4 ++-- .../benchmark/benchmark_peak_selection.py | 8 ++++---- .../sortingcomponents/clustering/clustering_tools.py | 10 +++++----- .../sortingcomponents/clustering/sliding_hdbscan.py | 10 +++++----- .../widgets/unit_waveforms_density_map.py | 2 +- 20 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 79c784491a..6f45f1497d 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self): indexes = np.arange(scores.shape[1]) order1 = [] for r in range(scores.shape[0]): - possible = indexes[~np.in1d(indexes, order1)] + possible = indexes[~isin(indexes, order1)] if possible.size > 0: ind = np.argmax(scores.iloc[r, possible].values) order1.append(possible[ind]) - remain = indexes[~np.in1d(indexes, order1)] + remain = indexes[~isin(indexes, order1)] order1.extend(remain) scores = scores.iloc[:, order1] diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index db45e2b25b..eb7b5c703c 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun matched_units2 = match_12[match_12 != -1].values unmatched_units1 = match_12[match_12 == -1].index - unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)] + unmatched_units2 = unit2_ids[~isin(unit2_ids, matched_units2)] ordered_units1 = np.hstack([matched_units1, unmatched_units1]) ordered_units2 = np.hstack([matched_units2, unmatched_units2]) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index af4970a4ad..8c4a2941a0 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -592,7 +592,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceRecording - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 737087abc1..7fd0823fc0 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -139,7 +139,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceSnippets - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceSnippets(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 52f71c2399..423f974220 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids): """ from spikeinterface import UnitsSelectionSorting - new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)] + new_unit_ids = self.unit_ids[~isin(self.unit_ids, remove_unit_ids)] new_sorting = UnitsSelectionSorting(self, new_unit_ids) return new_sorting diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..44d62818f9 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -166,7 +166,7 @@ def generate_sorting( ) if empty_units is not None: - keep = ~np.in1d(labels, empty_units) + keep = ~isin(labels, empty_units) times = times[keep] labels = labels[keep] @@ -219,7 +219,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] + units_not_used = unit_ids[~isin(unit_ids, units_used_for_spike[sample_index])] if len(units_not_used) == 0: continue diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a6b94c9b84..61c4179652 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -34,7 +34,7 @@ def test_ChannelSparsity(): for key, v in sparsity.unit_id_to_channel_ids.items(): assert key in unit_ids - assert np.all(np.in1d(v, channel_ids)) + assert np.all(isin(v, channel_ids)) for key, v in sparsity.unit_id_to_channel_indices.items(): assert key in unit_ids diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 264ac3a56d..ccbaa32e7b 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -59,7 +59,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties else: # we cannot automatically find new names new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(isin(new_unit_ids, keep_unit_ids)): raise ValueError( "Unable to find 'new_unit_ids' because it is a string and parents " "already contain merges. Pass a list of 'new_unit_ids' as an argument." @@ -68,7 +68,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties # dtype int new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) else: - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(isin(new_unit_ids, keep_unit_ids)): raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 02e7d5677d..9de272c56e 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids): contact_ids = channels["contact_id"].values.astype("U") # extracting information of requested channels - keep = np.in1d(channel_ids, recording_channel_ids) + keep = isin(channel_ids, recording_channel_ids) channel_ids = channel_ids[keep] contact_ids = contact_ids[keep] diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..af618cf4db 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -47,9 +47,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = isin(self.spikes["unit_index"], unit_inds) new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] return dict(amplitude_scalings=new_amplitude_scalings) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 62a4e2c320..729dbd12bb 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids): # load filter and save amplitude files sorting = self.waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) - (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + (keep_unit_indices,) = np.nonzero(isin(sorting.unit_ids, unit_ids)) new_extension_data = dict() for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) + filtered_idxs = isin(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..eb3f1255c8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = isin(self.spikes["unit_index"], unit_inds) new_spike_locations = self._extension_data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index e634d55e7f..5773b6a2ef 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non self.bad_channel_ids = bad_channel_ids self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids) - self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs) + self._good_channel_idxs = ~isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs) self._bad_channel_idxs.setflags(write=False) if sigma_um is None: diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index ee28485983..a51bfe9164 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -544,7 +544,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue - spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] + spike_complexity = complexity[isin(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 59000211d4..0702c8f35a 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -152,8 +152,8 @@ def calculate_pc_metrics( neighbor_unit_ids = unit_ids neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids) - labels = all_labels[np.in1d(all_labels, neighbor_unit_ids)] - pcs = all_pcs[np.in1d(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + labels = all_labels[isin(all_labels, neighbor_unit_ids)] + pcs = all_pcs[isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) func_args = ( @@ -506,7 +506,7 @@ def nearest_neighbors_isolation( other_units_ids = [ unit_id for unit_id in other_units_ids - if np.sum(np.in1d(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) + if np.sum(isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) >= (n_channels_target_unit * min_spatial_overlap) ] @@ -536,10 +536,10 @@ def nearest_neighbors_isolation( if waveform_extractor.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ - :, :, np.in1d(closest_chans_target_unit, common_channel_idxs) + :, :, isin(closest_chans_target_unit, common_channel_idxs) ] waveforms_other_unit_sampled = waveforms_other_unit_sampled[ - :, :, np.in1d(closest_chans_other_unit, common_channel_idxs) + :, :, isin(closest_chans_other_unit, common_channel_idxs) ] else: waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 07c7db155c..ee8ace42ee 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -502,7 +502,7 @@ def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine seg_num = 0 # TODO: make compatible with multiple segments idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] if len(intersection) == 0: print(f"No {label}s found for unit {unit_id}") @@ -552,7 +552,7 @@ def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cos for label in ["TP", "FN"]: idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] wfs_sliced = wfs[intersection, :, :] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1514a63dd4..ca18db58d6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -133,7 +133,7 @@ def run(self, peaks=None, positions=None, delta=0.2): matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] - garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) + garbage_matches = ~isin(np.arange(len(times2)), self.good_matches) garbage_channels = self.peaks["channel_index"][garbage_matches] garbage_peaks = times2[garbage_matches] nb_garbage = len(garbage_peaks) @@ -365,7 +365,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["full_gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["full_gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.gt_peaks["sample_index"], all_spikes[idx]) + mask = isin(self.gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.gt_peaks["amplitude"][mask]) ax.scatter(self.gt_positions["x"][mask], self.gt_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.gt_positions["x"][mask].mean(), self.gt_positions["y"][mask].mean()) @@ -391,7 +391,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) + mask = isin(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.sliced_gt_peaks["amplitude"][mask]) ax.scatter( self.sliced_gt_positions["x"][mask], self.sliced_gt_positions["y"][mask], c=colors, s=1, alpha=0.5 @@ -420,7 +420,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["garbage"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["garbage"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.garbage_peaks["sample_index"], all_spikes[idx]) + mask = isin(self.garbage_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.garbage_peaks["amplitude"][mask]) ax.scatter(self.garbage_positions["x"][mask], self.garbage_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.garbage_positions["x"][mask].mean(), self.garbage_positions["y"][mask].mean()) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6edf5af16b..fb45e5fc3a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -30,7 +30,7 @@ def _split_waveforms( local_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) - local_labels_with_noise[~np.in1d(local_labels_with_noise, persistent_clusters)] = -1 + local_labels_with_noise[~isin(local_labels_with_noise, persistent_clusters)] = -1 # remove super small cluster labels, count = np.unique(local_labels_with_noise[:valid_size], return_counts=True) @@ -43,7 +43,7 @@ def _split_waveforms( to_remove = labels[(count / valid_size) < minimum_cluster_size_ratio] # ~ print('to_remove', to_remove, count / valid_size) if to_remove.size > 0: - local_labels_with_noise[np.in1d(local_labels_with_noise, to_remove)] = -1 + local_labels_with_noise[isin(local_labels_with_noise, to_remove)] = -1 local_labels_with_noise[valid_size:] = -2 @@ -123,7 +123,7 @@ def _split_waveforms_nested( active_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) - active_labels_with_noise[~np.in1d(active_labels_with_noise, persistent_clusters)] = -1 + active_labels_with_noise[~isin(active_labels_with_noise, persistent_clusters)] = -1 active_labels = active_labels_with_noise[active_ind < valid_size] active_labels_set = np.unique(active_labels) @@ -381,9 +381,9 @@ def auto_clean_clustering( continue wfs0 = wfs_arrays[label0] - wfs0 = wfs0[:, :, np.in1d(channel_inds0, used_chans)] + wfs0 = wfs0[:, :, isin(channel_inds0, used_chans)] wfs1 = wfs_arrays[label1] - wfs1 = wfs1[:, :, np.in1d(channel_inds1, used_chans)] + wfs1 = wfs1[:, :, isin(channel_inds1, used_chans)] # TODO : remove assert wfs0.shape[2] == wfs1.shape[2] diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index aeec14158f..0f1d503bdf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -198,7 +198,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): for chan_ind in prev_local_chan_inds: if total_count[chan_ind] == 0: continue - # ~ inds, = np.nonzero(np.in1d(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) + # ~ inds, = np.nonzero(isin(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) (inds,) = np.nonzero((peaks["channel_index"] == chan_ind) & (peak_labels == 0)) if inds.size <= d["min_spike_on_channel"]: chan_amps[chan_ind] = 0.0 @@ -235,12 +235,12 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # TODO: only for debug, remove later - assert np.all(np.in1d(local_chan_inds, wf_chans)) + assert np.all(isin(local_chan_inds, wf_chans)) # none label spikes wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, local_chan_inds)] + wfs_chan = wfs_chan[:, :, isin(wf_chans, local_chan_inds)] wfs.append(wfs_chan) # put noise to enhance clusters @@ -517,7 +517,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # print('wf_chans', wf_chans) # TODO: only for debug, remove later - assert np.all(np.in1d(wanted_chans, wf_chans)) + assert np.all(isin(wanted_chans, wf_chans)) wfs_chan = wfs_arrays[chan_ind] # TODO: only for debug, remove later @@ -525,7 +525,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, wanted_chans)] + wfs_chan = wfs_chan[:, :, isin(wf_chans, wanted_chans)] wfs.append(wfs_chan) wfs = np.concatenate(wfs, axis=0) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index e8a6868e92..2515d844eb 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -103,7 +103,7 @@ def __init__( if same_axis and not np.array_equal(chan_inds, shared_chan_inds): # add more channels if necessary wfs_ = np.zeros((wfs.shape[0], wfs.shape[1], shared_chan_inds.size), dtype=float) - mask = np.in1d(shared_chan_inds, chan_inds) + mask = isin(shared_chan_inds, chan_inds) wfs_[:, :, mask] = wfs wfs_[:, :, ~mask] = np.nan wfs = wfs_ From e947e09a9c3d397ceabfd8eae50ba8a5ed345cf5 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 15 Sep 2023 14:20:32 +0200 Subject: [PATCH 17/33] Revert "in1d to isin" This reverts commit 1ac47ffd3c2525b4fa406937b7d2391ee759e4ea. --- src/spikeinterface/comparison/basecomparison.py | 4 ++-- src/spikeinterface/comparison/comparisontools.py | 2 +- src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/basesnippets.py | 2 +- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/generate.py | 4 ++-- src/spikeinterface/core/tests/test_sparsity.py | 2 +- src/spikeinterface/curation/mergeunitssorting.py | 4 ++-- src/spikeinterface/extractors/bids.py | 2 +- .../postprocessing/amplitude_scalings.py | 4 ++-- src/spikeinterface/postprocessing/spike_amplitudes.py | 4 ++-- src/spikeinterface/postprocessing/spike_locations.py | 4 ++-- .../preprocessing/interpolate_bad_channels.py | 2 +- src/spikeinterface/qualitymetrics/misc_metrics.py | 2 +- src/spikeinterface/qualitymetrics/pca_metrics.py | 10 +++++----- .../sortingcomponents/benchmark/benchmark_matching.py | 4 ++-- .../benchmark/benchmark_peak_selection.py | 8 ++++---- .../sortingcomponents/clustering/clustering_tools.py | 10 +++++----- .../sortingcomponents/clustering/sliding_hdbscan.py | 10 +++++----- .../widgets/unit_waveforms_density_map.py | 2 +- 20 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 6f45f1497d..79c784491a 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self): indexes = np.arange(scores.shape[1]) order1 = [] for r in range(scores.shape[0]): - possible = indexes[~isin(indexes, order1)] + possible = indexes[~np.in1d(indexes, order1)] if possible.size > 0: ind = np.argmax(scores.iloc[r, possible].values) order1.append(possible[ind]) - remain = indexes[~isin(indexes, order1)] + remain = indexes[~np.in1d(indexes, order1)] order1.extend(remain) scores = scores.iloc[:, order1] diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index eb7b5c703c..db45e2b25b 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun matched_units2 = match_12[match_12 != -1].values unmatched_units1 = match_12[match_12 == -1].index - unmatched_units2 = unit2_ids[~isin(unit2_ids, matched_units2)] + unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)] ordered_units1 = np.hstack([matched_units1, unmatched_units1]) ordered_units2 = np.hstack([matched_units2, unmatched_units2]) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 8c4a2941a0..af4970a4ad 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -592,7 +592,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceRecording - new_channel_ids = self.channel_ids[~isin(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 7fd0823fc0..737087abc1 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -139,7 +139,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceSnippets - new_channel_ids = self.channel_ids[~isin(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceSnippets(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 423f974220..52f71c2399 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids): """ from spikeinterface import UnitsSelectionSorting - new_unit_ids = self.unit_ids[~isin(self.unit_ids, remove_unit_ids)] + new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)] new_sorting = UnitsSelectionSorting(self, new_unit_ids) return new_sorting diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 44d62818f9..401c498f03 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -166,7 +166,7 @@ def generate_sorting( ) if empty_units is not None: - keep = ~isin(labels, empty_units) + keep = ~np.in1d(labels, empty_units) times = times[keep] labels = labels[keep] @@ -219,7 +219,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[~isin(unit_ids, units_used_for_spike[sample_index])] + units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] if len(units_not_used) == 0: continue diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 61c4179652..a6b94c9b84 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -34,7 +34,7 @@ def test_ChannelSparsity(): for key, v in sparsity.unit_id_to_channel_ids.items(): assert key in unit_ids - assert np.all(isin(v, channel_ids)) + assert np.all(np.in1d(v, channel_ids)) for key, v in sparsity.unit_id_to_channel_indices.items(): assert key in unit_ids diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index ccbaa32e7b..264ac3a56d 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -59,7 +59,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties else: # we cannot automatically find new names new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(isin(new_unit_ids, keep_unit_ids)): + if np.any(np.in1d(new_unit_ids, keep_unit_ids)): raise ValueError( "Unable to find 'new_unit_ids' because it is a string and parents " "already contain merges. Pass a list of 'new_unit_ids' as an argument." @@ -68,7 +68,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties # dtype int new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) else: - if np.any(isin(new_unit_ids, keep_unit_ids)): + if np.any(np.in1d(new_unit_ids, keep_unit_ids)): raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 9de272c56e..02e7d5677d 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids): contact_ids = channels["contact_id"].values.astype("U") # extracting information of requested channels - keep = isin(channel_ids, recording_channel_ids) + keep = np.in1d(channel_ids, recording_channel_ids) channel_ids = channel_ids[keep] contact_ids = contact_ids[keep] diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index af618cf4db..5a0148c5c4 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -47,9 +47,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(isin(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) - spike_mask = isin(self.spikes["unit_index"], unit_inds) + spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] return dict(amplitude_scalings=new_amplitude_scalings) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 729dbd12bb..62a4e2c320 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids): # load filter and save amplitude files sorting = self.waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) - (keep_unit_indices,) = np.nonzero(isin(sorting.unit_ids, unit_ids)) + (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) new_extension_data = dict() for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - filtered_idxs = isin(spikes[seg_index]["unit_index"], keep_unit_indices) + filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index eb3f1255c8..c6f498f7e8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(isin(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) - spike_mask = isin(self.spikes["unit_index"], unit_inds) + spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) new_spike_locations = self._extension_data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 5773b6a2ef..e634d55e7f 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non self.bad_channel_ids = bad_channel_ids self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids) - self._good_channel_idxs = ~isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs) + self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs) self._bad_channel_idxs.setflags(write=False) if sigma_um is None: diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index a51bfe9164..ee28485983 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -544,7 +544,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue - spike_complexity = complexity[isin(unique_spike_index, spikes_per_unit["sample_index"])] + spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 0702c8f35a..59000211d4 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -152,8 +152,8 @@ def calculate_pc_metrics( neighbor_unit_ids = unit_ids neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids) - labels = all_labels[isin(all_labels, neighbor_unit_ids)] - pcs = all_pcs[isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + labels = all_labels[np.in1d(all_labels, neighbor_unit_ids)] + pcs = all_pcs[np.in1d(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) func_args = ( @@ -506,7 +506,7 @@ def nearest_neighbors_isolation( other_units_ids = [ unit_id for unit_id in other_units_ids - if np.sum(isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) + if np.sum(np.in1d(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) >= (n_channels_target_unit * min_spatial_overlap) ] @@ -536,10 +536,10 @@ def nearest_neighbors_isolation( if waveform_extractor.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ - :, :, isin(closest_chans_target_unit, common_channel_idxs) + :, :, np.in1d(closest_chans_target_unit, common_channel_idxs) ] waveforms_other_unit_sampled = waveforms_other_unit_sampled[ - :, :, isin(closest_chans_other_unit, common_channel_idxs) + :, :, np.in1d(closest_chans_other_unit, common_channel_idxs) ] else: waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index ee8ace42ee..07c7db155c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -502,7 +502,7 @@ def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine seg_num = 0 # TODO: make compatible with multiple segments idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - intersection = np.where(isin(idx_2, idx_1))[0] + intersection = np.where(np.in1d(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] if len(intersection) == 0: print(f"No {label}s found for unit {unit_id}") @@ -552,7 +552,7 @@ def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cos for label in ["TP", "FN"]: idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - intersection = np.where(isin(idx_2, idx_1))[0] + intersection = np.where(np.in1d(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] wfs_sliced = wfs[intersection, :, :] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index ca18db58d6..1514a63dd4 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -133,7 +133,7 @@ def run(self, peaks=None, positions=None, delta=0.2): matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] - garbage_matches = ~isin(np.arange(len(times2)), self.good_matches) + garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) garbage_channels = self.peaks["channel_index"][garbage_matches] garbage_peaks = times2[garbage_matches] nb_garbage = len(garbage_peaks) @@ -365,7 +365,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["full_gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["full_gt"].sorting.get_unit_spike_train(unit_id) - mask = isin(self.gt_peaks["sample_index"], all_spikes[idx]) + mask = np.in1d(self.gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.gt_peaks["amplitude"][mask]) ax.scatter(self.gt_positions["x"][mask], self.gt_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.gt_positions["x"][mask].mean(), self.gt_positions["y"][mask].mean()) @@ -391,7 +391,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["gt"].sorting.get_unit_spike_train(unit_id) - mask = isin(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) + mask = np.in1d(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.sliced_gt_peaks["amplitude"][mask]) ax.scatter( self.sliced_gt_positions["x"][mask], self.sliced_gt_positions["y"][mask], c=colors, s=1, alpha=0.5 @@ -420,7 +420,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["garbage"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["garbage"].sorting.get_unit_spike_train(unit_id) - mask = isin(self.garbage_peaks["sample_index"], all_spikes[idx]) + mask = np.in1d(self.garbage_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.garbage_peaks["amplitude"][mask]) ax.scatter(self.garbage_positions["x"][mask], self.garbage_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.garbage_positions["x"][mask].mean(), self.garbage_positions["y"][mask].mean()) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index fb45e5fc3a..6edf5af16b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -30,7 +30,7 @@ def _split_waveforms( local_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) - local_labels_with_noise[~isin(local_labels_with_noise, persistent_clusters)] = -1 + local_labels_with_noise[~np.in1d(local_labels_with_noise, persistent_clusters)] = -1 # remove super small cluster labels, count = np.unique(local_labels_with_noise[:valid_size], return_counts=True) @@ -43,7 +43,7 @@ def _split_waveforms( to_remove = labels[(count / valid_size) < minimum_cluster_size_ratio] # ~ print('to_remove', to_remove, count / valid_size) if to_remove.size > 0: - local_labels_with_noise[isin(local_labels_with_noise, to_remove)] = -1 + local_labels_with_noise[np.in1d(local_labels_with_noise, to_remove)] = -1 local_labels_with_noise[valid_size:] = -2 @@ -123,7 +123,7 @@ def _split_waveforms_nested( active_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) - active_labels_with_noise[~isin(active_labels_with_noise, persistent_clusters)] = -1 + active_labels_with_noise[~np.in1d(active_labels_with_noise, persistent_clusters)] = -1 active_labels = active_labels_with_noise[active_ind < valid_size] active_labels_set = np.unique(active_labels) @@ -381,9 +381,9 @@ def auto_clean_clustering( continue wfs0 = wfs_arrays[label0] - wfs0 = wfs0[:, :, isin(channel_inds0, used_chans)] + wfs0 = wfs0[:, :, np.in1d(channel_inds0, used_chans)] wfs1 = wfs_arrays[label1] - wfs1 = wfs1[:, :, isin(channel_inds1, used_chans)] + wfs1 = wfs1[:, :, np.in1d(channel_inds1, used_chans)] # TODO : remove assert wfs0.shape[2] == wfs1.shape[2] diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 0f1d503bdf..aeec14158f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -198,7 +198,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): for chan_ind in prev_local_chan_inds: if total_count[chan_ind] == 0: continue - # ~ inds, = np.nonzero(isin(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) + # ~ inds, = np.nonzero(np.in1d(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) (inds,) = np.nonzero((peaks["channel_index"] == chan_ind) & (peak_labels == 0)) if inds.size <= d["min_spike_on_channel"]: chan_amps[chan_ind] = 0.0 @@ -235,12 +235,12 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # TODO: only for debug, remove later - assert np.all(isin(local_chan_inds, wf_chans)) + assert np.all(np.in1d(local_chan_inds, wf_chans)) # none label spikes wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, isin(wf_chans, local_chan_inds)] + wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, local_chan_inds)] wfs.append(wfs_chan) # put noise to enhance clusters @@ -517,7 +517,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # print('wf_chans', wf_chans) # TODO: only for debug, remove later - assert np.all(isin(wanted_chans, wf_chans)) + assert np.all(np.in1d(wanted_chans, wf_chans)) wfs_chan = wfs_arrays[chan_ind] # TODO: only for debug, remove later @@ -525,7 +525,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, isin(wf_chans, wanted_chans)] + wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, wanted_chans)] wfs.append(wfs_chan) wfs = np.concatenate(wfs, axis=0) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 2515d844eb..e8a6868e92 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -103,7 +103,7 @@ def __init__( if same_axis and not np.array_equal(chan_inds, shared_chan_inds): # add more channels if necessary wfs_ = np.zeros((wfs.shape[0], wfs.shape[1], shared_chan_inds.size), dtype=float) - mask = isin(shared_chan_inds, chan_inds) + mask = np.in1d(shared_chan_inds, chan_inds) wfs_[:, :, mask] = wfs wfs_[:, :, ~mask] = np.nan wfs = wfs_ From 5e420f3a847102c145c705dddfb01b140b318ec3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 15 Sep 2023 14:21:53 +0200 Subject: [PATCH 18/33] in1d to isin with correct alias (shame on me) --- src/spikeinterface/comparison/basecomparison.py | 4 ++-- src/spikeinterface/comparison/comparisontools.py | 2 +- src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/basesnippets.py | 2 +- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/generate.py | 4 ++-- src/spikeinterface/core/tests/test_sparsity.py | 2 +- src/spikeinterface/curation/mergeunitssorting.py | 4 ++-- src/spikeinterface/extractors/bids.py | 2 +- .../postprocessing/amplitude_scalings.py | 4 ++-- src/spikeinterface/postprocessing/spike_amplitudes.py | 4 ++-- src/spikeinterface/postprocessing/spike_locations.py | 4 ++-- .../preprocessing/interpolate_bad_channels.py | 2 +- src/spikeinterface/qualitymetrics/misc_metrics.py | 2 +- src/spikeinterface/qualitymetrics/pca_metrics.py | 10 +++++----- .../sortingcomponents/benchmark/benchmark_matching.py | 4 ++-- .../benchmark/benchmark_peak_selection.py | 8 ++++---- .../sortingcomponents/clustering/clustering_tools.py | 10 +++++----- .../sortingcomponents/clustering/sliding_hdbscan.py | 10 +++++----- .../widgets/unit_waveforms_density_map.py | 2 +- 20 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 79c784491a..5af20d79b5 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self): indexes = np.arange(scores.shape[1]) order1 = [] for r in range(scores.shape[0]): - possible = indexes[~np.in1d(indexes, order1)] + possible = indexes[~np.isin(indexes, order1)] if possible.size > 0: ind = np.argmax(scores.iloc[r, possible].values) order1.append(possible[ind]) - remain = indexes[~np.in1d(indexes, order1)] + remain = indexes[~np.isin(indexes, order1)] order1.extend(remain) scores = scores.iloc[:, order1] diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index db45e2b25b..20ee7910b4 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun matched_units2 = match_12[match_12 != -1].values unmatched_units1 = match_12[match_12 == -1].index - unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)] + unmatched_units2 = unit2_ids[~np.isin(unit2_ids, matched_units2)] ordered_units1 = np.hstack([matched_units1, unmatched_units1]) ordered_units2 = np.hstack([matched_units2, unmatched_units2]) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index af4970a4ad..08f187895b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -592,7 +592,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceRecording - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 737087abc1..f35bc2b266 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -139,7 +139,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceSnippets - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceSnippets(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 52f71c2399..056134a24e 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids): """ from spikeinterface import UnitsSelectionSorting - new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)] + new_unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] new_sorting = UnitsSelectionSorting(self, new_unit_ids) return new_sorting diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..07837bcef7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -166,7 +166,7 @@ def generate_sorting( ) if empty_units is not None: - keep = ~np.in1d(labels, empty_units) + keep = ~np.isin(labels, empty_units) times = times[keep] labels = labels[keep] @@ -219,7 +219,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] + units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])] if len(units_not_used) == 0: continue diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a6b94c9b84..75182bf532 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -34,7 +34,7 @@ def test_ChannelSparsity(): for key, v in sparsity.unit_id_to_channel_ids.items(): assert key in unit_ids - assert np.all(np.in1d(v, channel_ids)) + assert np.all(np.isin(v, channel_ids)) for key, v in sparsity.unit_id_to_channel_indices.items(): assert key in unit_ids diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 264ac3a56d..2d20a58453 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -59,7 +59,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties else: # we cannot automatically find new names new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError( "Unable to find 'new_unit_ids' because it is a string and parents " "already contain merges. Pass a list of 'new_unit_ids' as an argument." @@ -68,7 +68,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties # dtype int new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) else: - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 02e7d5677d..8b70722652 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids): contact_ids = channels["contact_id"].values.astype("U") # extracting information of requested channels - keep = np.in1d(channel_ids, recording_channel_ids) + keep = np.isin(channel_ids, recording_channel_ids) channel_ids = channel_ids[keep] contact_ids = contact_ids[keep] diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..5a3542cdf9 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -47,9 +47,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] return dict(amplitude_scalings=new_amplitude_scalings) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 62a4e2c320..b6f25cda95 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids): # load filter and save amplitude files sorting = self.waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) - (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + (keep_unit_indices,) = np.nonzero(np.isin(sorting.unit_ids, unit_ids)) new_extension_data = dict() for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) + filtered_idxs = np.isin(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..4cbe4d665e 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_spike_locations = self._extension_data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index e634d55e7f..95ecd0fe52 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non self.bad_channel_ids = bad_channel_ids self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids) - self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs) + self._good_channel_idxs = ~np.isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs) self._bad_channel_idxs.setflags(write=False) if sigma_um is None: diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index ee28485983..4e871492f8 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -544,7 +544,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue - spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] + spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 59000211d4..ed06f7d738 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -152,8 +152,8 @@ def calculate_pc_metrics( neighbor_unit_ids = unit_ids neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids) - labels = all_labels[np.in1d(all_labels, neighbor_unit_ids)] - pcs = all_pcs[np.in1d(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] + pcs = all_pcs[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) func_args = ( @@ -506,7 +506,7 @@ def nearest_neighbors_isolation( other_units_ids = [ unit_id for unit_id in other_units_ids - if np.sum(np.in1d(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) + if np.sum(np.isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) >= (n_channels_target_unit * min_spatial_overlap) ] @@ -536,10 +536,10 @@ def nearest_neighbors_isolation( if waveform_extractor.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ - :, :, np.in1d(closest_chans_target_unit, common_channel_idxs) + :, :, np.isin(closest_chans_target_unit, common_channel_idxs) ] waveforms_other_unit_sampled = waveforms_other_unit_sampled[ - :, :, np.in1d(closest_chans_other_unit, common_channel_idxs) + :, :, np.isin(closest_chans_other_unit, common_channel_idxs) ] else: waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 07c7db155c..772c99bc0a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -502,7 +502,7 @@ def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine seg_num = 0 # TODO: make compatible with multiple segments idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] if len(intersection) == 0: print(f"No {label}s found for unit {unit_id}") @@ -552,7 +552,7 @@ def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cos for label in ["TP", "FN"]: idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] wfs_sliced = wfs[intersection, :, :] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1514a63dd4..73497a59fd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -133,7 +133,7 @@ def run(self, peaks=None, positions=None, delta=0.2): matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] - garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) + garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) garbage_channels = self.peaks["channel_index"][garbage_matches] garbage_peaks = times2[garbage_matches] nb_garbage = len(garbage_peaks) @@ -365,7 +365,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["full_gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["full_gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.gt_peaks["amplitude"][mask]) ax.scatter(self.gt_positions["x"][mask], self.gt_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.gt_positions["x"][mask].mean(), self.gt_positions["y"][mask].mean()) @@ -391,7 +391,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.sliced_gt_peaks["amplitude"][mask]) ax.scatter( self.sliced_gt_positions["x"][mask], self.sliced_gt_positions["y"][mask], c=colors, s=1, alpha=0.5 @@ -420,7 +420,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["garbage"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["garbage"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.garbage_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.garbage_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.garbage_peaks["amplitude"][mask]) ax.scatter(self.garbage_positions["x"][mask], self.garbage_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.garbage_positions["x"][mask].mean(), self.garbage_positions["y"][mask].mean()) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6edf5af16b..23fdbf1979 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -30,7 +30,7 @@ def _split_waveforms( local_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) - local_labels_with_noise[~np.in1d(local_labels_with_noise, persistent_clusters)] = -1 + local_labels_with_noise[~np.isin(local_labels_with_noise, persistent_clusters)] = -1 # remove super small cluster labels, count = np.unique(local_labels_with_noise[:valid_size], return_counts=True) @@ -43,7 +43,7 @@ def _split_waveforms( to_remove = labels[(count / valid_size) < minimum_cluster_size_ratio] # ~ print('to_remove', to_remove, count / valid_size) if to_remove.size > 0: - local_labels_with_noise[np.in1d(local_labels_with_noise, to_remove)] = -1 + local_labels_with_noise[np.isin(local_labels_with_noise, to_remove)] = -1 local_labels_with_noise[valid_size:] = -2 @@ -123,7 +123,7 @@ def _split_waveforms_nested( active_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) - active_labels_with_noise[~np.in1d(active_labels_with_noise, persistent_clusters)] = -1 + active_labels_with_noise[~np.isin(active_labels_with_noise, persistent_clusters)] = -1 active_labels = active_labels_with_noise[active_ind < valid_size] active_labels_set = np.unique(active_labels) @@ -381,9 +381,9 @@ def auto_clean_clustering( continue wfs0 = wfs_arrays[label0] - wfs0 = wfs0[:, :, np.in1d(channel_inds0, used_chans)] + wfs0 = wfs0[:, :, np.isin(channel_inds0, used_chans)] wfs1 = wfs_arrays[label1] - wfs1 = wfs1[:, :, np.in1d(channel_inds1, used_chans)] + wfs1 = wfs1[:, :, np.isin(channel_inds1, used_chans)] # TODO : remove assert wfs0.shape[2] == wfs1.shape[2] diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index aeec14158f..08ce9f6791 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -198,7 +198,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): for chan_ind in prev_local_chan_inds: if total_count[chan_ind] == 0: continue - # ~ inds, = np.nonzero(np.in1d(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) + # ~ inds, = np.nonzero(np.isin(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) (inds,) = np.nonzero((peaks["channel_index"] == chan_ind) & (peak_labels == 0)) if inds.size <= d["min_spike_on_channel"]: chan_amps[chan_ind] = 0.0 @@ -235,12 +235,12 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # TODO: only for debug, remove later - assert np.all(np.in1d(local_chan_inds, wf_chans)) + assert np.all(np.isin(local_chan_inds, wf_chans)) # none label spikes wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, local_chan_inds)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, local_chan_inds)] wfs.append(wfs_chan) # put noise to enhance clusters @@ -517,7 +517,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # print('wf_chans', wf_chans) # TODO: only for debug, remove later - assert np.all(np.in1d(wanted_chans, wf_chans)) + assert np.all(np.isin(wanted_chans, wf_chans)) wfs_chan = wfs_arrays[chan_ind] # TODO: only for debug, remove later @@ -525,7 +525,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, wanted_chans)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, wanted_chans)] wfs.append(wfs_chan) wfs = np.concatenate(wfs, axis=0) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index e8a6868e92..b3391c0712 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -103,7 +103,7 @@ def __init__( if same_axis and not np.array_equal(chan_inds, shared_chan_inds): # add more channels if necessary wfs_ = np.zeros((wfs.shape[0], wfs.shape[1], shared_chan_inds.size), dtype=float) - mask = np.in1d(shared_chan_inds, chan_inds) + mask = np.isin(shared_chan_inds, chan_inds) wfs_[:, :, mask] = wfs wfs_[:, :, ~mask] = np.nan wfs = wfs_ From 0bd70dd27b23e799696ef966d9b84a4eac3c3b22 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 15 Sep 2023 16:54:36 +0200 Subject: [PATCH 19/33] detect_bad_channels some recording is not ordered. Add more chunk default computation. --- .../preprocessing/detect_bad_channels.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 0f4800c6e8..35ed2c349b 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -17,8 +17,8 @@ def detect_bad_channels( n_neighbors=11, nyquist_threshold=0.8, direction="y", - chunk_duration_s=0.3, - num_random_chunks=10, + chunk_duration_s=.5, + num_random_chunks=100, welch_window_ms=10.0, highpass_filter_cutoff=300, neighborhood_r2_threshold=0.9, @@ -81,9 +81,10 @@ def detect_bad_channels( highpass_filter_cutoff : float If the recording is not filtered, the cutoff frequency of the highpass filter, by default 300 chunk_duration_s : float - Duration of each chunk, by default 0.3 + Duration of each chunk, by default 0.5 num_random_chunks : int - Number of random chunks, by default 10 + Number of random chunks, by default 100 + Having many chunks is important for reproducibility. welch_window_ms : float Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms neighborhood_r2_threshold : float, default 0.95 @@ -174,20 +175,18 @@ def detect_bad_channels( channel_locations = recording.get_channel_locations() dim = ["x", "y", "z"].index(direction) assert dim < channel_locations.shape[1], f"Direction {direction} is wrong" - locs_depth = channel_locations[:, dim] - if np.array_equal(np.sort(locs_depth), locs_depth): + order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) + if np.all(np.diff(order_f) == 1): + # already ordered order_f = None order_r = None - else: - # sort by x, y to avoid ambiguity - order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) # Create empty channel labels and fill with bad-channel detection estimate for each chunk chunk_channel_labels = np.zeros((recording.get_num_channels(), len(random_data)), dtype=np.int8) for i, random_chunk in enumerate(random_data): - random_chunk_sorted = random_chunk[order_f] if order_f is not None else random_chunk - chunk_channel_labels[:, i] = detect_bad_channels_ibl( + random_chunk_sorted = random_chunk[:, order_f] if order_f is not None else random_chunk + chunk_labels = detect_bad_channels_ibl( raw=random_chunk_sorted, fs=recording.sampling_frequency, psd_hf_threshold=psd_hf_threshold, @@ -198,11 +197,10 @@ def detect_bad_channels( nyquist_threshold=nyquist_threshold, welch_window_ms=welch_window_ms, ) + chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels # Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output. mode_channel_labels, _ = scipy.stats.mode(chunk_channel_labels, axis=1, keepdims=False) - if order_r is not None: - mode_channel_labels = mode_channel_labels[order_r] (bad_inds,) = np.where(mode_channel_labels != 0) bad_channel_ids = recording.channel_ids[bad_inds] @@ -306,7 +304,7 @@ def detect_bad_channels_ibl( n_neighbors : int, optional Number of neighbors to compute median fitler, by default 11 nyquist_threshold : float, optional - Threshold on Nyquist frequency to calculate HF noise band, by default 0.8 + Threshold on Nyquist frequency to calcureclate HF noise band, by default 0.8 welch_window_ms: float Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms Returns From 05ad95be8f9811ca86d6905edc13a5b5d4c2251b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 14:55:58 +0000 Subject: [PATCH 20/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/detect_bad_channels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 35ed2c349b..fa61755aba 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -17,7 +17,7 @@ def detect_bad_channels( n_neighbors=11, nyquist_threshold=0.8, direction="y", - chunk_duration_s=.5, + chunk_duration_s=0.5, num_random_chunks=100, welch_window_ms=10.0, highpass_filter_cutoff=300, From bd26723e1cd1a86660abbe23d344cb299f9140ad Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Sun, 17 Sep 2023 09:40:10 -0400 Subject: [PATCH 21/33] fix folder --- .github/workflows/installation-tips-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index 0e522e6baa..b3bf08954d 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -30,4 +30,4 @@ jobs: - name: Test Conda Environment Creation uses: conda-incubator/setup-miniconda@v2.2.0 with: - environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml + environment-file: ./installation_tips/full_spikeinterface_environment_${{ matrix.label }}.yml From c57cfa71fae9e0cc4aada7e72435cb8f3667eecf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 18 Sep 2023 11:03:29 +0200 Subject: [PATCH 22/33] Add an option to flip the order by depth --- src/spikeinterface/core/recording_tools.py | 7 ++++++- src/spikeinterface/core/tests/test_recording_tools.py | 2 ++ src/spikeinterface/preprocessing/depth_order.py | 8 ++++++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index e5901d7ee0..8236671a3b 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -302,7 +302,7 @@ def get_chunk_with_margin( return traces_chunk, left_margin, right_margin -def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): +def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), flip=False): """ Order channels by depth, by first ordering the x-axis, and then the y-axis. @@ -316,6 +316,9 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') + flip: bool, default False + If flip is False then the order is bottom first (starting from tip of the probe). + If flip is True then the order is upper first. Returns ------- @@ -341,6 +344,8 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): assert dim < ndim, "Invalid dimensions!" locations_to_sort += (locations[:, dim],) order_f = np.lexsort(locations_to_sort) + if flip: + order_f = order_f[::-1] order_r = np.argsort(order_f, kind="stable") return order_f, order_r diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 6e92d155fe..1d99b192ee 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -138,11 +138,13 @@ def test_order_channels_by_depth(): order_1d, order_r1d = order_channels_by_depth(rec, dimensions="y") order_2d, order_r2d = order_channels_by_depth(rec, dimensions=("x", "y")) locations_rev = locations_copy[order_1d][order_r1d] + order_2d_fliped, order_r2d_fliped = order_channels_by_depth(rec, dimensions=("x", "y"), flip=True) assert np.array_equal(locations[:, 1], locations_copy[order_1d][:, 1]) assert np.array_equal(locations_copy[order_1d][:, 1], locations_copy[order_2d][:, 1]) assert np.array_equal(locations, locations_copy[order_2d]) assert np.array_equal(locations_copy, locations_copy[order_2d][order_r2d]) + assert np.array_equal(order_2d[::-1], order_2d_fliped) if __name__ == "__main__": diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 0b8d8a730b..b9edded883 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -18,13 +18,16 @@ class DepthOrderRecording(ChannelSliceRecording): If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') + flip: bool, default False + If flip is False then the order is bottom first (starting from tip of the probe). + If flip is True then the order is upper first. """ name = "depth_order" installed = True - def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): - order_f, order_r = order_channels_by_depth(parent_recording, channel_ids=channel_ids, dimensions=dimensions) + def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y"), flip=False): + order_f, order_r = order_channels_by_depth(parent_recording, channel_ids=channel_ids, dimensions=dimensions, flip=flip) reordered_channel_ids = parent_recording.channel_ids[order_f] ChannelSliceRecording.__init__( self, @@ -35,6 +38,7 @@ def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): parent_recording=parent_recording, channel_ids=channel_ids, dimensions=dimensions, + flip=flip, ) From ef165cb4a2d43df592a57a2c801c62ebe5ce780b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:03:59 +0000 Subject: [PATCH 23/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/depth_order.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index b9edded883..43c43a5843 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -27,7 +27,9 @@ class DepthOrderRecording(ChannelSliceRecording): installed = True def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y"), flip=False): - order_f, order_r = order_channels_by_depth(parent_recording, channel_ids=channel_ids, dimensions=dimensions, flip=flip) + order_f, order_r = order_channels_by_depth( + parent_recording, channel_ids=channel_ids, dimensions=dimensions, flip=flip + ) reordered_channel_ids = parent_recording.channel_ids[order_f] ChannelSliceRecording.__init__( self, From d431e4ebe817993a74173f414eda139c21a83171 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 18 Sep 2023 17:49:10 +0200 Subject: [PATCH 24/33] Update src/spikeinterface/preprocessing/detect_bad_channels.py Co-authored-by: Alessio Buccino --- src/spikeinterface/preprocessing/detect_bad_channels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index fa61755aba..3c712946eb 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -304,7 +304,7 @@ def detect_bad_channels_ibl( n_neighbors : int, optional Number of neighbors to compute median fitler, by default 11 nyquist_threshold : float, optional - Threshold on Nyquist frequency to calcureclate HF noise band, by default 0.8 + Threshold on Nyquist frequency to calculate HF noise band, by default 0.8 welch_window_ms: float Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms Returns From ef0d66e6cfeea0b1f3392c5a0a8758194a9c884d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 19 Sep 2023 09:27:18 +0200 Subject: [PATCH 25/33] Bringing back right searches --- src/spikeinterface/core/generate.py | 8 +++----- src/spikeinterface/curation/remove_duplicated_spikes.py | 3 ++- src/spikeinterface/postprocessing/spike_locations.py | 3 ++- .../sortingcomponents/motion_interpolation.py | 5 +++-- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 6f85e76f1f..33f3dea923 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1209,11 +1209,9 @@ def get_traces( else: traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype) - start, end = np.searchsorted( - self.spike_vector["sample_index"], - [start_frame - self.templates.shape[1], end_frame + self.templates.shape[1] + 1], - side="left", - ) + start = np.searchsorted(self.spike_vector["sample_index"], start_frame - self.templates.shape[1], side="left") + end = np.searchsorted(self.spike_vector["sample_index"], end_frame + self.templates.shape[1], side="right") + for i in range(start, end): spike = self.spike_vector[i] diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index d01ca1f6a1..04af69b37a 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -82,7 +82,8 @@ def get_unit_spike_train( if end_frame == None: end_frame = spike_train[-1] if len(spike_train) > 0 else 0 - start, end = np.searchsorted(spike_train, [start_frame, end_frame + 1], side="left") + start = np.searchsorted(spike_train, start_frame, side="left") + end = np.searchsorted(spike_train, end_frame, side="right") return spike_train[start:end] diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 5f23e25b32..c6f498f7e8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -77,7 +77,8 @@ def get_data(self, outputs="concatenated"): elif outputs == "by_unit": locations_by_unit = [] for segment_index in range(self.waveform_extractor.get_num_segments()): - i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1], side="left") + i0 = np.searchsorted(self.spikes["segment_index"], segment_index, side="left") + i1 = np.searchsorted(self.spikes["segment_index"], segment_index, side="right") spikes = self.spikes[i0:i1] locations = self._extension_data["spike_locations"][i0:i1] diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 18bb4f5a99..9a4cd688c5 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -155,8 +155,9 @@ def interpolate_motion_on_traces( **spatial_interpolation_kwargs, ) - i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1], side="left") - + i0 = np.searchsorted(bin_inds, bin_ind, side="left") + i1 = np.searchsorted(bin_inds, bin_ind, side="right") + # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) From f2d702a7e20f7fb6459a18b17dd9a4881c1fe337 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Sep 2023 07:27:40 +0000 Subject: [PATCH 26/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 1 - src/spikeinterface/sortingcomponents/motion_interpolation.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 33f3dea923..9adda4cb2b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1212,7 +1212,6 @@ def get_traces( start = np.searchsorted(self.spike_vector["sample_index"], start_frame - self.templates.shape[1], side="left") end = np.searchsorted(self.spike_vector["sample_index"], end_frame + self.templates.shape[1], side="right") - for i in range(start, end): spike = self.spike_vector[i] t = spike["sample_index"] diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 9a4cd688c5..b4a44105e4 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -157,7 +157,7 @@ def interpolate_motion_on_traces( i0 = np.searchsorted(bin_inds, bin_ind, side="left") i1 = np.searchsorted(bin_inds, bin_ind, side="right") - + # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) From 9d07ec2fb467e4bc035f2e36566ea9a2aead772e Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 19 Sep 2023 09:31:02 +0200 Subject: [PATCH 27/33] One more --- src/spikeinterface/core/generate.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 9adda4cb2b..401c498f03 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1109,9 +1109,8 @@ def __init__( num_samples = [num_samples] for segment_index in range(sorting.get_num_segments()): - start, end = np.searchsorted( - self.spike_vector["segment_index"], [segment_index, segment_index + 1], side="left" - ) + start = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="left") + end = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="right") spikes = self.spike_vector[start:end] amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None From e88b4b5da0b1d848bd910122a385b3f5fb01dc2c Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 19 Sep 2023 11:04:43 +0200 Subject: [PATCH 28/33] Update src/spikeinterface/preprocessing/depth_order.py --- src/spikeinterface/preprocessing/depth_order.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 43c43a5843..55e34ba5dd 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -18,7 +18,7 @@ class DepthOrderRecording(ChannelSliceRecording): If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') - flip: bool, default False + flip: bool, default: False If flip is False then the order is bottom first (starting from tip of the probe). If flip is True then the order is upper first. """ From b202c431a9f5d89bf7a5e92cf62acef64f040241 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 19 Sep 2023 11:05:23 +0200 Subject: [PATCH 29/33] Update src/spikeinterface/core/recording_tools.py --- src/spikeinterface/core/recording_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 8236671a3b..ff9cd99389 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -316,7 +316,7 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') - flip: bool, default False + flip: bool, default: False If flip is False then the order is bottom first (starting from tip of the probe). If flip is True then the order is upper first. From 73395fbd5a420be7d21e4017abcafb3d4a91d5ea Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 19 Sep 2023 11:24:18 +0200 Subject: [PATCH 30/33] Update src/spikeinterface/preprocessing/detect_bad_channels.py --- src/spikeinterface/preprocessing/detect_bad_channels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 3c712946eb..cc4e8601e2 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -17,7 +17,7 @@ def detect_bad_channels( n_neighbors=11, nyquist_threshold=0.8, direction="y", - chunk_duration_s=0.5, + chunk_duration_s=0.3, num_random_chunks=100, welch_window_ms=10.0, highpass_filter_cutoff=300, From 2f4d50a6651d4fc0ba568463df61a350d62ddd33 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:32:16 -0400 Subject: [PATCH 31/33] typo corrections, link corrections --- doc/development/development.rst | 6 +++--- doc/install_sorters.rst | 2 +- doc/modules/sorters.rst | 6 +++--- doc/modules/sortingcomponents.rst | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/doc/development/development.rst b/doc/development/development.rst index f1371639c3..4704b9b1e6 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -14,7 +14,7 @@ There are various ways to contribute to SpikeInterface as a user or developer. S * Writing unit tests to expand code coverage and use case scenarios. * Reporting bugs and issues. -We use a forking workflow _ to manage contributions. Here's a summary of the steps involved, with more details available in the provided link: +We use a forking workflow ``_ to manage contributions. Here's a summary of the steps involved, with more details available in the provided link: * Fork the SpikeInterface repository. * Create a new branch (e.g., :code:`git switch -c my-contribution`). @@ -22,7 +22,7 @@ We use a forking workflow _ . +While we appreciate all the contributions please be mindful of the cost of reviewing pull requests ``_ . How to run tests locally @@ -201,7 +201,7 @@ Implement a new extractor SpikeInterface already supports over 30 file formats, but the acquisition system you use might not be among the supported formats list (***ref***). Most of the extractord rely on the `NEO `_ package to read information from files. -Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new `neo.rawio `_ class. +Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new :code:``neo.rawio ` class. Once that is done, the new class can be easily wrapped into SpikeInterface as an extension of the :py:class:`~spikeinterface.extractors.neoextractors.neobaseextractors.NeoBaseRecordingExtractor` (for :py:class:`~spikeinterface.core.BaseRecording` objects) or diff --git a/doc/install_sorters.rst b/doc/install_sorters.rst index 3fda05848c..10a3185c5c 100644 --- a/doc/install_sorters.rst +++ b/doc/install_sorters.rst @@ -117,7 +117,7 @@ Kilosort2.5 git clone https://github.com/MouseLand/Kilosort # provide installation path by setting the KILOSORT2_5_PATH environment variable - # or using Kilosort2_5Sorter.set_kilosort2_path() + # or using Kilosort2_5Sorter.set_kilosort2_5_path() * See also for Matlab/CUDA: https://www.mathworks.com/help/parallel-computing/gpu-support-by-release.html diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index 34ab3d1151..1b27ed442c 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -239,7 +239,7 @@ There are three options: 1. **released PyPi version**: if you installed :code:`spikeinterface` with :code:`pip install spikeinterface`, the latest released version will be installed in the container. -2. **development :code:`main` version**: if you installed :code:`spikeinterface` from source from the cloned repo +2. **development** :code:`main` **version**: if you installed :code:`spikeinterface` from source from the cloned repo (with :code:`pip install .`) or with :code:`pip install git+https://github.com/SpikeInterface/spikeinterface.git`, the current development version from the :code:`main` branch will be installed in the container. @@ -458,7 +458,7 @@ Here is the list of external sorters accessible using the run_sorter wrapper: * **Kilosort** :code:`run_sorter('kilosort')` * **Kilosort2** :code:`run_sorter('kilosort2')` * **Kilosort2.5** :code:`run_sorter('kilosort2_5')` -* **Kilosort3** :code:`run_sorter('Kilosort3')` +* **Kilosort3** :code:`run_sorter('kilosort3')` * **PyKilosort** :code:`run_sorter('pykilosort')` * **Klusta** :code:`run_sorter('klusta')` * **Mountainsort4** :code:`run_sorter('mountainsort4')` @@ -474,7 +474,7 @@ Here is the list of external sorters accessible using the run_sorter wrapper: Here a list of internal sorter based on `spikeinterface.sortingcomponents`; they are totally experimental for now: -* **Spyking circus2** :code:`run_sorter('spykingcircus2')` +* **Spyking Circus2** :code:`run_sorter('spykingcircus2')` * **Tridesclous2** :code:`run_sorter('tridesclous2')` In 2023, we expect to add many more sorters to this list. diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index aa62ea5b33..422eaea890 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -223,7 +223,7 @@ Here is a short example that depends on the output of "Motion interpolation": **Notes**: * :code:`spatial_interpolation_method` "kriging" or "iwd" do not play a big role. - * :code:`border_mode` is a very important parameter. It controls how to deal with the border because motion causes units on the + * :code:`border_mode` is a very important parameter. It controls dealing with the border because motion causes units on the border to not be present throughout the entire recording. We highly recommend the :code:`border_mode='remove_channels'` because this removes channels on the border that will be impacted by drift. Of course the larger the motion is the more channels are removed. @@ -278,7 +278,7 @@ At the moment, there are five methods implemented: * 'naive': a very naive implemenation used as a reference for benchmarks * 'tridesclous': the algorithm for template matching implemented in Tridesclous * 'circus': the algorithm for template matching implemented in SpyKING-Circus - * 'circus-omp': a updated algorithm similar to SpyKING-Circus but with OMP (orthogonal macthing + * 'circus-omp': a updated algorithm similar to SpyKING-Circus but with OMP (orthogonal matching pursuit) * 'wobble' : an algorithm loosely based on YASS that scales template amplitudes and shifts them in time to match detected spikes From b90e35b9df6bb03bac2a7c3e76e36c79c3f68af3 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Tue, 19 Sep 2023 08:56:48 -0400 Subject: [PATCH 32/33] Update doc/development/development.rst Co-authored-by: Alessio Buccino --- doc/development/development.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/development/development.rst b/doc/development/development.rst index 4704b9b1e6..7656da11ab 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -201,7 +201,7 @@ Implement a new extractor SpikeInterface already supports over 30 file formats, but the acquisition system you use might not be among the supported formats list (***ref***). Most of the extractord rely on the `NEO `_ package to read information from files. -Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new :code:``neo.rawio ` class. +Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new :code:`neo.rawio.BaseRawIO` class (see `example `_). Once that is done, the new class can be easily wrapped into SpikeInterface as an extension of the :py:class:`~spikeinterface.extractors.neoextractors.neobaseextractors.NeoBaseRecordingExtractor` (for :py:class:`~spikeinterface.core.BaseRecording` objects) or From b8023d0733e48b8bc96d50c763753a7da1b3a5d5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Sep 2023 16:16:40 +0200 Subject: [PATCH 33/33] Add read_binary and read_zarr functions to extractord and docs API --- doc/api.rst | 11 ++++++----- src/spikeinterface/extractors/extractorlist.py | 2 ++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 43f79386e6..122c88d01b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -19,6 +19,8 @@ spikeinterface.core .. autofunction:: extract_waveforms .. autofunction:: load_waveforms .. autofunction:: compute_sparsity + .. autoclass:: ChannelSparsity + :members: .. autoclass:: BinaryRecordingExtractor .. autoclass:: ZarrRecordingExtractor .. autoclass:: BinaryFolderRecording @@ -48,10 +50,6 @@ spikeinterface.core .. autofunction:: get_template_extremum_channel .. autofunction:: get_template_extremum_channel_peak_shift .. autofunction:: get_template_extremum_amplitude - -.. - .. autofunction:: read_binary - .. autofunction:: read_zarr .. autofunction:: append_recordings .. autofunction:: concatenate_recordings .. autofunction:: split_recording @@ -59,6 +57,8 @@ spikeinterface.core .. autofunction:: append_sortings .. autofunction:: split_sorting .. autofunction:: select_segment_sorting + .. autofunction:: read_binary + .. autofunction:: read_zarr Low-level ~~~~~~~~~ @@ -67,7 +67,6 @@ Low-level :noindex: .. autoclass:: BaseWaveformExtractorExtension - .. autoclass:: ChannelSparsity .. autoclass:: ChunkRecordingExecutor spikeinterface.extractors @@ -83,6 +82,7 @@ NEO-based .. autofunction:: read_alphaomega_event .. autofunction:: read_axona .. autofunction:: read_biocam + .. autofunction:: read_binary .. autofunction:: read_blackrock .. autofunction:: read_ced .. autofunction:: read_intan @@ -104,6 +104,7 @@ NEO-based .. autofunction:: read_spikegadgets .. autofunction:: read_spikeglx .. autofunction:: read_tdt + .. autofunction:: read_zarr Non-NEO-based diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index ebff40fae0..235dd705dc 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -11,6 +11,8 @@ NumpySorting, NpySnippetsExtractor, ZarrRecordingExtractor, + read_binary, + read_zarr, ) # sorting/recording/event from neo