diff --git a/.github/scripts/determine_testing_environment.py b/.github/scripts/determine_testing_environment.py index 95ad0afc49..aa85aa2b91 100644 --- a/.github/scripts/determine_testing_environment.py +++ b/.github/scripts/determine_testing_environment.py @@ -31,6 +31,7 @@ sortingcomponents_changed = False generation_changed = False stream_extractors_changed = False +github_actions_changed = False for changed_file in changed_files_in_the_pull_request_paths: @@ -78,9 +79,12 @@ sorters_internal_changed = True else: sorters_changed = True + elif ".github" in changed_file.parts: + if "workflows" in changed_file.parts: + github_actions_changed = True -run_everything = core_changed or pyproject_toml_changed or neobaseextractor_changed +run_everything = core_changed or pyproject_toml_changed or neobaseextractor_changed or github_actions_changed run_generation_tests = run_everything or generation_changed run_extractor_tests = run_everything or extractors_changed or plexon2_changed run_preprocessing_tests = run_everything or preprocessing_changed @@ -96,7 +100,7 @@ run_sorters_test = run_everything or sorters_changed run_internal_sorters_test = run_everything or run_sortingcomponents_tests or sorters_internal_changed -run_streaming_extractors_test = stream_extractors_changed +run_streaming_extractors_test = stream_extractors_changed or github_actions_changed install_plexon_dependencies = plexon2_changed diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index e12cf6805d..dcaec8b272 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -12,6 +12,7 @@ on: env: KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} + KACHERY_ZONE: ${{ secrets.KACHERY_ZONE }} concurrency: # Cancel previous workflows on the same pull request group: ${{ github.workflow }}-${{ github.ref }} @@ -25,7 +26,7 @@ jobs: fail-fast: false matrix: python-version: ["3.9", "3.12"] # Lower and higher versions we support - os: [macos-13, windows-latest, ubuntu-latest] + os: [macos-latest, windows-latest, ubuntu-latest] steps: - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 6a222f5e25..407c614ebf 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -8,6 +8,7 @@ on: env: KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} + KACHERY_ZONE: ${{ secrets.KACHERY_ZONE }} jobs: full-tests-with-codecov: diff --git a/doc/development/development.rst b/doc/development/development.rst index 1094b466fc..246a2bcb9a 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -189,9 +189,17 @@ so that the user knows what the options are. Miscelleaneous Stylistic Conventions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -#. Avoid using abreviations in variable names (e.g., use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. +#. Avoid using abbreviations in variable names (e.g. use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. #. Use index as singular and indices for plural following the NumPy convention. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids) #. We use file_path and folder_path (instead of file_name and folder_name) for clarity. +#. For creating headers to divide sections of code we use the following convention (see issue `#3019 `_): + + +.. code:: python + + ######################################### + # A header + ######################################### How to build the documentation diff --git a/doc/releases/0.101.0.rst b/doc/releases/0.101.0.rst index c34cd0dc8e..0e686cca1a 100644 --- a/doc/releases/0.101.0.rst +++ b/doc/releases/0.101.0.rst @@ -3,7 +3,7 @@ SpikeInterface 0.101.0 release notes ------------------------------------ -Estimated: 19th July 2024 +19th July 2024 Main changes: diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst new file mode 100644 index 0000000000..eeb54566a6 --- /dev/null +++ b/doc/releases/0.101.1.rst @@ -0,0 +1,144 @@ +.. _release0.101.1: + +SpikeInterface 0.101.1 release notes +------------------------------------ + +16th September 2024 + +Main changes: + +* Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) +* Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) +* Skip recomputation of quality and template metrics if already computed (#3292) +* Improved `estimate_sparsity` with new `amplitude` method and deprecated `from_ptp` option (#3369) +* Dropped support for Python<3.9 (#3267) + +core: + +* Update the method of creating an empty file with right size when saving binary files (#3408) +* Refactor pandas save load and convert dtypes (#3412) +* Check run info completed only if it exists (back-compatibility) (#3407) +* Fix argument spelling in check for binary compatibility (#3409) +* Fix proposal for channel location when probegroup (#3392) +* Fix time handling test memory (#3379) +* Add `BaseRecording.reset_times()` function (#3363, #3380, #3391) +* Extend `estimate_sparsity` methods and update `from_ptp`` (#3369) +* Add `load_sorting_analyzer_or_waveforms` function (#3352) +* Fix zarr folder suffix handling (#3349) +* Analyzer extension exit status (#3347) +* Lazy loading of zarr timestamps (#3318) +* Enable cloud-loading for analyzer Zarr (#3314, #3351, #3378) +* Refactor `set_property` in base (#3287) +* Job kwargs fix (#3259) +* Add `is_filtered` to annotations in `binary.json` (#3245) +* Add check for None in `NoiseGeneratorRecordingSegment`` get_traces() (#3230) + +extractors: + +* Load phy channel_group as group (#3368) +* "quality" property to be read as string instead of object in `BasePhyKilosortSortingExtractor` (#3365) + +preprocessing: + +* Auto-cast recording to float prior to interpolation (#3415) +* Update doc handle drift + better preset (#3232) +* Add causal filtering to filter.py (#3172) + +sorters: + +* Updates to kilosort 4: version >= 4.0.16, `bad_channels`, `clear_cache`, `use_binary_file` (#3339) +* Download apptainer images without docker client (#3335) +* Expose save preprocessing in ks4 (#3276) +* Fix KS2/2.5/3 skip_kilosort_preprocessing (#3265) +* HS: Added lowpass parameter, fixed verbose option (#3262) +* Now exclusive support for HS v0.4 (Lightning) (#3210) + +postprocessing: + +* Add extra protection for template metrics (#3364) +* Protect median against nans in get_prototype_spike (#3270) +* Fix docstring and error for spike_amplitudes (#3269) + +qualitymetrics: + +* Do not delete quality and template metrics on recompute (#3292) +* Refactor quality metrics tests to use fixture (#3249) + + +curation: + +* Clean-up identity merges in `get_potential_auto_merges` (#3346) +* Fix sortingview curation no merge case (#3309) +* Start apply_curation() (#3208) + +widgets: + +* Fix metrics widgets for convert_dtypes (#3417) +* Fix plot motion for multi-segment (#3414) +* Sortingview: only round float properties (#3406) +* Fix widgets tests and add test on unit_table_properties (#3354) +* Allow quality and template metrics in sortingview's unit table (#3299) +* Add subwidget parameters for UnitSummaryWidget (#3242) +* Fix `ipympl`/`widget` backend check (#3238) + +generators: + +* Handle case where channel count changes from probeA to probeB (#3237) + +sortingcomponents: + +* Update doc handle drift + better preset (#3232) +* Make `InterpolateMotionRecording`` not JSON-serializable (#3341) + +documentation: + +* Clarify meaning of `delta_time` in `compare_sorter_to_ground_truth` (#3360) +* Added sphinxcontrib-jquery (#3307) +* Adding return type annotations (#3304) +* More docstring updates for multiple modules (#3298) +* Fix sampling frequency repr (#3294) +* Proposal for adding Examples to docstrings (#3279) +* More numpydoc fixes (#3275) +* Fix docstring and error for spike_amplitudes (#3269) +* Fix postprocessing docs (#3268) +* Fix name of principal_components ext in qm docs (take 2) (#3261) +* Update doc handle drift + better preset (#3232) +* Add `int` type to `num_samples` on `InjectTemplatesRecording`. (#3229) + +continuous integration: + +* Fix streaming extractor condition in the CI (#3362) + +packaging: + +* Minor typing fixes (#3374) +* Drop python 3.8 in pyproject.toml (#3267) + +testing: + +* Fix time handling test memory (#3379) +* Fix streaming extractor condition in the CI (#3362) +* Test IBL skip when the setting up the one client fails (#3289) +* Refactor `set_property` in base (#3287) +* Refactor quality metrics tests to use fixture (#3249) +* Add kilosort4 wrapper tests (#3085) +* Test IBL skip when the setting up the one client fails (#3289) +* Add kilosort4 wrapper tests (#3085) + +Contributors: + +* @Djoels +* @JoeZiminski +* @JuanPimientoCaicedo +* @alejoe91 +* @chrishalcrow +* @cwindolf +* @florian6973 +* @h-mayorquin +* @jiumao2 +* @jonahpearl +* @mhhennig +* @rkim48 +* @samuelgarcia +* @tabedzki +* @zm711 diff --git a/doc/scripts/auto-release-notes.sh b/doc/scripts/auto-release-notes.sh index 14bee3dad0..f3818e1e18 100644 --- a/doc/scripts/auto-release-notes.sh +++ b/doc/scripts/auto-release-notes.sh @@ -1,6 +1,6 @@ #!/bin/bash if [ $# -eq 0 ]; then - echo "Usage: $0 START_DATE END_DATE [LABEL] [BRANCH1,BRANCH2] [LIMIT]" + echo "Usage: $0 START_DATE (format: YEAR-MM-DD) END_DATE [LABEL] [BRANCH1,BRANCH2] [LIMIT]" exit 1 fi diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 94da5d15fb..c8038387f9 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.101.1.rst releases/0.101.0.rst releases/0.100.8.rst releases/0.100.7.rst @@ -43,6 +44,16 @@ Release notes releases/0.9.1.rst +Version 0.101.1 +=============== + +* Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) +* Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) +* Skip recomputation of quality and template metrics if already computed (#3292) +* Improved `estimate_sparsity` with new `amplitude` method and deprecated `from_ptp` option (#3369) +* Dropped support for Python<3.9 (#3267) + + Version 0.101.0 =============== diff --git a/pyproject.toml b/pyproject.toml index 8309ca89fe..c1c02db8db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,6 @@ preprocessing = [ full = [ "h5py", "pandas", - "xarray", "scipy", "scikit-learn", "networkx", @@ -125,16 +124,16 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_extractors = [ # Functions to download data in neo test suite "pooch>=1.8.2", "datalad>=1.0.2", - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_preprocessing = [ @@ -148,7 +147,6 @@ test = [ "pytest-dependency", "pytest-cov", - "xarray", "huggingface_hub", # preprocessing @@ -175,8 +173,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -193,15 +191,14 @@ docs = [ "pandas", # in the modules gallery comparison tutorial "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions - "xarray", # For use of SortingAnalyzer zarr format "networkx", # Download data "pooch>=1.8.2", "datalad>=1.0.2", # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 306c12d516..97fb95b623 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -30,5 +30,5 @@ # This flag must be set to False for release # This avoids using versioning that contains ".dev0" (and this is a better choice) # This is mainly useful when using run_sorter in a container and spikeinterface install -DEV_MODE = True -# DEV_MODE = False +# DEV_MODE = True +DEV_MODE = False diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 82f2ae1890..5e2e9e4014 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -503,8 +503,8 @@ def reset_times(self): segment's sampling frequency is set to the recording's sampling frequency. """ for segment_index in range(self.get_num_segments()): + rs = self._recording_segments[segment_index] if self.has_time_vector(segment_index): - rs = self._recording_segments[segment_index] rs.time_vector = None rs.t_start = None rs.sampling_frequency = self.sampling_frequency @@ -608,11 +608,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) - time_vectors = self._get_time_vectors() - if time_vectors is not None: - for segment_index, time_vector in enumerate(time_vectors): - if time_vector is not None: - cached.set_times(time_vector, segment_index=segment_index) + for segment_index in range(self.get_num_segments()): + if self.has_time_vector(segment_index): + # the use of get_times is preferred since timestamps are converted to array + time_vector = self.get_times(segment_index=segment_index) + cached.set_times(time_vector, segment_index=segment_index) return cached @@ -746,6 +746,30 @@ def _select_segments(self, segment_indices): return SelectSegmentRecording(self, segment_indices=segment_indices) + def get_channel_locations( + self, + channel_ids: list | np.ndarray | tuple | None = None, + axes: "xy" | "yz" | "xz" | "xyz" = "xy", + ) -> np.ndarray: + """ + Get the physical locations of specified channels. + + Parameters + ---------- + channel_ids : array-like, optional + The IDs of the channels for which to retrieve locations. If None, retrieves locations + for all available channels. Default is None. + axes : "xy" | "yz" | "xz" | "xyz", default: "xy" + The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy". + + Returns + ------- + np.ndarray + A 2D or 3D array of shape (n_channels, n_dimensions) containing the locations of the channels. + The number of dimensions depends on the `axes` argument (e.g., 2 for "xy", 3 for "xyz"). + """ + return super().get_channel_locations(channel_ids=channel_ids, axes=axes) + def is_binary_compatible(self) -> bool: """ Checks if the recording is "binary" compatible. @@ -768,7 +792,13 @@ def get_binary_description(self): raise NotImplementedError def binary_compatible_with( - self, dtype=None, time_axis=None, file_paths_lenght=None, file_offset=None, file_suffix=None + self, + dtype=None, + time_axis=None, + file_paths_length=None, + file_offset=None, + file_suffix=None, + file_paths_lenght=None, ): """ Check is the recording is binary compatible with some constrain on @@ -779,6 +809,15 @@ def binary_compatible_with( * file_offset * file_suffix """ + + # spelling typo need to fix + if file_paths_lenght is not None: + warnings.warn( + "`file_paths_lenght` is deprecated and will be removed in 0.103.0 please use `file_paths_length`" + ) + if file_paths_length is None: + file_paths_length = file_paths_lenght + if not self.is_binary_compatible(): return False @@ -790,7 +829,7 @@ def binary_compatible_with( if time_axis is not None and time_axis != d["time_axis"]: return False - if file_paths_lenght is not None and file_paths_lenght != len(d["file_paths"]): + if file_paths_length is not None and file_paths_length != len(d["file_paths"]): return False if file_offset is not None and file_offset != d["file_offset"]: diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 428472bf93..310533c96b 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -145,6 +145,11 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False else: raise ValueError("must give Probe or ProbeGroup or list of Probe") + # check that the probe do not overlap + num_probes = len(probegroup.probes) + if num_probes > 1: + check_probe_do_not_overlap(probegroup.probes) + # handle not connected channels assert all( probe.device_channel_indices is not None for probe in probegroup.probes @@ -234,7 +239,7 @@ def set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False) warning_msg = ( "`set_probes` is now a private function and the public function will be " - "removed in 0.103.0. Please use `set_probe` or `set_probegroups` instead" + "removed in 0.103.0. Please use `set_probe` or `set_probegroup` instead" ) warn(warning_msg, category=DeprecationWarning, stacklevel=2) @@ -344,21 +349,19 @@ def set_channel_locations(self, locations, channel_ids=None): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) - def get_channel_locations(self, channel_ids=None, axes: str = "xy"): + def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - if self.get_property("contact_vector") is not None: - if len(self.get_probes()) == 1: - probe = self.get_probe() - positions = probe.contact_positions[channel_indices] - else: - all_probes = self.get_probes() - # check that multiple probes are non-overlapping - check_probe_do_not_overlap(all_probes) - all_positions = np.vstack([probe.contact_positions for probe in all_probes]) - positions = all_positions[channel_indices] - return select_axes(positions, axes) + contact_vector = self.get_property("contact_vector") + if contact_vector is not None: + # here we bypass the probe reconstruction so this works both for probe and probegroup + ndim = len(axes) + all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") + for i, dim in enumerate(axes): + all_positions[:, i] = contact_vector[dim] + positions = all_positions[channel_indices] + return positions else: locations = self.get_property("location") if locations is None: diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index a5279247f5..5240edcee7 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -136,11 +136,8 @@ def divide_segment_into_chunks(num_frames, chunk_size): else: n = num_frames // chunk_size - frame_starts = np.arange(n) * chunk_size - frame_stops = frame_starts + chunk_size - - frame_starts = frame_starts.tolist() - frame_stops = frame_stops.tolist() + frame_starts = [i * chunk_size for i in range(n)] + frame_stops = [frame_start + chunk_size for frame_start in frame_starts] if (num_frames % chunk_size) > 0: frame_starts.append(n * chunk_size) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 0ec5449bae..77d427bc88 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -131,9 +131,12 @@ def write_binary_recording( data_size_bytes = dtype_size_bytes * num_frames * num_channels file_size_bytes = data_size_bytes + byte_offset - file = open(file_path, "wb+") - file.truncate(file_size_bytes) - file.close() + # Create an empty file with file_size_bytes + with open(file_path, "wb+") as file: + # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) + file.seek(file_size_bytes - 1) + file.write(b"\0") + assert Path(file_path).is_file() # use executor (loop or workers) @@ -888,11 +891,10 @@ def check_probe_do_not_overlap(probes): for j in range(i + 1, len(probes)): probe_j = probes[j] - if np.any( np.array( [ - x_bounds_i[0] < cp[0] < x_bounds_i[1] and y_bounds_i[0] < cp[1] < y_bounds_i[1] + x_bounds_i[0] <= cp[0] <= x_bounds_i[1] and y_bounds_i[0] <= cp[1] <= y_bounds_i[1] for cp in probe_j.contact_positions ] ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 49a31738e3..4961db8524 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from packaging.version import parse from time import perf_counter import numpy as np @@ -77,6 +78,9 @@ def create_sorting_analyzer( return_scaled : bool, default: True All extensions that play with traces will use this global return_scaled : "waveforms", "noise_levels", "templates". This prevent return_scaled being differents from different extensions and having wrong snr for instance. + overwrite: bool, default: False + If True, overwrite the folder if it already exists. + Returns ------- @@ -228,6 +232,8 @@ def __repr__(self) -> str: txt += " - sparse" if self.has_recording(): txt += " - has recording" + if self.has_temporary_recording(): + txt += " - has temporary recording" ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys()) txt += "\n" + ext_txt return txt @@ -346,7 +352,7 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes): # used by create and save_as - assert recording is not None, "To create a SortingAnalyzer you need recording not None" + assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" folder = Path(folder) if folder.is_dir(): @@ -486,7 +492,11 @@ def _get_zarr_root(self, mode="r+"): if is_path_remote(str(self.folder)): mode = "r" - zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options) + # we open_consolidated only if we are in read mode + if mode in ("r+", "a"): + zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=self.storage_options) + else: + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=self.storage_options) return zarr_root @classmethod @@ -564,11 +574,27 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info = zarr_root.create_group("extensions") + zarr.consolidate_metadata(zarr_root.store) + @classmethod def load_from_zarr(cls, folder, recording=None, storage_options=None): import zarr - zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options) + zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) + + si_info = zarr_root.attrs["spikeinterface_info"] + if parse(si_info["version"]) < parse("0.101.1"): + # v0.101.0 did not have a consolidate metadata step after computing extensions. + # Here we try to consolidate the metadata and throw a warning if it fails. + try: + zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options) + zarr.consolidate_metadata(zarr_root_a.store) + except Exception as e: + warnings.warn( + "The zarr store was not properly consolidated prior to v0.101.1. " + "This may lead to unexpected behavior in loading extensions. " + "Please consider re-generating the SortingAnalyzer object." + ) # load internal sorting in memory sorting = NumpySorting.from_sorting( @@ -614,6 +640,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): format="zarr", sparsity=sparsity, return_scaled=return_scaled, + storage_options=storage_options, ) return sorting_analyzer @@ -1101,9 +1128,16 @@ def get_probe(self): def get_channel_locations(self) -> np.ndarray: # important note : contrary to recording # this give all channel locations, so no kwargs like channel_ids and axes - all_probes = self.get_probegroup().probes - all_positions = np.vstack([probe.contact_positions for probe in all_probes]) - return all_positions + probegroup = self.get_probegroup() + probe_as_numpy_array = probegroup.to_numpy(complete=True) + # we need to sort by device_channel_indices to ensure the order of locations is correct + probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])] + ndim = probegroup.ndim + locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") + # here we only loop through xy because only 2d locations are supported + for i, dim in enumerate(["x", "y"][:ndim]): + locations[:, i] = probe_as_numpy_array[dim] + return locations def channel_ids_to_indices(self, channel_ids) -> np.ndarray: all_channel_ids = list(self.rec_attributes["channel_ids"]) @@ -1189,7 +1223,7 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar extensions[ext_name] = ext_params self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs) else: - raise ValueError("SortingAnalyzer.compute() need str, dict or list") + raise ValueError("SortingAnalyzer.compute() needs a str, dict or list") def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs) -> "AnalyzerExtension": """ @@ -1323,7 +1357,9 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job for extension_name, extension_params in extensions_with_pipeline.items(): extension_class = get_extension_class(extension_name) - assert self.has_recording(), f"Extension {extension_name} need the recording" + assert ( + self.has_recording() or self.has_temporary_recording() + ), f"Extension {extension_name} requires the recording" for variable_name in extension_class.nodepipeline_variables: result_routage.append((extension_name, variable_name)) @@ -1462,7 +1498,7 @@ def delete_extension(self, extension_name) -> None: if self.format != "memory" and self.has_extension(extension_name): # need a reload to reset the folder ext = self.load_extension(extension_name) - ext.reset() + ext.delete() # remove from dict self.extensions.pop(extension_name, None) @@ -1571,17 +1607,17 @@ def _sort_extensions_by_dependency(extensions): def _get_children_dependencies(extension_name): """ Extension classes have a `depend_on` attribute to declare on which class they - depend. For instance "templates" depend on "waveforms". "waveforms depends on "random_spikes". + depend on. For instance "templates" depends on "waveforms". "waveforms" depends on "random_spikes". - This function is making the reverse way : get all children that depend of a + This function is going the opposite way: it finds all children that depend on a particular extension. - This is recursive so this includes : children and so grand children and great grand children + The implementation is recursive so that the output includes children, grand children, great grand children, etc. - This function is usefull for deleting on recompute. - For instance recompute the "waveforms" need to delete "template" - This make sens if "ms_before" is change in "waveforms" because the template also depends - on this parameters. + This function is useful for deleting existing extensions on recompute. + For instance, recomputing the "waveforms" needs to delete the "templates", since the latter depends on the former. + For this particular example, if we change the "ms_before" parameter of the "waveforms", also the "templates" will + require recomputation as this parameter is inherited. """ names = [] children = _extension_children[extension_name] @@ -1953,12 +1989,14 @@ def load_data(self): if "dict" in ext_data_.attrs: ext_data = ext_data_[0] elif "dataframe" in ext_data_.attrs: - import xarray + import pandas as pd - ext_data = xarray.open_zarr( - ext_data_.store, group=f"{extension_group.name}/{ext_data_name}" - ).to_pandas() - ext_data.index.rename("", inplace=True) + index = ext_data_["index"] + ext_data = pd.DataFrame(index=index) + for col in ext_data_.keys(): + if col != "index": + ext_data.loc[:, col] = ext_data_[col][:] + ext_data = ext_data.convert_dtypes() elif "object" in ext_data_.attrs: ext_data = ext_data_[0] else: @@ -2004,24 +2042,31 @@ def run(self, save=True, **kwargs): # NB: this call to _save_params() also resets the folder or zarr group self._save_params() self._save_importing_provenance() - self._save_run_info() t_start = perf_counter() self._run(**kwargs) t_end = perf_counter() self.run_info["runtime_s"] = t_end - t_start + self.run_info["run_completed"] = True if save and not self.sorting_analyzer.is_read_only(): + self._save_run_info() self._save_data(**kwargs) + if self.format == "zarr": + import zarr - self.run_info["run_completed"] = True - self._save_run_info() + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def save(self, **kwargs): self._save_params() self._save_importing_provenance() - self._save_data(**kwargs) self._save_run_info() + self._save_data(**kwargs) + + if self.format == "zarr": + import zarr + + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _save_data(self, **kwargs): if self.format == "memory": @@ -2062,7 +2107,7 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - + import zarr import numcodecs extension_group = self._get_zarr_extension_group(mode="r+") @@ -2081,12 +2126,12 @@ def _save_data(self, **kwargs): elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): - ext_data.to_xarray().to_zarr( - store=extension_group.store, - group=f"{extension_group.name}/{ext_data_name}", - mode="a", - ) - extension_group[ext_data_name].attrs["dataframe"] = True + df_group = extension_group.create_group(ext_data_name) + # first we save the index + df_group.create_dataset(name="index", data=ext_data.index.to_numpy()) + for col in ext_data.columns: + df_group.create_dataset(name=col, data=ext_data[col].to_numpy()) + df_group.attrs["dataframe"] = True else: # any object try: @@ -2110,8 +2155,35 @@ def _reset_extension_folder(self): elif self.format == "zarr": import zarr - zarr_root = zarr.open(self.folder, mode="r+") - extension_group = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) + zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") + _ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) + zarr.consolidate_metadata(zarr_root.store) + + def _delete_extension_folder(self): + """ + Delete the extension in a folder (binary or zarr). + """ + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + if extension_folder.is_dir(): + shutil.rmtree(extension_folder) + + elif self.format == "zarr": + import zarr + + zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") + if self.extension_name in zarr_root["extensions"]: + del zarr_root["extensions"][self.extension_name] + zarr.consolidate_metadata(zarr_root.store) + + def delete(self): + """ + Delete the extension from the folder or zarr and from the dict. + """ + self._delete_extension_folder() + self.params = None + self.run_info = self._default_run_info_dict() + self.data = dict() def reset(self): """ @@ -2128,7 +2200,7 @@ def set_params(self, save=True, **params): Set parameters for the extension and make it persistent in json. """ - # this ensure data is also deleted and corresponf to params + # this ensure data is also deleted and corresponds to params # this also ensure the group is created self._reset_extension_folder() @@ -2141,7 +2213,6 @@ def set_params(self, save=True, **params): if save: self._save_params() self._save_importing_provenance() - self._save_run_info() def _save_params(self): params_to_save = self.params.copy() @@ -2197,9 +2268,10 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert self.run_info[ - "run_completed" - ], f"You must run the extension {self.extension_name} before retrieving data" + if self.run_info is not None: + assert self.run_info[ + "run_completed" + ], f"You must run the extension {self.extension_name} before retrieving data" assert len(self.data) > 0, "Extension has been run but no data found." return self._get_data(*args, **kwargs) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index a38562ea2c..fd613e1fcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import warnings from .basesorting import BaseSorting @@ -13,28 +14,35 @@ _sparsity_doc = """ method : str * "best_channels" : N best channels with the largest amplitude. Use the "num_channels" argument to specify the - number of channels. - * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um + number of channels. + * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um. * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument - to specify the SNR threshold (in units of noise levels) - * "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of noise levels) + to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. + * "amplitude" : threshold based on the amplitude values on every channels. Use the "threshold" argument + to specify the ptp threshold (in units of amplitude) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. * "energy" : threshold based on the expected energy that should be present on the channels, - given their noise levels. Use the "threshold" argument to specify the SNR threshold + given their noise levels. Use the "threshold" argument to specify the energy threshold (in units of noise levels) - * "by_property" : sparsity is given by a property of the recording and sorting(e.g. "group"). - Use the "by_property" argument to specify the property name. + * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). + In this case the sparsity for each unit is given by the channels that have the same property + value as the unit. Use the "by_property" argument to specify the property name. + * "ptp: : deprecated, use the 'snr' method with the 'peak_to_peak' amplitude mode instead. - peak_sign : str - Sign of the template to compute best channels ("neg", "pos", "both") + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. num_channels : int - Number of channels for "best_channels" method + Number of channels for "best_channels" method. radius_um : float - Radius in um for "radius" method + Radius in um for "radius" method. threshold : float - Threshold in SNR "threshold" method + Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). + For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak" + Mode to compute the amplitude of the templates for the "snr", "amplitude", and "best_channels" methods. by_property : object - Property name for "by_property" method + Property name for "by_property" method. """ @@ -277,18 +285,35 @@ def from_dict(cls, dictionary: dict): ## Some convinient function to compute sparsity from several strategy @classmethod - def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): + def from_best_channels( + cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg", amplitude_mode="extremum" + ): """ Construct sparsity from N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + num_channels : int + Number of channels for "best_channels" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity """ from .template_tools import get_template_amplitudes - print(templates_or_sorting_analyzer) mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) - peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign) + peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode) for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] @@ -299,7 +324,21 @@ def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_si def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): """ Construct sparsity from a radius around the best channel. - Use the "radius_um" argument to specify the radius in um + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + radius_um : float + Radius in um for "radius" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_extremum_channel @@ -316,10 +355,38 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, peak_sign="neg"): + def from_snr( + cls, + templates_or_sorting_analyzer, + threshold, + amplitude_mode="extremum", + peak_sign="neg", + noise_levels=None, + ): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "snr" method (in units of noise levels). + noise_levels : np.array | None, default: None + Noise levels required for the "snr" method. You can use the + `get_noise_levels()` function to compute them. + If the input is a `SortingAnalyzer`, the noise levels are automatically retrieved + if the `noise_levels` extension is present. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute amplitudes. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates for the "snr" method. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer @@ -338,13 +405,13 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None + assert noise_levels is not None, "To compute sparsity from snr you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode="extremum", return_scaled=return_scaled + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=return_scaled ) for unit_ind, unit_id in enumerate(unit_ids): @@ -356,38 +423,81 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ Construct sparsity from a thresholds based on template peak-to-peak values. - Use the "threshold" argument to specify the SNR threshold. + Use the "threshold" argument to specify the peak-to-peak threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "ptp" method (in units of amplitude). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ + warnings.warn( + "The 'ptp' method is deprecated and will be removed in version 0.103.0. " + "Please use the 'snr' method with the 'peak_to_peak' amplitude mode instead.", + DeprecationWarning, + ) + return cls.from_snr( + templates_or_sorting_analyzer, threshold, amplitude_mode="peak_to_peak", noise_levels=noise_levels + ) - assert ( - templates_or_sorting_analyzer.sparsity is None - ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + @classmethod + def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"): + """ + Construct sparsity from a threshold based on template amplitude. + The amplitude is computed with the specified amplitude mode and it is assumed + that the amplitude is in uV. The input `Templates` or `SortingAnalyzer` object must + have scaled templates. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "amplitude" method (in uV). + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer from .template import Templates + assert ( + templates_or_sorting_analyzer.sparsity is None + ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + unit_ids = templates_or_sorting_analyzer.unit_ids channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - ext = templates_or_sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" - noise_levels = ext.data["noise_levels"] - return_scaled = templates_or_sorting_analyzer.return_scaled + assert templates_or_sorting_analyzer.return_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `return_scaled=True` when computing the templates." + ) elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None - return_scaled = templates_or_sorting_analyzer.is_scaled - - from .template_tools import get_dense_templates_array + assert templates_or_sorting_analyzer.is_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `is_scaled=True` when creating the Templates object." + ) mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) - templates_ptps = np.ptp(templates_array, axis=1) + peak_values = get_template_amplitudes( + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=True + ) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) + chan_inds = np.nonzero((np.abs(peak_values[unit_id])) >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -396,6 +506,19 @@ def from_energy(cls, sorting_analyzer, threshold): """ Construct sparsity from a threshold based on per channel energy ratio. Use the "threshold" argument to specify the SNR threshold. + This method requires the "waveforms" and "noise_levels" extensions to be computed. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + threshold : float + Threshold for "energy" method (in units of noise levels). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ assert sorting_analyzer.sparsity is None, "To compute sparsity with energy you need a dense SortingAnalyzer" @@ -403,7 +526,7 @@ def from_energy(cls, sorting_analyzer, threshold): # noise_levels ext = sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + assert ext is not None, "To compute sparsity from energy you need to compute 'noise_levels' first" noise_levels = ext.data["noise_levels"] # waveforms @@ -421,51 +544,72 @@ def from_energy(cls, sorting_analyzer, threshold): return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) @classmethod - def from_property(cls, sorting_analyzer, by_property): + def from_property(cls, sorting, recording, by_property): """ Construct sparsity witha property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. + + Parameters + ---------- + sorting : Sorting + A Sorting object. + recording : Recording + A Recording object. + by_property : object + Property name for "by_property" method. Both the recording and sorting must have this property set. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ # check consistency - assert ( - by_property in sorting_analyzer.recording.get_property_keys() - ), f"Property {by_property} is not a recording property" - assert ( - by_property in sorting_analyzer.sorting.get_property_keys() - ), f"Property {by_property} is not a sorting property" + assert by_property in recording.get_property_keys(), f"Property {by_property} is not a recording property" + assert by_property in sorting.get_property_keys(), f"Property {by_property} is not a sorting property" - mask = np.zeros((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") - rec_by = sorting_analyzer.recording.split_by(by_property) - for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): - unit_property = sorting_analyzer.sorting.get_property(by_property)[unit_ind] + mask = np.zeros((sorting.unit_ids.size, recording.channel_ids.size), dtype="bool") + rec_by = recording.split_by(by_property) + for unit_ind, unit_id in enumerate(sorting.unit_ids): + unit_property = sorting.get_property(by_property)[unit_ind] assert ( unit_property in rec_by.keys() ), f"Unit property {unit_property} cannot be found in the recording properties" - chan_inds = sorting_analyzer.recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) + chan_inds = recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) mask[unit_ind, chan_inds] = True - return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) + return cls(mask, sorting.unit_ids, recording.channel_ids) @classmethod def create_dense(cls, sorting_analyzer): """ Create a sparsity object with all selected channel for all units. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + + Returns + ------- + sparsity : ChannelSparsity + The full sparsity. """ mask = np.ones((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) def compute_sparsity( - templates_or_sorting_analyzer, - noise_levels=None, - method="radius", - peak_sign="neg", - num_channels=5, - radius_um=100.0, - threshold=5, - by_property=None, -): + templates_or_sorting_analyzer: "Templates | SortingAnalyzer", + noise_levels: np.ndarray | None = None, + method: "radius" | "best_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", + num_channels: int | None = 5, + radius_um: float | None = 100.0, + threshold: float | None = 5, + by_property: str | None = None, + amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", +) -> ChannelSparsity: """ - Get channel sparsity (subset of channels) for each template with several methods. + Compute channel sparsity from a `SortingAnalyzer` for each template with several methods. Parameters ---------- @@ -491,7 +635,7 @@ def compute_sparsity( # to keep backward compatibility templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer - if method in ("best_channels", "radius", "snr", "ptp"): + if method in ("best_channels", "radius", "snr", "amplitude", "ptp"): assert isinstance( templates_or_sorting_analyzer, (Templates, SortingAnalyzer) ), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer" @@ -500,11 +644,6 @@ def compute_sparsity( templates_or_sorting_analyzer, SortingAnalyzer ), f"compute_sparsity(method='{method}') need SortingAnalyzer" - if method in ("snr", "ptp") and isinstance(templates_or_sorting_analyzer, Templates): - assert ( - noise_levels is not None - ), f"compute_sparsity(..., method='{method}') with Templates need noise_levels as input" - if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels, peak_sign=peak_sign) @@ -514,21 +653,36 @@ def compute_sparsity( elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_snr( - templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, peak_sign=peak_sign - ) - elif method == "ptp": - assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, + peak_sign=peak_sign, + amplitude_mode=amplitude_mode, + ) + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( + templates_or_sorting_analyzer, + threshold, + amplitude_mode=amplitude_mode, + peak_sign=peak_sign, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_energy(templates_or_sorting_analyzer, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" - sparsity = ChannelSparsity.from_property(templates_or_sorting_analyzer, by_property) + sparsity = ChannelSparsity.from_property( + templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property + ) + elif method == "ptp": + # TODO: remove after deprecation + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp( + templates_or_sorting_analyzer, + threshold, + noise_levels=noise_levels, + ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -544,16 +698,21 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" = "radius", - peak_sign: str = "neg", + method: "radius" | "best_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, + threshold: float | None = 5, + amplitude_mode: "extremum" | "peak_to_peak" = "extremum", + by_property: str | None = None, + noise_levels: np.ndarray | list | None = None, **job_kwargs, ): """ - Estimate the sparsity without needing a SortingAnalyzer or Templates object - This is faster than `spikeinterface.waveforms_extractor.precompute_sparsity()` and it - traverses the recording to compute the average templates for each unit. + Estimate the sparsity without needing a SortingAnalyzer or Templates object. + In case the sparsity method needs templates, they are computed on-the-fly. + For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. + These can be computed with the `get_noise_levels()` function. Contrary to the previous implementation: * all units are computed in one read of recording @@ -561,29 +720,23 @@ def estimate_sparsity( * it doesn't consume too much memory * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. + Parameters ---------- sorting : BaseSorting The sorting recording : BaseRecording The recording - num_spikes_for_sparsity : int, default: 100 How many spikes per units to compute the sparsity ms_before : float, default: 1.0 Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels", default: "radius" - Sparsity method propagated to the `compute_sparsity()` function. - Only "radius" or "best_channels" are implemented - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute best channels - radius_um : float, default: 100.0 - Used for "radius" method - num_channels : int, default: 5 - Used for "best_channels" method - + noise_levels : np.array | None, default: None + Noise levels required for the "snr" and "energy" methods. You can use the + `get_noise_levels()` function to compute them. {} Returns @@ -594,7 +747,10 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels"), "estimate_sparsity() handle only method='radius' or 'best_channel'" + assert method in ("radius", "best_channels", "snr", "amplitude", "by_property", "ptp"), ( + f"method={method} is not available for `estimate_sparsity()`. " + "Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property', 'ptp' (deprecated)" + ) if recording.get_probes() == 1: # standard case @@ -605,44 +761,81 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) - - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - random_spikes_indices = random_spikes_selection( - sorting, - num_samples, - method="uniform", - max_spikes_per_unit=num_spikes_for_sparsity, - margin_size=max(nbefore, nafter), - seed=2205, - ) - spikes = sorting.to_spike_vector() - spikes = spikes[random_spikes_indices] - - templates_array = estimate_templates_with_accumulator( - recording, - spikes, - sorting.unit_ids, - nbefore, - nafter, - return_scaled=False, - job_name="estimate_sparsity", - **job_kwargs, - ) - templates = Templates( - templates_array=templates_array, - sampling_frequency=recording.sampling_frequency, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=sorting.unit_ids, - probe=probe, - ) + if method != "by_property": + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_sparsity, + margin_size=max(nbefore, nafter), + seed=2205, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name="estimate_sparsity", + **job_kwargs, + ) + templates = Templates( + templates_array=templates_array, + sampling_frequency=recording.sampling_frequency, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=sorting.unit_ids, + probe=probe, + ) - sparsity = compute_sparsity( - templates, method=method, peak_sign=peak_sign, num_channels=num_channels, radius_um=radius_um - ) + if method == "best_channels": + assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" + sparsity = ChannelSparsity.from_best_channels( + templates, num_channels, peak_sign=peak_sign, amplitude_mode=amplitude_mode + ) + elif method == "radius": + assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" + sparsity = ChannelSparsity.from_radius(templates, radius_um, peak_sign=peak_sign) + elif method == "snr": + assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." + ) + sparsity = ChannelSparsity.from_snr( + templates, + threshold, + noise_levels=noise_levels, + peak_sign=peak_sign, + amplitude_mode=amplitude_mode, + ) + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( + templates, threshold, amplitude_mode=amplitude_mode, peak_sign=peak_sign + ) + elif method == "ptp": + # TODO: remove after deprecation + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." + ) + sparsity = ChannelSparsity.from_ptp(templates, threshold, noise_levels=noise_levels) + else: + raise ValueError(f"compute_sparsity() method={method} does not exists") + else: + assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" + sparsity = ChannelSparsity.from_property(sorting, recording, by_property) return sparsity diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index b4d96a3391..626899ab6e 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -79,15 +79,20 @@ def _check_result_extension(sorting_analyzer, extension_name, cache_folder): ) def test_ComputeRandomSpikes(format, sparse, create_cache_folder): cache_folder = create_cache_folder + print("Creating analyzer") sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse) + print("Computing random spikes") ext = sorting_analyzer.compute("random_spikes", max_spikes_per_unit=10, seed=2205) indices = ext.data["random_spikes_indices"] assert indices.size == 10 * sorting_analyzer.sorting.unit_ids.size + print("Checking results") _check_result_extension(sorting_analyzer, "random_spikes", cache_folder) + print("Delering extension") sorting_analyzer.delete_extension("random_spikes") + print("Re-computing random spikes") ext = sorting_analyzer.compute("random_spikes", method="all") indices = ext.data["random_spikes_indices"] assert indices.size == len(sorting_analyzer.sorting.to_spike_vector()) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 9c354510ac..df614978ba 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -10,7 +10,7 @@ import numpy as np from numpy.testing import assert_raises -from probeinterface import Probe +from probeinterface import Probe, ProbeGroup, generate_linear_probe from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, load_extractor, get_default_zarr_compressor from spikeinterface.core.base import BaseExtractor @@ -298,6 +298,9 @@ def test_BaseRecording(create_cache_folder): assert time_info["time_vector"] is None assert time_info["sampling_frequency"] == rec.sampling_frequency + # resetting time again should be ok + rec.reset_times() + # test 3d probe rec_3d = generate_recording(ndim=3, num_channels=30) locations_3d = rec_3d.get_property("location") @@ -355,6 +358,34 @@ def test_BaseRecording(create_cache_folder): assert np.allclose(rec_u.get_traces(cast_unsigned=True), rec_i.get_traces().astype("float")) +def test_interleaved_probegroups(): + recording = generate_recording(durations=[1.0], num_channels=16) + + probe1 = generate_linear_probe(num_elec=8, ypitch=20.0) + probe2_overlap = probe1.copy() + + probegroup_overlap = ProbeGroup() + probegroup_overlap.add_probe(probe1) + probegroup_overlap.add_probe(probe2_overlap) + probegroup_overlap.set_global_device_channel_indices(np.arange(16)) + + # setting overlapping probes should raise an error + with pytest.raises(Exception): + recording.set_probegroup(probegroup_overlap) + + probe2 = probe1.copy() + probe2.move([100.0, 100.0]) + probegroup = ProbeGroup() + probegroup.add_probe(probe1) + probegroup.add_probe(probe2) + probegroup.set_global_device_channel_indices(np.random.permutation(16)) + + recording.set_probegroup(probegroup) + probegroup_set = recording.get_probegroup() + # check that the probe group is correctly set, by sorting the device channel indices + assert np.array_equal(probegroup_set.get_global_device_channel_indices()["device_channel_indices"], np.arange(16)) + + def test_rename_channels(): recording = generate_recording(durations=[1.0], num_channels=3) renamed_recording = recording.rename_channels(new_channel_ids=["a", "b", "c"]) @@ -396,4 +427,5 @@ def test_time_slice_with_time_vector(): if __name__ == "__main__": - test_BaseRecording() + # test_BaseRecording() + test_interleaved_probegroups() diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 3f45487f4c..5c7e267cc6 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -126,6 +126,8 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): def test_load_without_runtime_info(tmp_path, dataset): + import zarr + recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_run_info" @@ -153,6 +155,7 @@ def test_load_without_runtime_info(tmp_path, dataset): root = sorting_analyzer._get_zarr_root(mode="r+") for ext in extensions: del root["extensions"][ext].attrs["run_info"] + zarr.consolidate_metadata(root.store) # should raise a warning for missing run_info with pytest.warns(UserWarning): sorting_analyzer = load_sorting_analyzer(folder, format="auto") @@ -178,6 +181,27 @@ def test_SortingAnalyzer_tmp_recording(dataset): sorting_analyzer.set_temporary_recording(recording_sliced) +def test_SortingAnalyzer_interleaved_probegroup(dataset): + from probeinterface import generate_linear_probe, ProbeGroup + + recording, sorting = dataset + num_channels = recording.get_num_channels() + probe1 = generate_linear_probe(num_elec=num_channels // 2, ypitch=20.0) + probe2 = probe1.copy() + probe2.move([100.0, 100.0]) + + probegroup = ProbeGroup() + probegroup.add_probe(probe1) + probegroup.add_probe(probe2) + probegroup.set_global_device_channel_indices(np.random.permutation(num_channels)) + + recording = recording.set_probegroup(probegroup) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + # check that locations are correct + assert np.array_equal(recording.get_channel_locations(), sorting_analyzer.get_channel_locations()) + + def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): register_result_extension(DummyAnalyzerExtension) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a192d90502..ace869df8c 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -3,7 +3,7 @@ import numpy as np import json -from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, Templates +from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, get_noise_levels from spikeinterface.core.core_tools import check_json from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer @@ -86,7 +86,7 @@ def test_sparsify_waveforms(): num_active_channels = len(non_zero_indices) assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels) - # Test round-trip (note that this is loosy) + # Test round-trip (note that this is lossy) unit_id = unit_ids[unit_id] non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) @@ -195,6 +195,82 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) + # by_property + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="by_property", + by_property="group", + progress_bar=True, + n_jobs=1, + ) + assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) + + # amplitude + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="amplitude", + threshold=5, + amplitude_mode="peak_to_peak", + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + + # snr: fails without noise levels + with pytest.raises(AssertionError): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + # snr: works with noise levels + noise_levels = get_noise_levels(recording) + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + # ptp: just run it + print(noise_levels) + + with pytest.warns(DeprecationWarning): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + def test_compute_sparsity(): recording, sorting = get_dataset() @@ -212,9 +288,14 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") - sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) + sparsity = compute_sparsity( + sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" + ) + sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) # using object Templates templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") @@ -222,7 +303,10 @@ def test_compute_sparsity(): sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") - sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) + sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) if __name__ == "__main__": diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index a50a56bf85..5c7584ecd8 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -536,6 +536,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): ext = ComputeRandomSpikes(sorting_analyzer) ext.params = dict() ext.data = dict(random_spikes_indices=random_spikes_indices) + ext.run_info = None sorting_analyzer.extensions["random_spikes"] = ext ext = ComputeWaveforms(sorting_analyzer) @@ -545,6 +546,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): dtype=params["dtype"], ) ext.data["waveforms"] = waveforms + ext.run_info = None sorting_analyzer.extensions["waveforms"] = ext # templates saved dense @@ -559,6 +561,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): ext.params = dict(ms_before=params["ms_before"], ms_after=params["ms_after"], operators=list(templates.keys())) for mode, arr in templates.items(): ext.data[mode] = arr + ext.run_info = None sorting_analyzer.extensions["templates"] = ext for old_name, new_name in old_extension_to_new_class_map.items(): @@ -631,6 +634,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): ext.set_params(**updated_params, save=False) if ext.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() + ext.run_info = None sorting_analyzer.extensions[new_name] = ext diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index acd7ebe8ad..3f73161218 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -234,7 +234,7 @@ class BlackrockSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = BlackrockSortingExtractor downloads = ["blackrock"] entities = [ - "blackrock/FileSpec2.3001.nev", + dict(file_path=local_folder / "blackrock/FileSpec2.3001.nev", sampling_frequency=30_000.0), dict(file_path=local_folder / "blackrock/blackrock_2_1/l101210-001.nev", sampling_frequency=30_000.0), ] @@ -278,8 +278,8 @@ class Spike2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): @pytest.mark.skipif( - version.parse(platform.python_version()) >= version.parse("3.10"), - reason="Sonpy only testing with Python < 3.10!", + version.parse(platform.python_version()) >= version.parse("3.10") or platform.system() == "Darwin", + reason="Sonpy only testing with Python < 3.10 and not supported on macOS!", ) class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = CedRecordingExtractor @@ -293,6 +293,7 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +@pytest.mark.skipif(platform.system() == "Darwin", reason="Maxwell plugin not supported on macOS") class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MaxwellRecordingExtractor downloads = ["maxwell"] diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index f1f89403c7..1871c11b85 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -359,12 +359,12 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) p = self.params - we = self.sorting_analyzer - sorting = we.sorting + sorting_analyzer = self.sorting_analyzer + sorting = sorting_analyzer.sorting assert ( - we.has_recording() - ), "To compute PCA projections for all spikes, the waveform extractor needs the recording" - recording = we.recording + sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording() + ), "To compute PCA projections for all spikes, the sorting analyzer needs the recording" + recording = sorting_analyzer.recording # assert sorting.get_num_segments() == 1 assert p["mode"] in ("by_channel_local", "by_channel_global") @@ -374,8 +374,9 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): sparsity = self.sorting_analyzer.sparsity if sparsity is None: - sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} - max_channels_per_template = we.get_num_channels() + num_channels = recording.get_num_channels() + sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids} + max_channels_per_template = num_channels else: sparse_channels_indices = sparsity.unit_id_to_channel_indices max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) @@ -449,9 +450,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): return pca_models def _fit_by_channel_global(self, progress_bar): - # we = self.sorting_analyzer p = self.params - # unit_ids = we.unit_ids unit_ids = self.sorting_analyzer.unit_ids # there is one unique PCA accross channels diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 45ba55dee4..306e9594b8 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -62,6 +62,8 @@ class ComputeTemplateMetrics(AnalyzerExtension): For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics + delete_existing_metrics : bool, default: False + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. metrics_kwargs : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 @@ -109,9 +111,12 @@ def _set_params( sparsity=None, metrics_kwargs=None, include_multi_channel_metrics=False, + delete_existing_metrics=False, **other_kwargs, ): + import pandas as pd + # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() if include_multi_channel_metrics or ( metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) @@ -139,12 +144,36 @@ def _set_params( metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) + metrics_to_compute = metric_names + tm_extension = self.sorting_analyzer.get_extension("template_metrics") + if delete_existing_metrics is False and tm_extension is not None: + + existing_params = tm_extension.params["metrics_kwargs"] + # checks that existing metrics were calculated using the same params + if existing_params != metrics_kwargs_: + warnings.warn( + f"The parameters used to calculate the previous template metrics are different" + f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " + f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." + ) + tm_extension.params["metric_names"] = [] + existing_metric_names = [] + else: + existing_metric_names = tm_extension.params["metric_names"] + + existing_metric_names_propogated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute + ] + metric_names = metrics_to_compute + existing_metric_names_propogated + params = dict( - metric_names=[str(name) for name in np.unique(metric_names)], + metric_names=metric_names, sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), metrics_kwargs=metrics_kwargs_, + delete_existing_metrics=delete_existing_metrics, + metrics_to_compute=metrics_to_compute, ) return params @@ -158,6 +187,7 @@ def _merge_extension_data( ): import pandas as pd + metric_names = self.params["metric_names"] old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -166,19 +196,20 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs + ) new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute template metrics. """ import pandas as pd from scipy.signal import resample_poly - metric_names = self.params["metric_names"] sparsity = self.params["sparsity"] peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] @@ -287,13 +318,37 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") value = np.nan template_metrics.at[index, metric_name] = value + + # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns + # (in case of NaN values) + template_metrics = template_metrics.convert_dtypes() return template_metrics def _run(self, verbose=False): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose + + delete_existing_metrics = self.params["delete_existing_metrics"] + metrics_to_compute = self.params["metrics_to_compute"] + + # compute the metrics which have been specified by the user + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metrics_to_compute ) + existing_metrics = [] + tm_extension = self.sorting_analyzer.get_extension("template_metrics") + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): + existing_metrics = tm_extension.params["metric_names"] + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metrics_to_compute): + computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + + self.data["metrics"] = computed_metrics + def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 3945e71881..2207b98da6 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -5,7 +5,7 @@ import numpy as np from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import create_sorting_analyzer +from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer from spikeinterface.core import estimate_sparsity @@ -116,6 +116,8 @@ def _check_one(self, sorting_analyzer, extension_class, params): with the passed parameters, and check the output is not empty, the extension exists and `select_units()` method works. """ + import pandas as pd + if extension_class.need_job_kwargs: job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) else: @@ -138,6 +140,26 @@ def _check_one(self, sorting_analyzer, extension_class, params): merged = sorting_analyzer.merge_units(some_merges, format="memory", merging_mode="soft", sparsity_overlap=0.0) assert len(merged.unit_ids) == num_units_after_merge + # test roundtrip + if sorting_analyzer.format in ("binary_folder", "zarr"): + sorting_analyzer_loaded = load_sorting_analyzer(sorting_analyzer.folder) + ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name) + for ext_data_name, ext_data_loaded in ext_loaded.data.items(): + if isinstance(ext_data_loaded, np.ndarray): + assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) + elif isinstance(ext_data_loaded, pd.DataFrame): + # skip nan values + for col in ext_data_loaded.columns: + np.testing.assert_array_almost_equal( + ext.data[ext_data_name][col].dropna().to_numpy(), + ext_data_loaded[col].dropna().to_numpy(), + decimal=5, + ) + elif isinstance(ext_data_loaded, dict): + assert ext.data[ext_data_name] == ext_data_loaded + else: + continue + def run_extension_tests(self, extension_class, params): """ Convenience function to perform all checks on the extension diff --git a/src/spikeinterface/postprocessing/tests/conftest.py b/src/spikeinterface/postprocessing/tests/conftest.py new file mode 100644 index 0000000000..51ac8aa250 --- /dev/null +++ b/src/spikeinterface/postprocessing/tests/conftest.py @@ -0,0 +1,33 @@ +import pytest + +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + + +def _small_sorting_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=10, + seed=1205, + ) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 694aa083cc..5056d4ff2a 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,6 +1,108 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeTemplateMetrics import pytest +import csv + +from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func + +template_metrics = list(_single_channel_metric_name_to_func.keys()) + + +def test_compute_new_template_metrics(small_sorting_analyzer): + """ + Computes template metrics then computes a subset of template metrics, and checks + that the old template metrics are not deleted. + + Then computes template metrics with new parameters and checks that old metrics + are deleted. + """ + + # calculate just exp_decay + small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + assert "exp_decay" in list(template_metric_extension.get_data().keys()) + assert "half_width" not in list(template_metric_extension.get_data().keys()) + + # calculate all template metrics + small_sorting_analyzer.compute("template_metrics") + # calculate just exp_decay - this should not delete any other metrics + small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + set(template_metrics) == set(template_metric_extension.get_data().keys()) + + # calculate just exp_decay with delete_existing_metrics + small_sorting_analyzer.compute( + {"template_metrics": {"metric_names": ["exp_decay"], "delete_existing_metrics": True}} + ) + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + computed_metric_names = template_metric_extension.get_data().keys() + + for metric_name in template_metrics: + if metric_name == "exp_decay": + assert metric_name in computed_metric_names + else: + assert metric_name not in computed_metric_names + + # check that, when parameters are changed, the old metrics are deleted + small_sorting_analyzer.compute( + {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + ) + + +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified template metrics and checks order is propogated. + """ + specified_metric_names = ["peak_trough_ratio", "num_negative_peaks", "half_width"] + small_sorting_analyzer.compute("template_metrics", metric_names=specified_metric_names) + tm_keys = small_sorting_analyzer.get_extension("template_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == tm_keys[i] + + +def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes template metrics in binary folder format. Then computes subsets of template + metrics and checks if they are saved correctly. + """ + + small_sorting_analyzer.compute("template_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + template_metrics_filename = output_folder / "extensions" / "template_metrics" / "metrics.csv" + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in template_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in template_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in template_metrics: + if metric_name == "half_width": + assert metric_name in metric_names + else: + assert metric_name not in metric_names class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index ddb981a944..14c565a290 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -13,6 +13,7 @@ from spikeinterface.core.core_tools import SIJsonEncoder from spikeinterface.core.job_tools import _shared_job_kwargs_doc + motion_options_preset = { # dredge "dredge": { @@ -277,10 +278,11 @@ def correct_motion( This function depends on several modular components of :py:mod:`spikeinterface.sortingcomponents`. - If select_kwargs is None then all peak are used for localized. + If `select_kwargs` is None then all peak are used for localized. The recording must be preprocessed (filter and denoised at least), and we recommend to not use whithening before motion estimation. + Since the motion interpolation requires a "float" recording, the recording is casted to float32 if necessary. Parameters for each step are handled as separate dictionaries. For more information please check the documentation of the following functions: @@ -435,6 +437,8 @@ def correct_motion( t1 = time.perf_counter() run_times["estimate_motion"] = t1 - t0 + if recording.get_dtype().kind != "f": + recording = recording.astype("float32") recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) motion_info = dict( diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 2de31ad750..8dfd41cf88 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -69,6 +69,9 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): return num_spikes +_default_params["num_spikes"] = {} + + def compute_firing_rates(sorting_analyzer, unit_ids=None): """ Compute the firing rate across segments. @@ -98,6 +101,9 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): return firing_rates +_default_params["firing_rate"] = {} + + def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -1550,3 +1556,10 @@ def compute_sd_ratio( sd_ratio[unit_id] = unit_std / std_noise return sd_ratio + + +_default_params["sd_ratio"] = dict( + censored_period_ms=4.0, + correct_for_drift=True, + correct_for_template_itself=True, +) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index cdf6151e95..3b6c6d3e50 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -2,7 +2,6 @@ from __future__ import annotations - import warnings from copy import deepcopy @@ -12,7 +11,12 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from .quality_metric_list import compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names +from .quality_metric_list import ( + compute_pc_metrics, + _misc_metric_name_to_func, + _possible_pc_metric_names, + compute_name_to_column_names, +) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -30,8 +34,10 @@ class ComputeQualityMetrics(AnalyzerExtension): qm_params : dict or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` - skip_pc_metrics : bool + skip_pc_metrics : bool, default: False If True, PC metrics computation is skipped. + delete_existing_metrics : bool, default: False + If True, any quality metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept. Returns ------- @@ -49,7 +55,17 @@ class ComputeQualityMetrics(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True - def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False): + def _set_params( + self, + metric_names=None, + qm_params=None, + peak_sign=None, + seed=None, + skip_pc_metrics=False, + delete_existing_metrics=False, + metrics_to_compute=None, + ): + if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list @@ -71,12 +87,24 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign + metrics_to_compute = metric_names + qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + if delete_existing_metrics is False and qm_extension is not None: + + existing_metric_names = qm_extension.params["metric_names"] + existing_metric_names_propogated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute + ] + metric_names = metrics_to_compute + existing_metric_names_propogated + params = dict( - metric_names=[str(name) for name in np.unique(metric_names)], + metric_names=metric_names, peak_sign=peak_sign, seed=seed, qm_params=qm_params_, skip_pc_metrics=skip_pc_metrics, + delete_existing_metrics=delete_existing_metrics, + metrics_to_compute=metrics_to_compute, ) return params @@ -91,6 +119,7 @@ def _merge_extension_data( ): import pandas as pd + metric_names = self.params["metric_names"] old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -99,16 +128,19 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs + ) new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute quality metrics. """ - metric_names = self.params["metric_names"] + import pandas as pd + qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -132,8 +164,6 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job non_empty_unit_ids = unit_ids empty_unit_ids = [] - import pandas as pd - metrics = pd.DataFrame(index=unit_ids) # simple metrics not based on PCs @@ -185,13 +215,41 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan + # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns + # (in case of NaN values) + metrics = metrics.convert_dtypes() return metrics def _run(self, verbose=False, **job_kwargs): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, **job_kwargs + + metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] + + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, + unit_ids=None, + verbose=verbose, + metric_names=metrics_to_compute, + **job_kwargs, ) + existing_metrics = [] + qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + if ( + delete_existing_metrics is False + and qm_extension is not None + and qm_extension.data.get("metrics") is not None + ): + existing_metrics = qm_extension.params["metric_names"] + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metrics_to_compute): + # some metrics names produce data columns with other names. This deals with that. + for column_name in compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = qm_extension.data["metrics"][column_name] + + self.data["metrics"] = computed_metrics + def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 140ad87a8b..375dd320ae 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -53,3 +53,29 @@ "drift": compute_drift_metrics, "sd_ratio": compute_sd_ratio, } + +# a dict converting the name of the metric for computation to the output of that computation +compute_name_to_column_names = { + "num_spikes": ["num_spikes"], + "firing_rate": ["firing_rate"], + "presence_ratio": ["presence_ratio"], + "snr": ["snr"], + "isi_violation": ["isi_violations_ratio", "isi_violations_count"], + "rp_violation": ["rp_violations", "rp_contamination"], + "sliding_rp_violation": ["sliding_rp_violation"], + "amplitude_cutoff": ["amplitude_cutoff"], + "amplitude_median": ["amplitude_median"], + "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], + "synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"], + "firing_range": ["firing_range"], + "drift": ["drift_ptp", "drift_std", "drift_mad"], + "sd_ratio": ["sd_ratio"], + "isolation_distance": ["isolation_distance"], + "l_ratio": ["l_ratio"], + "d_prime": ["d_prime"], + "nearest_neighbor": ["nn_hit_rate", "nn_miss_rate"], + "nn_isolation": ["nn_isolation", "nn_unit_id"], + "nn_noise_overlap": ["nn_noise_overlap"], + "silhouette": ["silhouette"], + "silhouette_full": ["silhouette_full"], +} diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index bb2a345340..01fa16c8d7 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,8 +5,11 @@ create_sorting_analyzer, ) +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _small_sorting_analyzer(): + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], num_units=10, @@ -33,5 +36,36 @@ def _small_sorting_analyzer(): @pytest.fixture(scope="module") -def small_sorting_analyzer(): - return _small_sorting_analyzer() +def sorting_analyzer_simple(): + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params=dict( + alpha=(200.0, 500.0), + ) + ), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=1205, + ) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) + + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + + return sorting_analyzer diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index bb222200e9..4c0890b62b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -1,6 +1,8 @@ import pytest from pathlib import Path import numpy as np +from copy import deepcopy +import csv from spikeinterface.core import ( NumpySorting, synthetize_spike_train_bad_isi, @@ -41,12 +43,167 @@ compute_quality_metrics, ) + from spikeinterface.core.basesorting import minimum_spike_dtype job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +def test_compute_new_quality_metrics(small_sorting_analyzer): + """ + Computes quality metrics then computes a subset of quality metrics, and checks + that the old quality metrics are not deleted. + """ + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "firing_range": {"bin_size_s": 1}, + } + + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + qm_extension = small_sorting_analyzer.get_extension("quality_metrics") + calculated_metrics = list(qm_extension.get_data().keys()) + + assert calculated_metrics == ["snr"] + + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} + ) + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + + quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + + # Check old metrics are not deleted and the new one is added to the data and metadata + assert set(list(quality_metric_extension.get_data().keys())) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + assert set(list(quality_metric_extension.params.get("metric_names"))) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + + # check that, when parameters are changed, the data and metadata are updated + old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": ["snr"], "qm_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + ) + new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + new_snr_data = new_quality_metric_extension.get_data()["snr"].values + + assert np.all(old_snr_data != new_snr_data) + assert new_quality_metric_extension.params["qm_params"]["snr"]["peak_mode"] == "peak_to_peak" + + # check that all quality metrics are deleted when parents are recomputed, even after + # recomputation + extensions_to_compute = { + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + small_sorting_analyzer.compute(extensions_to_compute) + + assert small_sorting_analyzer.get_extension("quality_metrics") is None + + +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified quality metrics and checks order is propogated. + """ + specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] + small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) + qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == qm_keys[i] + + +def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes quality metrics in binary folder format. Then computes subsets of quality + metrics and checks if they are saved correctly. + """ + + # can't use _misc_metric_name_to_func as some functions compute several qms + # e.g. isi_violation and synchrony + quality_metrics = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violations_ratio", + "isi_violations_count", + "rp_contamination", + "rp_violations", + "sliding_rp_violation", + "amplitude_cutoff", + "amplitude_median", + "amplitude_cv_median", + "amplitude_cv_range", + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + "firing_range", + "drift_ptp", + "drift_std", + "drift_mad", + "sd_ratio", + "isolation_distance", + "l_ratio", + "d_prime", + "silhouette", + "nn_hit_rate", + "nn_miss_rate", + ] + + small_sorting_analyzer.compute("quality_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + if metric_name == "snr": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + def test_unit_structure_in_output(small_sorting_analyzer): qm_params = { @@ -129,40 +286,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params ) - for metric, metric_1_data in quality_metrics_1.items(): - assert quality_metrics_2[metric][2] == metric_1_data["#3"] - assert quality_metrics_2[metric][7] == metric_1_data["#9"] - assert quality_metrics_2[metric][1] == metric_1_data["#4"] - - -def _sorting_analyzer_simple(): - recording, sorting = generate_ground_truth_recording( - durations=[ - 50.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=2205, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() + for metric, metric_2_data in quality_metrics_2.items(): + assert quality_metrics_1[metric]["#3"] == metric_2_data[2] + assert quality_metrics_1[metric]["#9"] == metric_2_data[7] + assert quality_metrics_1[metric]["#4"] == metric_2_data[1] def _sorting_violation(): @@ -576,6 +703,7 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): test_unit_structure_in_output(_small_sorting_analyzer()) # test_calculate_firing_rate_num_spikes(sorting_analyzer) + # test_calculate_snrs(sorting_analyzer) # test_calculate_amplitude_cutoff(sorting_analyzer) # test_calculate_presence_ratio(sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 28869ba5ff..a6415c58e8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -2,7 +2,6 @@ from pathlib import Path import numpy as np - from spikeinterface.core import ( generate_ground_truth_recording, create_sorting_analyzer, @@ -15,54 +14,11 @@ compute_quality_metrics, ) - job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def get_sorting_analyzer(seed=2205): - # we need high firing rate for amplitude_cutoff - recording, sorting = generate_ground_truth_recording( - durations=[ - 120.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - generate_unit_locations_kwargs=dict( - margin_um=5.0, - minimum_z=5.0, - maximum_z=20.0, - ), - generate_templates_kwargs=dict( - unit_params=dict( - alpha=(200.0, 500.0), - ) - ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=seed, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=seed) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - sorting_analyzer = get_sorting_analyzer(seed=2205) - return sorting_analyzer - - def test_compute_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple - print(sorting_analyzer) # without PCs metrics = compute_quality_metrics( diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index e73ac2cb6c..2a9fb34267 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -179,7 +179,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): write_prb(probe_filename, pg) if params["use_binary_file"]: - if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + if not recording.binary_compatible_with(time_axis=0, file_paths_length=1): # local copy needed binary_file_path = sorter_output_folder / "recording.dat" write_binary_recording( @@ -235,7 +235,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe_name = "" if params["use_binary_file"] is None: - if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + if recording.binary_compatible_with(time_axis=0, file_paths_length=1): # no copy binary_description = recording.get_binary_description() filename = str(binary_description["file_paths"][0]) @@ -247,7 +247,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): file_object = RecordingExtractorAsArray(recording_extractor=recording) elif params["use_binary_file"]: # here we force the use of a binary file - if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + if recording.binary_compatible_with(time_axis=0, file_paths_length=1): # no copy binary_description = recording.get_binary_description() filename = str(binary_description["file_paths"][0]) diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 95d8d3badc..2aff9d296f 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -127,7 +127,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): skip_kilosort_preprocessing = params.get("skip_kilosort_preprocessing", False) if ( - recording.binary_compatible_with(dtype="int16", time_axis=0, file_paths_lenght=1) + recording.binary_compatible_with(dtype="int16", time_axis=0, file_paths_length=1) and not skip_kilosort_preprocessing ): # no copy diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4701d76012..c3b3099535 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -25,7 +25,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index aa9b16bb97..71a5f282a8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -27,7 +27,7 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity(gt_templates, noise_levels, method="ptp", threshold=0.25) + sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", amplitude_mode="peak_to_peak", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 2bacf36ac9..b08ee4d9cb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -50,7 +50,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, "recursive_depth": 3, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 77d47aec16..f7ca999d53 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -45,7 +45,7 @@ class RandomProjectionClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "radius_um": 30, "nb_projections": 10, "feature": "energy", diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 2fbd0e31eb..813e7d7b63 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -235,6 +235,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): values = check_json(metrics.loc[unit_id].to_dict()) values_skip_nans = {} for k, v in values.items(): + # convert_dypes returns NaN as None or np.nan (for float) + if v is None: + continue if np.isnan(v): continue values_skip_nans[k] = v diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 81cda212b2..42e9a20f3c 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -200,18 +200,11 @@ def __init__( if peak_amplitudes is not None: peak_amplitudes = peak_amplitudes[peak_mask] - if recording is not None: - sampling_frequency = recording.sampling_frequency - times = recording.get_times(segment_index=segment_index) - else: - times = None - plot_data = dict( peaks=peaks, peak_locations=peak_locations, peak_amplitudes=peak_amplitudes, direction=direction, - times=times, sampling_frequency=sampling_frequency, segment_index=segment_index, depth_lim=depth_lim, @@ -238,10 +231,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - if dp.times is None: + if dp.recording is None: peak_times = dp.peaks["sample_index"] / dp.sampling_frequency else: - peak_times = dp.times[dp.peaks["sample_index"]] + peak_times = dp.recording.sample_index_to_time(dp.peaks["sample_index"], segment_index=dp.segment_index) peak_locs = dp.peak_locations[dp.direction] if dp.scatter_decimate is not None: @@ -340,12 +333,12 @@ def __init__( raise ValueError( "plot drift map : the Motion object is multi-segment you must provide segment_index=XX" ) - - times = recording.get_times() if recording is not None else None + assert recording.get_num_segments() == len( + motion.displacement + ), "The number of segments in the recording must be the same as the number of segments in the motion object" plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], - times=times, segment_index=segment_index, depth_lim=depth_lim, motion_lim=motion_lim, diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index debcd52085..80f58f5ad9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -73,7 +73,7 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes"]), + quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 0b2a348edf..755e60ccbf 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections import defaultdict import numpy as np @@ -17,7 +18,7 @@ class UnitSummaryWidget(BaseWidget): """ Plot a unit summary. - If amplitudes are alreday computed they are displayed. + If amplitudes are alreday computed, they are displayed. Parameters ---------- @@ -30,6 +31,14 @@ class UnitSummaryWidget(BaseWidget): sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored + subwidget_kwargs : dict or None, default: None + Parameters for the subwidgets in a nested dictionary + unit_locations : UnitLocationsWidget (see UnitLocationsWidget for details) + unit_waveforms : UnitWaveformsWidget (see UnitWaveformsWidget for details) + unit_waveform_density_map : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelograms : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudes : AmplitudesWidget (see AmplitudesWidget for details) + Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary. """ # possible_backends = {} @@ -40,21 +49,29 @@ def __init__( unit_id, unit_colors=None, sparsity=None, - radius_um=100, + subwidget_kwargs=None, backend=None, **backend_kwargs, ): - sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer) + if subwidget_kwargs is None: + subwidget_kwargs = dict() + for kwargs in subwidget_kwargs.values(): + if "unit_colors" in kwargs: + raise ValueError( + "unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary" + ) + plot_data = dict( sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, + subwidget_kwargs=subwidget_kwargs, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -70,6 +87,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors = dp.unit_colors sparsity = dp.sparsity + # defaultdict returns empty dict if key not found in subwidget_kwargs + subwidget_kwargs = defaultdict(lambda: dict(), dp.subwidget_kwargs) + unitlocationswidget_kwargs = subwidget_kwargs["unit_locations"] + unitwaveformswidget_kwargs = subwidget_kwargs["unit_waveforms"] + unitwaveformdensitymapwidget_kwargs = subwidget_kwargs["unit_waveform_density_map"] + autocorrelogramswidget_kwargs = subwidget_kwargs["autocorrelograms"] + amplitudeswidget_kwargs = subwidget_kwargs["amplitudes"] + # force the figure without axes if "figsize" not in backend_kwargs: backend_kwargs["figsize"] = (18, 7) @@ -99,6 +124,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, + **unitlocationswidget_kwargs, ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -121,6 +147,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, + **unitwaveformswidget_kwargs, ) ax2.set_title(None) @@ -134,6 +161,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, + **unitwaveformdensitymapwidget_kwargs, ) ax3.set_ylabel(None) @@ -145,6 +173,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, + **autocorrelogramswidget_kwargs, ) ax4.set_title(None) @@ -162,6 +191,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, + **amplitudeswidget_kwargs, ) fig.suptitle(f"unit_id: {dp.unit_id}") diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 7a9dc47826..a6cc562ba2 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -106,9 +106,9 @@ def generate_unit_table_view( if prop_name in sorting_props: property_values = sorting.get_property(prop_name) elif prop_name in qm_props: - property_values = qm_data[prop_name].values + property_values = qm_data[prop_name].to_numpy() elif prop_name in tm_props: - property_values = tm_data[prop_name].values + property_values = tm_data[prop_name].to_numpy() else: warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") continue @@ -137,16 +137,17 @@ def generate_unit_table_view( if prop_name in sorting_props: property_values = sorting.get_property(prop_name) elif prop_name in qm_props: - property_values = qm_data[prop_name].values + property_values = qm_data[prop_name].to_numpy() elif prop_name in tm_props: - property_values = tm_data[prop_name].values + property_values = tm_data[prop_name].to_numpy() - # Check for NaN values + # Check for NaN values and round floats val0 = np.array(property_values[0]) if val0.dtype.kind == "f": if np.isnan(property_values[ui]): continue - values[prop_name] = np.format_float_positional(property_values[ui], precision=4, fractional=False) + property_values[ui] = np.format_float_positional(property_values[ui], precision=4, fractional=False) + values[prop_name] = property_values[ui] ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores)