diff --git a/.github/run_tests.sh b/.github/run_tests.sh index 558e0b64d3..02eb6ab8a1 100644 --- a/.github/run_tests.sh +++ b/.github/run_tests.sh @@ -10,5 +10,5 @@ fi pytest -m "$MARKER" -vv -ra --durations=0 --durations-min=0.001 | tee report.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of ${MARKER}" >> $GITHUB_STEP_SUMMARY -python $GITHUB_WORKSPACE/.github/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY +python $GITHUB_WORKSPACE/.github/scripts/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY rm report.txt diff --git a/.github/scripts/README.MD b/.github/scripts/README.MD new file mode 100644 index 0000000000..1d3a622aae --- /dev/null +++ b/.github/scripts/README.MD @@ -0,0 +1,2 @@ +This folder contains test scripts for running in the CI, that are not run as part of the usual +CI because they are too long / heavy. These are run on cron-jobs once per week. diff --git a/.github/build_job_summary.py b/.github/scripts/build_job_summary.py similarity index 100% rename from .github/build_job_summary.py rename to .github/scripts/build_job_summary.py diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py new file mode 100644 index 0000000000..7a6368f3cf --- /dev/null +++ b/.github/scripts/check_kilosort4_releases.py @@ -0,0 +1,30 @@ +import os +import re +from pathlib import Path +import requests +import json +from packaging.version import parse +import spikeinterface + +def get_pypi_versions(package_name): + """ + Make an API call to pypi to retrieve all + available versions of the kilosort package. + """ + url = f"https://pypi.org/pypi/{package_name}/json" + response = requests.get(url) + response.raise_for_status() + data = response.json() + versions = list(sorted(data["releases"].keys())) + # Filter out versions that are less than 4.0.16 + versions = [ver for ver in versions if parse(ver) >= parse("4.0.16")] + return versions + + +if __name__ == "__main__": + # Get all KS4 versions from pipi and write to file. + package_name = "kilosort" + versions = get_pypi_versions(package_name) + with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: + print(versions) + json.dump(versions, f) diff --git a/.github/determine_testing_environment.py b/.github/scripts/determine_testing_environment.py similarity index 100% rename from .github/determine_testing_environment.py rename to .github/scripts/determine_testing_environment.py diff --git a/.github/import_test.py b/.github/scripts/import_test.py similarity index 100% rename from .github/import_test.py rename to .github/scripts/import_test.py diff --git a/.github/scripts/kilosort4-latest-version.json b/.github/scripts/kilosort4-latest-version.json new file mode 100644 index 0000000000..03629ff842 --- /dev/null +++ b/.github/scripts/kilosort4-latest-version.json @@ -0,0 +1 @@ +["4.0.10", "4.0.11", "4.0.12", "4.0.5", "4.0.6", "4.0.7", "4.0.8", "4.0.9"] diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py new file mode 100644 index 0000000000..6eeb71f1dd --- /dev/null +++ b/.github/scripts/test_kilosort4_ci.py @@ -0,0 +1,644 @@ +""" +This file tests the SpikeInterface wrapper of the Kilosort4. The general logic +of the tests are: +- Change every exposed parameter one at a time (PARAMS_TO_TEST). Check that + the result of the SpikeInterface wrapper and Kilosort run natively are + the same. The SpikeInterface wrapper is non-trivial and decomposes the + kilosort pipeline to allow additions such as skipping preprocessing. Therefore, + the idea is that is it safer to rely on the output directly rather than + try monkeypatching. One thing can could be better tested is parameter + changes when skipping KS4 preprocessing is true, because this takes a slightly + different path through the kilosort4.py wrapper logic. + This also checks that changing the parameter changes the test output from default + on our test case (otherwise, the test could not detect a failure). + +- Test that kilosort functions called from `kilosort4.py` wrapper have the expected + input signatures + +- Do some tests to check all KS4 parameters are tested against. +""" + +import pytest +import copy +from typing import Any +from inspect import signature + +import numpy as np +import torch + +import spikeinterface.full as si +from spikeinterface.core.testing import check_sortings_equal +from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter +from probeinterface.io import write_prb + +import kilosort +from kilosort.parameters import DEFAULT_SETTINGS +from kilosort.run_kilosort import ( + set_files, + initialize_ops, + compute_preprocessing, + compute_drift_correction, + detect_spikes, + cluster_spikes, + save_sorting, + get_run_parameters, +) +from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered + + +RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] +# "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. + +# Setup Params to test #### +PARAMS_TO_TEST_DICT = { + "nblocks": 0, + "do_CAR": False, + "batch_size": 42743, + "Th_universal": 12, + "Th_learned": 14, + "invert_sign": True, + "nt": 93, + "nskip": 1, + "whitening_range": 16, + "highpass_cutoff": 200, + "sig_interp": 5, + "nt0min": 25, + "dmin": 15, + "dminx": 16, + "min_template_size": 15, + "template_sizes": 10, + "nearest_chans": 8, + "nearest_templates": 35, + "max_channel_distance": 5, + "templates_from_data": False, + "n_templates": 10, + "n_pcs": 3, + "Th_single_ch": 4, + "x_centers": 5, + "binning_depth": 1, + "drift_smoothing": [250, 250, 250], + "artifact_threshold": 200, + "ccg_threshold": 1e12, + "acg_threshold": 1e12, + "cluster_downsampling": 2, + "duplicate_spike_ms": 0.3, +} + +PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys()) + +PARAMETERS_NOT_AFFECTING_RESULTS = [ + "artifact_threshold", + "ccg_threshold", + "acg_threshold", + "cluster_downsampling", + "cluster_pcs", + "duplicate_spike_ms", # this is because ground-truth spikes don't have violations +] + +# THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST +# if parse(version("kilosort")) >= parse("4.0.X"): +# PARAMS_TO_TEST_DICT.update( +# [ +# {"new_param": new_value}, +# ] +# ) + + +class TestKilosort4Long: + # Fixtures ###### + @pytest.fixture(scope="session") + def recording_and_paths(self, tmp_path_factory): + """ + Create a ground-truth recording, and save it to binary + so KS4 can run on it. Fixture is set up once and shared between + all tests. + """ + tmp_path = tmp_path_factory.mktemp("kilosort4_tests") + + recording = self._get_ground_truth_recording() + + paths = self._save_ground_truth_recording(recording, tmp_path) + + return (recording, paths) + + @pytest.fixture(scope="session") + def default_kilosort_sorting(self, recording_and_paths): + """ + Because we check each parameter at a time and check the + KS4 and SpikeInterface versions match, if changing the parameter + had no effect as compared to default then the test would not test + anything. Therefore, the default results are run once and stored, + to check changing params indeed changes the results during testing. + This is possibly for nearly all parameters. + """ + recording, paths = recording_and_paths + + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, None, None) + + defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=defaults_ks_output_dir, + ) + + return si.read_kilosort(defaults_ks_output_dir) + + def _get_ground_truth_recording(self): + """ + A ground truth recording chosen to be as small as possible (for speed). + But contain enough information so that changing most parameters + changes the results. + """ + num_channels = 32 + recording, _ = si.generate_ground_truth_recording( + durations=[5], + seed=0, + num_channels=num_channels, + num_units=5, + generate_sorting_kwargs=dict(firing_rates=100, refractory_period_ms=4.0), + ) + return recording + + def _save_ground_truth_recording(self, recording, tmp_path): + """ + Save the recording and its probe to file, so it can be + loaded by KS4. + """ + paths = { + "session_scope_tmp_path": tmp_path, + "recording_path": tmp_path / "my_test_recording", + "probe_path": tmp_path / "my_test_probe.prb", + } + + recording.save(folder=paths["recording_path"], overwrite=True) + + probegroup = recording.get_probegroup() + write_prb(paths["probe_path"].as_posix(), probegroup) + + return paths + + # Tests ###### + def test_params_to_test(self): + """ + Test that all values in PARAMS_TO_TEST_DICT are + different to the default values used in Kilosort, + otherwise there is no point to the test. + """ + for param_key, param_value in PARAMS_TO_TEST_DICT.items(): + if param_key not in RUN_KILOSORT_ARGS: + assert DEFAULT_SETTINGS[param_key] != param_value, ( + f"{param_key} values should be different in test: " + f"{param_value} vs. {DEFAULT_SETTINGS[param_key]}" + ) + + def test_default_settings_all_represented(self): + """ + Test that every entry in DEFAULT_SETTINGS is tested in + PARAMS_TO_TEST, otherwise we are missing settings added + on the KS side. + """ + tested_keys = PARAMS_TO_TEST + additional_non_tested_keys = ["shift", "scale", "save_preprocessed_copy"] + tested_keys += additional_non_tested_keys + + for param_key in DEFAULT_SETTINGS: + if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: + assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + + def test_spikeinterface_defaults_against_kilsort(self): + """ + Here check that all _ + Don't check that every default in KS is exposed in params, + because they change across versions. Instead, this check + is performed here against PARAMS_TO_TEST. + """ + params = copy.deepcopy(Kilosort4Sorter._default_params) + + for key in params.keys(): + # "artifact threshold" is set to `np.inf` if `None` in + # the body of the `Kilosort4Sorter` class. + if key in DEFAULT_SETTINGS and key not in ["artifact_threshold"]: + assert params[key] == DEFAULT_SETTINGS[key], f"{key} is not the same across versions." + + # Testing Arguments ### + def test_set_files_arguments(self): + self._check_arguments( + set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"] + ) + + def test_initialize_ops_arguments(self): + expected_arguments = [ + "settings", + "probe", + "data_dtype", + "do_CAR", + "invert_sign", + "device", + "save_preprocessed_copy", + ] + + self._check_arguments( + initialize_ops, + expected_arguments, + ) + + def test_compute_preprocessing_arguments(self): + self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"]) + + def test_compute_drift_location_arguments(self): + self._check_arguments( + compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"] + ) + + def test_detect_spikes_arguments(self): + self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]) + + def test_cluster_spikes_arguments(self): + self._check_arguments( + cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"] + ) + + def test_save_sorting_arguments(self): + expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] + + expected_arguments.append("save_preprocessed_copy") + + self._check_arguments(save_sorting, expected_arguments) + + def test_get_run_parameters(self): + self._check_arguments(get_run_parameters, ["ops"]) + + def test_load_probe_parameters(self): + self._check_arguments(load_probe, ["probe_path"]) + + def test_recording_extractor_as_array_arguments(self): + self._check_arguments(RecordingExtractorAsArray, ["recording_extractor"]) + + def test_binary_filtered_arguments(self): + expected_arguments = [ + "filename", + "n_chan_bin", + "fs", + "NT", + "nt", + "nt0min", + "chan_map", + "hp_filter", + "whiten_mat", + "dshift", + "device", + "do_CAR", + "artifact_threshold", + "invert_sign", + "dtype", + "tmin", + "tmax", + "shift", + "scale", + "file_object", + ] + + self._check_arguments(BinaryFiltered, expected_arguments) + + def _check_arguments(self, object_, expected_arguments): + """ + Check that the argument signature of `object_` is as expected + (i.e. has not changed across kilosort versions). + """ + sig = signature(object_) + obj_arguments = list(sig.parameters.keys()) + assert expected_arguments == obj_arguments + + # Full Test #### + @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) + def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp_path, parameter): + """ + Given a recording, paths to raw data, and a parameter to change, + run KS4 natively and within the SpikeInterface wrapper with the + new parameter value (all other values default) and + check the outputs are the same. + """ + recording, paths = recording_and_paths + param_key = parameter + param_value = PARAMS_TO_TEST_DICT[param_key] + + # Setup parameters for KS4 and run it natively + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + settings, run_kilosort_kwargs, ks_format_probe = self._get_kilosort_native_settings( + recording, paths, param_key, param_value + ) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + **run_kilosort_kwargs, + ) + sorting_ks4 = si.read_kilosort(kilosort_output_dir) + + # Setup Parameters for SI and KS4 through SI + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) + + sorting_si = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + **spikeinterface_settings, + ) + + # Get the results and check they match + check_sortings_equal(sorting_ks4, sorting_si) + + # Check the ops file in KS4 output is as expected. This is saved on the + # SI side so not an extremely robust addition, but it can't hurt. + ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) + ops = ops.tolist() # strangely this makes a dict + assert ops[param_key] == param_value + + # Finally, check out test parameters actually change the output of + # KS4, ensuring our tests are actually doing something (exxcept for some params). + if param_key not in PARAMETERS_NOT_AFFECTING_RESULTS: + with pytest.raises(AssertionError): + check_sortings_equal(default_kilosort_sorting, sorting_si) + + def test_clear_cache(self,recording_and_paths, tmp_path): + """ + Test clear_cache parameter in kilosort4.run_kilosort + """ + recording, paths = recording_and_paths + + spikeinterface_output_dir = tmp_path / "spikeinterface_output_clear" + sorting_si_clear = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + clear_cache=True + ) + spikeinterface_output_dir = tmp_path / "spikeinterface_output_no_clear" + sorting_si_no_clear = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + clear_cache=False + ) + check_sortings_equal(sorting_si_clear, sorting_si_no_clear) + + def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): + """ + Test the SpikeInterface wrappers `do_correction` argument. We set + `nblocks=0` for KS4 native, turning off motion correction. Then + we run KS$ through SpikeInterface with `do_correction=False` but + `nblocks=1` (KS4 default) - checking that `do_correction` overrides + this and the result matches KS4 when run without motion correction. + """ + recording, paths = recording_and_paths + + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "nblocks", 0) + + kilosort.run_kilosort( + settings=settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=True, + ) + sorting_ks = si.read_kilosort(kilosort_output_dir) + + spikeinterface_settings = self._get_spikeinterface_settings("nblocks", 1) + sorting_si = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_correction=False, + **spikeinterface_settings, + ) + check_sortings_equal(sorting_ks, sorting_si) + + def test_use_binary_file(self, tmp_path): + """ + Test that the SpikeInterface wrapper can run KS4 using a binary file as input or directly + from the recording. + """ + recording = self._get_ground_truth_recording() + recording_bin = recording.save() + + # run with SI wrapper + sorting_ks4 = si.run_sorter( + "kilosort4", + recording, + folder=tmp_path / "ks4_output_si_wrapper_default", + use_binary_file=None, + remove_existing_folder=True, + ) + sorting_ks4_bin = si.run_sorter( + "kilosort4", + recording_bin, + folder=tmp_path / "ks4_output_bin_default", + use_binary_file=None, + remove_existing_folder=True, + ) + sorting_ks4_force_binary = si.run_sorter( + "kilosort4", + recording, + folder=tmp_path / "ks4_output_force_bin", + use_binary_file=True, + remove_existing_folder=True, + ) + assert not (tmp_path / "ks4_output_force_bin" / "sorter_output" / "recording.dat").exists() + sorting_ks4_force_non_binary = si.run_sorter( + "kilosort4", + recording_bin, + folder=tmp_path / "ks4_output_force_wrapper", + use_binary_file=False, + remove_existing_folder=True, + ) + # test deleting recording.dat + sorting_ks4_force_binary_keep = si.run_sorter( + "kilosort4", + recording, + folder=tmp_path / "ks4_output_force_bin_keep", + use_binary_file=True, + delete_recording_dat=False, + remove_existing_folder=True, + ) + assert (tmp_path / "ks4_output_force_bin_keep" / "sorter_output" / "recording.dat").exists() + + check_sortings_equal(sorting_ks4, sorting_ks4_bin) + check_sortings_equal(sorting_ks4, sorting_ks4_force_binary) + check_sortings_equal(sorting_ks4, sorting_ks4_force_non_binary) + + @pytest.mark.parametrize( + "param_to_test", + [ + ("do_CAR", False), + ("batch_size", 42743), + ("Th_learned", 14), + ("dmin", 15), + ("max_channel_distance", 5), + ("n_pcs", 3), + ], + ) + def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, param_to_test): + """ + Test that skipping KS4 preprocessing works as expected. Run + KS4 natively, monkeypatching the relevant preprocessing functions + such that preprocessing is not performed. Then run in SpikeInterface + with `skip_kilosort_preprocessing=True` and check the outputs match. + + Run with a few randomly chosen parameters to check these are propagated + under this condition. + + TODO + ---- + It would be nice to check a few additional parameters here. Screw it! + """ + param_key, param_value = param_to_test + + recording = self._get_ground_truth_recording() + + # We need to filter and whiten the recording here to KS takes forever. + # Do this in a way different to KS. + recording = si.highpass_filter(recording, 300) + recording = si.whiten(recording, mode="local", apply_mean=False) + + paths = self._save_ground_truth_recording(recording, tmp_path) + + kilosort_output_dir = tmp_path / "kilosort_output_dir" + spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" + + def monkeypatch_filter_function(self, X, ops=None, ibatch=None): + """ + This is a direct copy of the kilosort io.BinaryFiltered.filter + function, with hp_filter and whitening matrix code sections, and + comments removed. This is the easiest way to monkeypatch (tried a few approaches) + """ + if self.chan_map is not None: + X = X[self.chan_map] + + if self.invert_sign: + X = X * -1 + + X = X - X.mean(1).unsqueeze(1) + if self.do_CAR: + X = X - torch.median(X, 0)[0] + + if self.hp_filter is not None: + pass + + if self.artifact_threshold < np.inf: + if torch.any(torch.abs(X) >= self.artifact_threshold): + return torch.zeros_like(X) + + if self.whiten_mat is not None: + pass + return X + + monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", monkeypatch_filter_function) + + ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) + ks_settings["nblocks"] = 0 + + # Be explicit here and don't rely on defaults. + do_CAR = param_value if param_key == "do_CAR" else False + + kilosort.run_kilosort( + settings=ks_settings, + probe=ks_format_probe, + data_dtype="float32", + results_dir=kilosort_output_dir, + do_CAR=do_CAR, + ) + + monkeypatch.undo() + si.read_kilosort(kilosort_output_dir) + + # Now, run kilosort through spikeinterface with the same options. + spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) + spikeinterface_settings["nblocks"] = 0 + + do_CAR = False if param_key != "do_CAR" else spikeinterface_settings.pop("do_CAR") + + si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + do_CAR=do_CAR, + skip_kilosort_preprocessing=True, + **spikeinterface_settings, + ) + + # There is a very slight difference caused by the batching between load vs. + # memory file. Because in this test recordings are preprocessed, there are + # some filter edge effects that depend on the chunking in `get_traces()`. + # These are all extremely close (usually just 1 spike, 1 idx different). + results = {} + results["ks"] = {} + results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") + results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") + results["si"] = {} + results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") + results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") + assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + + ##### Helpers ###### + def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): + """ + Function to generate the settings and function inputs to run kilosort. + Note when `binning_depth` is used we need to set `nblocks` high to + get the results to change from default. + + Some settings in KS4 are passed by `settings` dict while others + are through the function, these are split here. + """ + settings = { + "data_dir": paths["recording_path"], + "n_chan_bin": recording.get_num_channels(), + "fs": recording.get_sampling_frequency(), + } + run_kilosort_kwargs = {} + + if param_key is not None: + if param_key == "binning_depth": + settings.update({"nblocks": 5}) + + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: + settings.update({param_key: param_value}) + run_kilosort_kwargs = {} + + ks_format_probe = load_probe(paths["probe_path"]) + + return settings, run_kilosort_kwargs, ks_format_probe + + def _get_spikeinterface_settings(self, param_key, param_value): + """ + Generate settings kwargs for running KS4 in SpikeInterface. + See `_get_kilosort_native_settings()` for some details. + """ + settings = {} # copy.deepcopy(DEFAULT_SETTINGS) + + if param_key == "binning_depth": + settings.update({"nblocks": 5}) + + settings.update({param_key: param_value}) + + # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: + # settings.pop(name) + + return settings diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 8317d7bec4..e12cf6805d 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -50,7 +50,7 @@ jobs: shell: bash run: | changed_files="${{ steps.changed-files.outputs.all_changed_files }}" - python .github/determine_testing_environment.py $changed_files + python .github/scripts/determine_testing_environment.py $changed_files - name: Display testing environment shell: bash @@ -139,7 +139,7 @@ jobs: - name: Test streaming extractors shell: bash - if: env.RUN_STREAMING_EXTRACTORS_TESTS + if: env.RUN_STREAMING_EXTRACTORS_TESTS == 'true' run: | pip install -e .[streaming_extractors,test_extractors] ./.github/run_tests.sh "streaming_extractors" --no-virtual-env @@ -202,7 +202,7 @@ jobs: shell: bash if: env.RUN_WIDGETS_TESTS == 'true' run: | - pip install -e .[full] + pip install -e .[full,widgets] ./.github/run_tests.sh widgets --no-virtual-env - name: Test exporters diff --git a/.github/workflows/core-test.yml b/.github/workflows/core-test.yml index a513d48f3b..1dbf0f5109 100644 --- a/.github/workflows/core-test.yml +++ b/.github/workflows/core-test.yml @@ -39,7 +39,7 @@ jobs: pip install tabulate echo "# Timing profile of core tests in ${{matrix.os}}" >> $GITHUB_STEP_SUMMARY # Outputs markdown summary to standard output - python ./.github/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY rm report.txt shell: bash # Necessary for pipeline to work on windows diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index ab4a083ae1..6a222f5e25 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -47,7 +47,7 @@ jobs: source ${{ github.workspace }}/test_env/bin/activate pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY - python ./.github/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY rm report_full.txt - uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_imports.yml b/.github/workflows/test_imports.yml index d39fc37242..a2631f6eb7 100644 --- a/.github/workflows/test_imports.yml +++ b/.github/workflows/test_imports.yml @@ -34,7 +34,7 @@ jobs: echo "## OS: ${{ matrix.os }}" >> $GITHUB_STEP_SUMMARY echo "---" >> $GITHUB_STEP_SUMMARY echo "### Import times when only installing only core dependencies " >> $GITHUB_STEP_SUMMARY - python ./.github/import_test.py >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/import_test.py >> $GITHUB_STEP_SUMMARY shell: bash # Necessary for pipeline to work on windows - name: Install in full mode run: | @@ -44,5 +44,5 @@ jobs: # Add a header to separate the two profiles echo "---" >> $GITHUB_STEP_SUMMARY echo "### Import times when installing full dependencies in " >> $GITHUB_STEP_SUMMARY - python ./.github/import_test.py >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/import_test.py >> $GITHUB_STEP_SUMMARY shell: bash # Necessary for pipeline to work on windows diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml new file mode 100644 index 0000000000..42e6140917 --- /dev/null +++ b/.github/workflows/test_kilosort4.yml @@ -0,0 +1,74 @@ +name: Testing Kilosort4 + +on: + workflow_dispatch: + schedule: + - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC + pull_request: + paths: + - '**/kilosort4.py' + +jobs: + versions: + # Poll Pypi for all released KS4 versions >4.0.16, save to JSON + # and store them in a matrix for the next job. + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.12 + + - name: Install dependencies + run: | + pip install requests packaging + pip install . + + - name: Fetch package versions from PyPI + run: | + python .github/scripts/check_kilosort4_releases.py + shell: bash + + - name: Set matrix data + id: set-matrix + run: | + echo "matrix=$(jq -c . < .github/scripts/kilosort4-latest-version.json)" >> $GITHUB_OUTPUT + + test: + needs: versions + name: ${{ matrix.ks_version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: ["3.12"] + os: [ubuntu-latest] + ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install SpikeInterface + run: | + pip install -e .[test] + shell: bash + + - name: Install Kilosort + run: | + pip install kilosort==${{ matrix.ks_version }} + shell: bash + + - name: Run new kilosort4 tests + run: | + pytest .github/scripts/test_kilosort4_ci.py + shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a58e7e8f1..4c4bd68be4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black files: ^src/ diff --git a/conftest.py b/conftest.py index ce5e07b47b..5bf7d74527 100644 --- a/conftest.py +++ b/conftest.py @@ -7,6 +7,7 @@ def create_cache_folder(tmp_path_factory): cache_folder = tmp_path_factory.mktemp("cache_folder") return cache_folder + def pytest_collection_modifyitems(config, items): """ This function marks (in the pytest sense) the tests according to their name and file_path location @@ -16,7 +17,11 @@ def pytest_collection_modifyitems(config, items): rootdir = Path(config.rootdir) modules_location = rootdir / "src" / "spikeinterface" for item in items: - rel_path = Path(item.fspath).relative_to(modules_location) + try: + rel_path = Path(item.fspath).relative_to(modules_location) + except: + continue + module = rel_path.parts[0] if module == "sorters": if "internal" in rel_path.parts: diff --git a/doc/api.rst b/doc/api.rst index 1966b48a37..6bb9b39091 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -73,6 +73,19 @@ Low-level .. autoclass:: ChunkRecordingExecutor + +Back-compatibility with ``WaveformExtractor`` (version < 0.101.0) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: spikeinterface.core + :noindex: + + .. autofunction:: extract_waveforms + .. autofunction:: load_waveforms + .. autofunction:: load_sorting_analyzer_or_waveforms + + + spikeinterface.extractors ------------------------- @@ -171,6 +184,7 @@ spikeinterface.preprocessing .. autofunction:: interpolate_bad_channels .. autofunction:: normalize_by_quantile .. autofunction:: notch_filter + .. autofunction:: causal_filter .. autofunction:: phase_shift .. autofunction:: rectify .. autofunction:: remove_artifacts diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index cc2b064ed0..3c30f248c8 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -208,9 +208,11 @@ For dense waveforms, sparsity can also be passed as an argument. .. code-block:: python - pc = sorting_analyzer.compute(input="principal_components", - n_components=3, - mode="by_channel_local") + pc = sorting_analyzer.compute( + input="principal_components", + n_components=3, + mode="by_channel_local" + ) For more information, see :py:func:`~spikeinterface.postprocessing.compute_principal_components` @@ -243,9 +245,7 @@ each spike. .. code-block:: python - amplitudes = sorting_analyzer.compute(input="spike_amplitudes", - peak_sign="neg", - outputs="concatenated") + amplitudes = sorting_analyzer.compute(input="spike_amplitudes", peak_sign="neg") For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes` @@ -263,15 +263,17 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), .. code-block:: python - spike_locations = sorting_analyzer.compute(input="spike_locations", - ms_before=0.5, - ms_after=0.5, - spike_retriever_kwargs=dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg" - ), - method="center_of_mass") + spike_locations = sorting_analyzer.compute( + input="spike_locations", + ms_before=0.5, + ms_after=0.5, + spike_retriever_kwargs=dict( + channel_from_template=True, + radius_um=50, + peak_sign="neg" + ), + method="center_of_mass" + ) For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_locations` @@ -329,6 +331,12 @@ Optionally, the following multi-channel metrics can be computed by setting: Visualization of template metrics. Image from `ecephys_spike_sorting `_ from the Allen Institute. + +.. code-block:: python + + tm = sorting_analyzer.compute(input="template_metrics", include_multi_channel_metrics=True) + + For more information, see :py:func:`~spikeinterface.postprocessing.compute_template_metrics` @@ -340,10 +348,12 @@ with shape (num_units, num_units, num_bins) with all correlograms for each pair .. code-block:: python - ccg = sorting_analyzer.compute(input="correlograms", - window_ms=50.0, - bin_ms=1.0, - method="auto") + ccg = sorting_analyzer.compute( + input="correlograms", + window_ms=50.0, + bin_ms=1.0, + method="auto" + ) For more information, see :py:func:`~spikeinterface.postprocessing.compute_correlograms` @@ -357,10 +367,12 @@ This extension computes the histograms of inter-spike-intervals. The computed ou .. code-block:: python - isi = sorting_analyer.compute(input="isi_histograms" - window_ms=50.0, - bin_ms=1.0, - method="auto") + isi = sorting_analyer.compute( + input="isi_histograms" + window_ms=50.0, + bin_ms=1.0, + method="auto" + ) For more information, see :py:func:`~spikeinterface.postprocessing.compute_isi_histograms` diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index 04a302c597..f119693203 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -57,7 +57,7 @@ This code snippet shows how to compute quality metrics (with or without principa # with PCs (depends on "pca" in addition to the above metrics) - qm_ext = sorting_analyzer.compute(input={"pca": dict(n_components=5, mode="by_channel_local"), + qm_ext = sorting_analyzer.compute(input={"principal_components": dict(n_components=5, mode="by_channel_local"), "quality_metrics": dict(skip_pc_metrics=False)}) metrics = qm_ext.get_data() assert 'isolation_distance' in metrics.columns diff --git a/pyproject.toml b/pyproject.toml index 67aee92d29..8309ca89fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,13 @@ [project] name = "spikeinterface" -version = "0.101.0" +version = "0.101.1" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.8,<4.0" +requires-python = ">=3.9,<4.0" classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", @@ -125,16 +125,16 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_extractors = [ # Functions to download data in neo test suite "pooch>=1.8.2", "datalad>=1.0.2", - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_preprocessing = [ @@ -175,17 +175,18 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ "Sphinx", - "sphinx_rtd_theme", + "sphinx_rtd_theme>=1.2", "sphinx-gallery", "sphinx-design", "numpydoc", "ipython", + "sphinxcontrib-jquery", # for notebooks in the gallery "MEArec", # Use as an example @@ -199,8 +200,8 @@ docs = [ "datalad>=1.0.2", # for release we need pypi, so this needs to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 97fb95b623..306c12d516 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -30,5 +30,5 @@ # This flag must be set to False for release # This avoids using versioning that contains ".dev0" (and this is a better choice) # This is mainly useful when using run_sorter in a container and spikeinterface install -# DEV_MODE = True -DEV_MODE = False +DEV_MODE = True +# DEV_MODE = False diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index f1d2130d38..3a39f08a7c 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -63,9 +63,9 @@ def compute_subgraphs(self): Computes subgraphs of connected components. Returns ------- - sg_object_names: list + sg_object_names : list List of sorter names for each node in the connected component subgraph - sg_units: list + sg_units : list List of unit ids for each node in the connected component subgraph """ if self.clean_graph is not None: diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 9b455e6200..574bd16093 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -13,9 +13,19 @@ class CollisionGTComparison(GroundTruthComparison): This class needs maintenance and need a bit of refactoring. - - collision_lag : float + Parameters + ---------- + gt_sorting : SortingExtractor + The first sorting for the comparison + collision_lag : float, default 2.0 Collision lag in ms. + tested_sorting : SortingExtractor + The second sorting for the comparison + nbins : int, default : 11 + Number of collision bins + **kwargs : dict + Keyword arguments for `GroundTruthComparison` + """ diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 87d0bf512b..9b5304b0a7 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -3,27 +3,25 @@ """ from __future__ import annotations - - import numpy as np -def count_matching_events(times1, times2, delta=10): +def count_matching_events(times1, times2: np.ndarray | list, delta: int = 10): """ Counts matching events. Parameters ---------- - times1: list + times1 : list List of spike train 1 frames - times2: list + times2 : list List of spike train 2 frames - delta: int + delta : int Number of frames for considering matching events Returns ------- - matching_count: int + matching_count : int Number of matching events """ times_concat = np.concatenate((times1, times2)) @@ -39,22 +37,22 @@ def count_matching_events(times1, times2, delta=10): return len(inds2) + 1 -def compute_agreement_score(num_matches, num1, num2): +def compute_agreement_score(num_matches: int, num1: int, num2: int) -> float: """ Computes agreement score. Parameters ---------- - num_matches: int + num_matches : int Number of matches - num1: int + num1 : int Number of events in spike train 1 - num2: int + num2 : int Number of events in spike train 2 Returns ------- - score: float + score : float Agreement score """ denom = num1 + num2 - num_matches @@ -71,12 +69,12 @@ def do_count_event(sorting): Parameters ---------- - sorting: SortingExtractor + sorting : SortingExtractor A sorting extractor Returns ------- - event_count: pd.Series + event_count : pd.Series Nb of spike by units. """ import pandas as pd @@ -90,14 +88,14 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev Parameters ---------- - times1: array + times1 : array Spike train 1 frames - all_times2: list of array + all_times2 : list of array List of spike trains from sorting 2 Returns ------- - matching_events_count: list + matching_events_count : list List of counts of matching events """ matching_event_counts = np.zeros(len(all_times2), dtype="int64") @@ -232,8 +230,8 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa And the minimum of the two results is taken. Returns ------- - matching_matrix : ndarray - A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + matching_matrix : pd.DataFrame + A 2D pandas DataFrame of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. Notes @@ -337,18 +335,18 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True Parameters ---------- - sorting1: SortingExtractor + sorting1 : SortingExtractor The first sorting extractor - sorting2: SortingExtractor + sorting2 : SortingExtractor The second sorting extractor - delta_frames: int + delta_frames : int Number of frames to consider spikes coincident - ensure_symmetry: bool, default: True + ensure_symmetry : bool, default: True If ensure_symmetry is True, then the algo is run two times by switching sorting1 and sorting2. And the minimum of the two results is taken. Returns ------- - agreement_scores: array (float) + agreement_scores : pd.DataFrame The agreement score matrix. """ import pandas as pd @@ -401,16 +399,16 @@ def make_possible_match(agreement_scores, min_score): Parameters ---------- - agreement_scores: pd.DataFrame + agreement_scores : pd.DataFrame - min_score: float + min_score : float Returns ------- - best_match_12: pd.Series + best_match_12 : dict[NDArray] - best_match_21: pd.Series + best_match_21 : dict[NDArray] """ unit1_ids = np.array(agreement_scores.index) @@ -433,7 +431,7 @@ def make_possible_match(agreement_scores, min_score): return possible_match_12, possible_match_21 -def make_best_match(agreement_scores, min_score): +def make_best_match(agreement_scores, min_score) -> "tuple[pd.Series, pd.Series]": """ Given an agreement matrix and a min_score threshold. return a dict a best match for each units independently of others. @@ -442,16 +440,16 @@ def make_best_match(agreement_scores, min_score): Parameters ---------- - agreement_scores: pd.DataFrame + agreement_scores : pd.DataFrame - min_score: float + min_score : float Returns ------- - best_match_12: pd.Series + best_match_12 : pd.Series - best_match_21: pd.Series + best_match_21 : pd.Series """ import pandas as pd @@ -490,14 +488,14 @@ def make_hungarian_match(agreement_scores, min_score): ---------- agreement_scores: pd.DataFrame - min_score: float + min_score : float Returns ------- - hungarian_match_12: pd.Series + hungarian_match_12 : pd.Series - hungarian_match_21: pd.Series + hungarian_match_21 : pd.Series """ import pandas as pd @@ -541,22 +539,22 @@ def do_score_labels(sorting1, sorting2, delta_frames, unit_map12, label_misclass Parameters ---------- - sorting1: SortingExtractor instance + sorting1 : SortingExtractor instance The ground truth sorting - sorting2: SortingExtractor instance + sorting2 : SortingExtractor instance The tested sorting - delta_frames: int + delta_frames : int Number of frames to consider spikes coincident - unit_map12: pd.Series + unit_map12 : pd.Series Dict of matching from sorting1 to sorting2 - label_misclassification: bool + label_misclassification : bool If True, misclassification errors are labelled Returns ------- - labels_st1: dict of lists of np.array of str + labels_st1 : dict of lists of np.array of str Contain score labels for units of sorting 1 for each segment - labels_st2: dict of lists of np.array of str + labels_st2 : dict of lists of np.array of str Contain score labels for units of sorting 2 for each segment """ unit1_ids = sorting1.get_unit_ids() @@ -647,12 +645,12 @@ def compare_spike_trains(spiketrain1, spiketrain2, delta_frames=10): Parameters ---------- - spiketrain1, spiketrain2: numpy.array + spiketrain1, spiketrain2 : numpy.array Times of spikes for the 2 spike trains. Returns ------- - lab_st1, lab_st2: numpy.array + lab_st1, lab_st2 : numpy.array Label of score for each spike """ lab_st1 = np.array(["UNPAIRED"] * len(spiketrain1)) @@ -684,19 +682,19 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun Parameters ---------- - event_counts1: pd.Series + event_counts1 : pd.Series Number of event per units 1 - event_counts2: pd.Series + event_counts2 : pd.Series Number of event per units 2 - match_12: pd.Series + match_12 : pd.Series Series of matching from sorting1 to sorting2. Can be the hungarian or best match. - match_event_count: pd.DataFrame + match_event_count : pd.DataFrame The match count matrix given by make_match_count_matrix Returns ------- - confusion_matrix: pd.DataFrame + confusion_matrix : pd.DataFrame The confusion matrix index are units1 reordered columns are units2 redordered @@ -746,19 +744,19 @@ def do_count_score(event_counts1, event_counts2, match_12, match_event_count): Parameters ---------- - event_counts1: pd.Series + event_counts1 : pd.Series Number of event per units 1 - event_counts2: pd.Series + event_counts2 : pd.Series Number of event per units 2 - match_12: pd.Series + match_12 : pd.Series Series of matching from sorting1 to sorting2. Can be the hungarian or best match. - match_event_count: pd.DataFrame + match_event_count : pd.DataFrame The match count matrix given by make_match_count_matrix Returns ------- - count_score: pd.DataFrame + count_score : pd.DataFrame A table with one line per GT units and columns are tp/fn/fp/... """ @@ -837,16 +835,16 @@ def make_matching_events(times1, times2, delta): Parameters ---------- - times1: list + times1 : list List of spike train 1 frames - times2: list + times2 : list List of spike train 2 frames - delta: int + delta : int Number of frames for considering matching events Returns ------- - matching_event: numpy array dtype = ["index1", "index2", "delta"] + matching_event : numpy array dtype = ["index1", "index2", "delta"] 1d of collision """ times_concat = np.concatenate((times1, times2)) @@ -894,14 +892,14 @@ def make_collision_events(sorting, delta): Parameters ---------- - sorting: SortingExtractor + sorting : SortingExtractor The sorting extractor object for counting collision events - delta: int + delta : int Number of frames for considering collision events Returns ------- - collision_events: numpy array + collision_events : numpy array dtype = [('index1', 'int64'), ('unit_id1', 'int64'), ('index2', 'int64'), ('unit_id2', 'int64'), ('delta', 'int64')] diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 5a0dd1d3a7..0cafef2c12 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -15,6 +15,21 @@ class CorrelogramGTComparison(GroundTruthComparison): This class needs maintenance and need a bit of refactoring. + Parameters + ---------- + gt_sorting : SortingExtractor + The first sorting for the comparison + tested_sorting : SortingExtractor + The second sorting for the comparison + bin_ms : float, default: 1.0 + Size of bin for correlograms + window_ms : float, default: 100.0 + The window around the spike to compute the correlation in ms. + well_detected_score : float, default: 0.8 + Agreement score above which units are well detected + **kwargs : dict + Keyword arguments for `GroundTruthComparison` + """ def __init__(self, gt_sorting, tested_sorting, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index ba7268b4f0..8929d6983c 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -42,6 +42,11 @@ class GroundTruthStudy: This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. Note that the underlying folder structure is not backward compatible! + + Parameters + ---------- + study_folder : str | Path + Path to folder containing `GroundTruthStudy` """ def __init__(self, study_folder): @@ -370,7 +375,6 @@ def get_metrics(self, key): return metrics def get_units_snr(self, key): - """ """ return self.get_metrics(key)["snr"] def get_performance_by_unit(self, case_keys=None): diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index beb9682e37..657fc73b71 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -48,6 +48,8 @@ class HybridUnitsRecording(InjectTemplatesRecording): injected_sorting_folder : str | Path | None If given, the injected sorting is saved to this folder. It must be specified if injected_sorting is None or not serialisable to file. + seed : int, default: None + Random seed for amplitude_factor Returns ------- diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 499004e32e..f7d9782a07 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -43,7 +43,9 @@ class MultiSortingComparison(BaseMultiComparison, MixinSpikeTrainComparison): - "intersection" : spike trains are the intersection between the spike trains of the best matching two sorters verbose : bool, default: False - if True, output is verbose + If True, output is verbose + do_matching : bool, default: True + If True, the comparison is done when the `MultiSortingComparison` is initialized Returns ------- @@ -318,7 +320,15 @@ class MultiTemplateComparison(BaseMultiComparison, MixinTemplateComparison): chance_score : float, default: 0.3 Minimum agreement score to for a possible match verbose : bool, default: False - if True, output is verbose + If True, output is verbose + do_matching : bool, default: True + If True, the comparison is done when the `MultiSortingComparison` is initialized + support : "dense" | "union" | "intersection", default: "union" + The support to compute the similarity matrix. + num_shifts : int, default: 0 + Number of shifts to use to shift templates to maximize similarity. + similarity_method : "cosine" | "l1" | "l2", default: "cosine" + Method for the similarity matrix. Returns ------- diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 7d5f04dfdd..5c884d82bf 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -224,7 +224,9 @@ class GroundTruthComparison(BasePairSorterComparison): tested_name : : str, default: None The name of sorter 2 delta_time : float, default: 0.4 - Number of ms to consider coincident spikes + Number of ms to consider coincident spikes. + This means that two spikes are considered simultaneous if they are within `delta_time` of each other or + mathematically abs(spike1_time - spike2_time) <= delta_time. match_score : float, default: 0.5 Minimum agreement score to match units chance_score : float, default: 0.1 @@ -263,7 +265,6 @@ def __init__( gt_name=None, tested_name=None, delta_time=0.4, - sampling_frequency=None, match_score=0.5, well_detected_score=0.8, redundant_score=0.2, @@ -425,6 +426,11 @@ def get_performance(self, method="by_unit", output="pandas"): def print_performance(self, method="pooled_with_average"): """ Print performance with the selected method + + Parameters + ---------- + method : "by_unit" | "pooled_with_average", default: "pooled_with_average" + The method to compute performance """ template_txt_performance = _template_txt_performance @@ -449,6 +455,19 @@ def print_summary(self, well_detected_score=None, redundant_score=None, overmerg * how many gt units (one or several) This summary mix several performance metrics. + + Parameters + ---------- + well_detected_score : float, default: None + The agreement score above which tested units + are counted as "well detected". + redundant_score : float, default: None + The agreement score below which tested units + are counted as "false positive"" (and not "redundant"). + overmerged_score : float, default: None + Tested units with 2 or more agreement scores above "overmerged_score" + are counted as "overmerged". + """ txt = _template_summary_part1 @@ -456,12 +475,12 @@ def print_summary(self, well_detected_score=None, redundant_score=None, overmerg num_gt=len(self.unit1_ids), num_tested=len(self.unit2_ids), num_well_detected=self.count_well_detected_units(well_detected_score), - num_redundant=self.count_redundant_units(redundant_score), - num_overmerged=self.count_overmerged_units(overmerged_score), ) if self.exhaustive_gt: txt = txt + _template_summary_part2 + d["num_redundant"] = self.count_redundant_units(redundant_score) + d["num_overmerged"] = self.count_overmerged_units(overmerged_score) d["num_false_positive_units"] = self.count_false_positive_units() d["num_bad"] = self.count_bad_units() @@ -500,6 +519,12 @@ def count_well_detected_units(self, well_detected_score): """ Count how many well detected units. kwargs are the same as get_well_detected_units. + + Parameters + ---------- + well_detected_score : float, default: None + The agreement score above which tested units + are counted as "well detected". """ return len(self.get_well_detected_units(well_detected_score=well_detected_score)) @@ -540,6 +565,12 @@ def get_false_positive_units(self, redundant_score=None): def count_false_positive_units(self, redundant_score=None): """ See get_false_positive_units(). + + Parameters + ---------- + redundant_score : float | None, default: None + The agreement score below which tested units + are counted as "false positive"" (and not "redundant"). """ return len(self.get_false_positive_units(redundant_score)) @@ -554,7 +585,7 @@ def get_redundant_units(self, redundant_score=None): Parameters ---------- - redundant_score=None : float, default: None + redundant_score : float, default: None The agreement score above which tested units are counted as "redundant" (and not "false positive" ). """ @@ -577,6 +608,12 @@ def get_redundant_units(self, redundant_score=None): def count_redundant_units(self, redundant_score=None): """ See get_redundant_units(). + + Parameters + ---------- + redundant_score : float, default: None + The agreement score below which tested units + are counted as "false positive"" (and not "redundant"). """ return len(self.get_redundant_units(redundant_score=redundant_score)) @@ -609,6 +646,12 @@ def get_overmerged_units(self, overmerged_score=None): def count_overmerged_units(self, overmerged_score=None): """ See get_overmerged_units(). + + Parameters + ---------- + overmerged_score : float, default: None + Tested units with 2 or more agreement scores above "overmerged_score" + are counted as "overmerged". """ return len(self.get_overmerged_units(overmerged_score=overmerged_score)) @@ -676,11 +719,11 @@ def count_units_categories( GT num_units: {num_gt} TESTED num_units: {num_tested} num_well_detected: {num_well_detected} -num_redundant: {num_redundant} -num_overmerged: {num_overmerged} """ -_template_summary_part2 = """num_false_positive_units {num_false_positive_units} +_template_summary_part2 = """num_redundant: {num_redundant} +num_overmerged: {num_overmerged} +num_false_positive_units {num_false_positive_units} num_bad: {num_bad} """ @@ -704,6 +747,10 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): List of units from sorting_analyzer_1 to compare. unit_ids2 : list, default: None List of units from sorting_analyzer_2 to compare. + name1 : str, default: "sess1" + Name of first session. + name2 : str, default: "sess2" + Name of second session. similarity_method : "cosine" | "l1" | "l2", default: "cosine" Method for the similarity matrix. support : "dense" | "union" | "intersection", default: "union" @@ -712,6 +759,11 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison): Number of shifts to use to shift templates to maximize similarity. verbose : bool, default: False If True, output is verbose. + chance_score : float, default: 0.3 + Minimum agreement score to for a possible match + match_score : float, default: 0.7 + Minimum agreement score to match units + Returns ------- diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 674f1ac463..ead7007920 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -166,4 +166,8 @@ # Important not for compatibility!! # This wil be uncommented after 0.100 -from .waveforms_extractor_backwards_compatibility import extract_waveforms, load_waveforms +from .waveforms_extractor_backwards_compatibility import ( + extract_waveforms, + load_waveforms, + load_sorting_analyzer_or_waveforms, +) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index ff1dc5dafa..bc5de63d07 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -675,13 +675,14 @@ class ComputeNoiseLevels(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - **params: dict with additional parameters for the `spikeinterface.get_noise_levels()` function + **kwargs : dict + Additional parameters for the `spikeinterface.get_noise_levels()` function Returns ------- - noise_levels: np.array + noise_levels : np.array The noise level vector """ diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 72c0a2c2fe..1fa218851b 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -17,6 +17,7 @@ from .globals import get_global_tmp_folder, is_set_global_tmp_folder from .core_tools import ( check_json, + clean_zarr_folder_name, is_dict_extractor, SIJsonEncoder, make_paths_relative, @@ -164,7 +165,7 @@ def id_to_index(self, id) -> int: def annotate(self, **new_annotations) -> None: self._annotations.update(new_annotations) - def set_annotation(self, annotation_key, value: Any, overwrite=False) -> None: + def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> None: """This function adds an entry to the annotations dictionary. Parameters @@ -192,7 +193,7 @@ def get_preferred_mp_context(self): """ return self._preferred_mp_context - def get_annotation(self, key, copy: bool = True) -> Any: + def get_annotation(self, key: str, copy: bool = True) -> Any: """ Get a annotation. Return a copy by default @@ -205,7 +206,13 @@ def get_annotation(self, key, copy: bool = True) -> Any: def get_annotation_keys(self) -> List: return list(self._annotations.keys()) - def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, missing_value: Any = None) -> None: + def set_property( + self, + key, + values: list | np.ndarray | tuple, + ids: list | np.ndarray | tuple | None = None, + missing_value: Any = None, + ) -> None: """ Set property vector for main ids. @@ -223,6 +230,7 @@ def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, mi Array of values for the property ids : list/np.array, default: None List of subset of ids to set the values, default: None + if None which is the default all the ids are set or changed missing_value : object, default: None In case the property is set on a subset of values ("ids" not None), it specifies the how the missing values should be filled. @@ -240,23 +248,26 @@ def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, mi dtype_kind = dtype.kind if ids is None: - assert values.shape[0] == size + assert ( + values.shape[0] == size + ), f"Values must have the same size as the main ids: {size} but got array of size {values.shape[0]}" self._properties[key] = values else: ids = np.array(ids) - assert np.unique(ids).size == ids.size, "'ids' are not unique!" + unique_ids = np.unique(ids) + if unique_ids.size != ids.size: + non_unique_ids = [id for id in ids if np.count_nonzero(ids == id) > 1] + raise ValueError(f"IDs are not unique: {non_unique_ids}") + # Not clear where this branch is used, perhaps on aggregation of extractors? if ids.size < size: if key not in self._properties: - # create the property with nan or empty - shape = (size,) + values.shape[1:] if missing_value is None: if dtype_kind not in self.default_missing_property_values.keys(): - raise Exception( - "For values dtypes other than float, string, object or unicode, the missing value " - "cannot be automatically inferred. Please specify it with the 'missing_value' " - "argument." + raise ValueError( + f"Can't infer a natural missing value for dtype {dtype_kind}. " + "Please provide it with the `missing_value` argument" ) else: missing_value = self.default_missing_property_values[dtype_kind] @@ -268,15 +279,18 @@ def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, mi "as the values." ) + # create the property with nan or empty + shape = (size,) + values.shape[1:] empty_values = np.zeros(shape, dtype=dtype) empty_values[:] = missing_value self._properties[key] = empty_values if ids.size == 0: return else: - assert dtype_kind == self._properties[key].dtype.kind, ( - "Mismatch between existing property dtype " "values dtype." - ) + existing_property = self._properties[key].dtype + assert ( + dtype_kind == existing_property.kind + ), f"Mismatch between existing property dtype {existing_property.kind} and provided values dtype {dtype_kind}." indices = self.ids_to_indices(ids) self._properties[key][indices] = values @@ -285,7 +299,7 @@ def set_property(self, key, values: Sequence, ids: Optional[Sequence] = None, mi self._properties[key] = np.zeros_like(values, dtype=values.dtype) self._properties[key][indices] = values - def get_property(self, key, ids: Optional[Iterable] = None) -> np.ndarray: + def get_property(self, key: str, ids: Optional[Iterable] = None) -> np.ndarray: values = self._properties.get(key, None) if ids is not None and values is not None: inds = self.ids_to_indices(ids) @@ -1048,9 +1062,7 @@ def save_to_zarr( print(f"Use zarr_path={zarr_path}") else: if storage_options is None: - folder = Path(folder) - if folder.suffix != ".zarr": - folder = folder.parent / f"{folder.stem}.zarr" + folder = clean_zarr_folder_name(folder) if folder.is_dir() and overwrite: shutil.rmtree(folder) zarr_path = folder diff --git a/src/spikeinterface/core/baseevent.py b/src/spikeinterface/core/baseevent.py index 2cf92f8cea..9e1963089b 100644 --- a/src/spikeinterface/core/baseevent.py +++ b/src/spikeinterface/core/baseevent.py @@ -137,7 +137,7 @@ class BaseEventSegment(BaseSegment): def __init__(self): BaseSegment.__init__(self) - def get_event_times(self, channel_id: int | str, start_time: float, end_time: float): + def get_event_times(self, channel_id: int | str, start_time: float, end_time: float) -> np.ndarray: """Returns event timestamps of a channel in seconds Parameters ---------- diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 03db6bd9af..82f2ae1890 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -96,14 +96,18 @@ def list_to_string(lst, max_size=6): def _repr_header(self): num_segments = self.get_num_segments() num_channels = self.get_num_channels() - sf_hz = self.get_sampling_frequency() - sf_khz = sf_hz / 1000 dtype = self.get_dtype() total_samples = self.get_total_samples() total_duration = self.get_total_duration() total_memory_size = self.get_total_memory_size() - sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz" + + sf_hz = self.get_sampling_frequency() + if not sf_hz.is_integer(): + sampling_frequency_repr = f"{sf_hz:f} Hz" + else: + # Khz for high sampling rate and Hz for LFP + sampling_frequency_repr = f"{(sf_hz/1000.0):0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz" txt = ( f"{self.name}: " @@ -301,7 +305,7 @@ def get_traces( order: "C" | "F" | None = None, return_scaled: bool = False, cast_unsigned: bool = False, - ): + ) -> np.ndarray: """Returns traces from recording. Parameters @@ -422,7 +426,7 @@ def get_time_info(self, segment_index=None) -> dict: return time_kwargs - def get_times(self, segment_index=None): + def get_times(self, segment_index=None) -> np.ndarray: """Get time vector for a recording segment. If the segment has a time_vector, then it is returned. Otherwise @@ -491,6 +495,20 @@ def set_times(self, times, segment_index=None, with_warning=True): "Use this carefully!" ) + def reset_times(self): + """ + Reset time information in-memory for all segments that have a time vector. + If the timestamps come from a file, the files won't be modified. but only the in-memory + attributes of the recording objects are deleted. Also `t_start` is set to None and the + segment's sampling frequency is set to the recording's sampling frequency. + """ + for segment_index in range(self.get_num_segments()): + if self.has_time_vector(segment_index): + rs = self._recording_segments[segment_index] + rs.time_vector = None + rs.t_start = None + rs.sampling_frequency = self.sampling_frequency + def sample_index_to_time(self, sample_ind, segment_index=None): """ Transform sample index into time in seconds @@ -810,12 +828,10 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): BaseSegment.__init__(self) - def get_times(self): + def get_times(self) -> np.ndarray: if self.time_vector is not None: - if isinstance(self.time_vector, np.ndarray): - return self.time_vector - else: - return np.array(self.time_vector) + self.time_vector = np.asarray(self.time_vector) + return self.time_vector else: time_vector = np.arange(self.get_num_samples(), dtype="float64") time_vector /= self.sampling_frequency diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index aad7613d01..b3a857d158 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -153,6 +153,13 @@ def check_json(dictionary: dict) -> dict: return json.loads(json_string) +def clean_zarr_folder_name(folder): + folder = Path(folder) + if folder.suffix != ".zarr": + folder = folder.parent / f"{folder.stem}.zarr" + return folder + + def add_suffix(file_path, possible_suffix): file_path = Path(file_path) if isinstance(possible_suffix, str): @@ -194,7 +201,7 @@ def is_dict_extractor(d: dict) -> bool: extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"]) -def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]: +def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element, None, None]: """ Iterator for recursive traversal of a dictionary. This function explores the dictionary recursively and yields the path to each value along with the value itself. @@ -250,8 +257,7 @@ def set_value_in_extractor_dict(extractor_dict: dict, access_path: tuple, new_va Returns ------- - dict - The modified dictionary + None """ current = extractor_dict @@ -432,7 +438,7 @@ def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict: return output_dict -def make_paths_absolute(input_dict, base_folder): +def make_paths_absolute(input_dict, base_folder) -> dict: """ Recursively transform a dict describing an BaseExtractor to make every path absolute given a base_folder. @@ -625,7 +631,7 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0): return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((x - mu) ** 2) / (2 * sigma**2)) -def retrieve_importing_provenance(a_class): +def retrieve_importing_provenance(a_class) -> dict: """ Retrieve the import provenance of a class, including its import name (that consists of the class name and the module), the top-level module, and the module version. @@ -684,3 +690,20 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float: memory = mem_info.total - mem_info.available return memory + + +def is_path_remote(path: str | Path) -> bool: + """ + Returns True if the path is a remote path (e.g., s3:// or gcs://). + + Parameters + ---------- + path : str or Path + The path to check. + + Returns + ------- + bool + Whether the path is a remote path. + """ + return "s3://" in str(path) or "gcs://" in str(path) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ff75789aab..6d2d1cbb55 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from typing import Union, Optional, List, Literal +from typing import Literal from math import ceil from .basesorting import SpikeVectorSortingSegment @@ -27,13 +27,13 @@ def _ensure_seed(seed): def generate_recording( - num_channels: Optional[int] = 2, - sampling_frequency: Optional[float] = 30000.0, - durations: Optional[List[float]] = [5.0, 2.5], - set_probe: Optional[bool] = True, - ndim: Optional[int] = 2, - seed: Optional[int] = None, -) -> BaseRecording: + num_channels: int = 2, + sampling_frequency: float = 30000.0, + durations: list[float] = [5.0, 2.5], + set_probe: bool | None = True, + ndim: int | None = 2, + seed: int | None = None, +) -> NumpySorting: """ Generate a lazy recording object. Useful for testing API and algos. @@ -44,13 +44,14 @@ def generate_recording( The number of channels in the recording. sampling_frequency : float, default: 30000. (in Hz) The sampling frequency of the recording, default: 30000. - durations: List[float], default: [5.0, 2.5] - The duration in seconds of each segment in the recording, default: [5.0, 2.5]. - Note that the number of segments is determined by the length of this list. - set_probe: bool, default: True + durations : list[float], default: [5.0, 2.5] + The duration in seconds of each segment in the recording. + The number of segments is determined by the length of this list. + set_probe : bool, default: True + If true, attaches probe to the returned `Recording` ndim : int, default: 2 The number of dimensions of the probe, default: 2. Set to 3 to make 3 dimensional probe. - seed : Optional[int] + seed : int | None, default: None A seed for the np.ramdom.default_rng function Returns @@ -105,7 +106,7 @@ def generate_sorting( num_units : int, default: 5 Number of units. sampling_frequency : float, default: 30000.0 - The sampling frequency. + The sampling frequency of the recording in Hz. durations : list, default: [10.325, 3.5] Duration of each segment in s. firing_rates : float, default: 3.0 @@ -188,7 +189,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): ---------- sorting : BaseSorting The sorting object. - sync_event_ratio : float + sync_event_ratio : float, default: 0.3 The ratio of added synchronous spikes with respect to the total number of spikes. E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra spikes are synchronous (same sample_index), but on different units (not duplicates). @@ -236,7 +237,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): def generate_sorting_to_inject( sorting: BaseSorting, - num_samples: List[int], + num_samples: list[int], max_injected_per_unit: int = 1000, injected_rate: float = 0.05, refractory_period_ms: float = 1.5, @@ -250,16 +251,16 @@ def generate_sorting_to_inject( ---------- sorting : BaseSorting The sorting object. - num_samples: list of size num_segments. + num_samples : list[int] of size num_segments. The number of samples in all the segments of the sorting, to generate spike times covering entire the entire duration of the segments. - max_injected_per_unit: int, default 1000 + max_injected_per_unit : int, default: 1000 The maximal number of spikes injected per units. - injected_rate: float, default 0.05 + injected_rate : float, default: 0.05 The rate at which spikes are injected. - refractory_period_ms: float, default 1.5 + refractory_period_ms : float, default: 1.5 The refractory period that should not be violated while injecting new spikes. - seed: int, default None + seed : int, default: None The random seed. Returns @@ -313,13 +314,13 @@ class TransformSorting(BaseSorting): ---------- sorting : BaseSorting The sorting object. - added_spikes_existing_units : np.array (spike_vector) + added_spikes_existing_units : np.array (spike_vector) | None, default: None The spikes that should be added to the sorting object, for existing units. - added_spikes_new_units: np.array (spike_vector) + added_spikes_new_units : np.array (spike_vector) | None, default: None The spikes that should be added to the sorting object, for new units. - new_units_ids: list + new_units_ids : list[str, int] | None, default: None The unit_ids that should be added if spikes for new units are added. - refractory_period_ms : float, default None + refractory_period_ms : float | None, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -333,10 +334,10 @@ class TransformSorting(BaseSorting): def __init__( self, sorting: BaseSorting, - added_spikes_existing_units=None, - added_spikes_new_units=None, - new_unit_ids: Optional[List[Union[str, int]]] = None, - refractory_period_ms: Optional[float] = None, + added_spikes_existing_units: np.array | None = None, + added_spikes_new_units: np.array | None = None, + new_unit_ids: list[str | int] | None = None, + refractory_period_ms: float | None = None, ): sampling_frequency = sorting.get_sampling_frequency() unit_ids = list(sorting.get_unit_ids()) @@ -428,11 +429,11 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe Parameters ---------- - sorting1: BaseSorting + sorting1 : BaseSorting The first sorting. - sorting2: BaseSorting + sorting2 : BaseSorting The second sorting. - refractory_period_ms : float, default None + refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -484,7 +485,7 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe @staticmethod def add_from_unit_dict( - sorting1: BaseSorting, units_dict_list: dict, refractory_period_ms=None + sorting1: BaseSorting, units_dict_list: list[dict] | dict, refractory_period_ms=None ) -> "TransformSorting": """ Construct TransformSorting by adding one sorting with a @@ -494,11 +495,11 @@ def add_from_unit_dict( Parameters ---------- - sorting1: BaseSorting + sorting1 : BaseSorting The first sorting - dict_list: list of dict + dict_list : list[dict] | dict A list of dict with unit_ids as keys and spike times as values. - refractory_period_ms : float, default None + refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -519,16 +520,18 @@ def from_times_labels( Parameters ---------- - sorting1: BaseSorting + sorting1 : BaseSorting The first sorting - times_list: list of array (or array) + times_list : list[np.array] | np.array An array of spike times (in frames). - labels_list: list of array (or array) + labels_list : list[np.array] | np.array An array of spike labels corresponding to the given times. - unit_ids: list or None, default: None + sampling_frequency : float, default: 30000.0 + The sampling frequency of the recording in Hz. + unit_ids : list | None, default: None The explicit list of unit_ids that should be extracted from labels_list If None, then it will be np.unique(labels_list). - refractory_period_ms : float, default None + refractory_period_ms : float, default: None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be discarded. @@ -592,7 +595,7 @@ def generate_snippets( nafter=44, num_channels=2, wf_folder=None, - sampling_frequency=30000.0, # in Hz + sampling_frequency=30000.0, durations=[10.325, 3.5], #  in s for 2 segments set_probe=True, ndim=2, @@ -614,13 +617,20 @@ def generate_snippets( wf_folder : str | Path | None, default: None Optional folder to save the waveform snippets. If None, snippets are in memory. sampling_frequency : float, default: 30000.0 - The sampling frequency of the snippets. + The sampling frequency of the snippets in Hz. ndim : int, default: 2 The number of dimensions of the probe. num_units : int, default: 5 The number of units. empty_units : list | None, default: None A list of units that will have no spikes. + durations : List[float], default: [10.325, 3.5] + The duration in seconds of each segment in the recording. + The number of segments is determined by the length of this list. + set_probe : bool, default: True + If true, attaches probe to the returned snippets object + **job_kwargs : dict, default: None + Job keyword arguments for `snippets_from_sorting` Returns ------- @@ -793,20 +803,20 @@ def synthesize_random_firings( Parameters ---------- - num_units : int + num_units : int, default: 20 Number of units. - sampling_frequency : float - Sampling rate. - duration : float + sampling_frequency : float, default: 30000.0 + Sampling rate in Hz. + duration : float, default: 60 Duration of the segment in seconds. - refractory_period_ms: float + refractory_period_ms : float Refractory period in ms. - firing_rates: float or list[float] + firing_rates : float or list[float] The firing rate of each unit (in Hz). If float, all units will have the same firing rate. - add_shift_shuffle: bool, default: False + add_shift_shuffle : bool, default: False Optionally add a small shuffle on half of the spikes to make the autocorrelogram less flat. - seed: int, default: None + seed : int, default: None Seed for the generator. Returns @@ -899,12 +909,14 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No ---------- sorting : Original sorting. - num : int + num : int, default: 4 Number of injected units. - max_shift : int + max_shift : int, default: 5 range of the shift in sample. - ratio: float + ratio : float | None, default: None Proportion of original spike in the injected units. + seed : int | None, default: None + Random seed for creating unit peak shifts. Returns ------- @@ -1060,23 +1072,23 @@ class NoiseGeneratorRecording(BaseRecording): The number of channels. sampling_frequency : float The sampling frequency of the recorder. - durations : List[float] + durations : list[float] The durations of each segment in seconds. Note that the length of this list is the number of segments. - noise_levels: float or array, default: 1 + noise_levels : float | np.array, default: 1.0 Std of the white noise (if an array, defined by per channels) - cov_matrix: np.array, default None + cov_matrix : np.array | None, default: None The covariance matrix of the noise - dtype : Optional[Union[np.dtype, str]], default: "float32" + dtype : np.dtype | str | None, default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. - seed : Optional[int], default: None + seed : int | None, default: None The seed for np.random.default_rng. - strategy : "tile_pregenerated" or "on_the_fly" + strategy : "tile_pregenerated" | "on_the_fly", default: "tile_pregenerated" The strategy of generating noise chunk: * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it very fast and cusume only one noise block. * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index no memory preallocation but a bit more computaion (random) - noise_block_size: int + noise_block_size : int, default: 30000 Size in sample of noise block. Notes @@ -1089,11 +1101,11 @@ def __init__( self, num_channels: int, sampling_frequency: float, - durations: List[float], - noise_levels: float = 1.0, - cov_matrix: Optional[np.array] = None, - dtype: Optional[Union[np.dtype, str]] = "float32", - seed: Optional[int] = None, + durations: list[float], + noise_levels: float | np.array = 1.0, + cov_matrix: np.array | None = None, + dtype: np.dtype | str | None = "float32", + seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", noise_block_size: int = 30000, ): @@ -1150,7 +1162,6 @@ def __init__( "sampling_frequency": sampling_frequency, "noise_levels": noise_levels, "cov_matrix": cov_matrix, - "noise_levels": noise_levels, "dtype": dtype, "seed": seed, "strategy": strategy, @@ -1205,10 +1216,16 @@ def get_num_samples(self) -> int: def get_traces( self, - start_frame: Union[int, None] = None, - end_frame: Union[int, None] = None, - channel_indices: Union[List, None] = None, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: list | None = None, ) -> np.ndarray: + + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self.get_num_samples() + start_frame_within_block = start_frame % self.noise_block_size end_frame_within_block = end_frame % self.noise_block_size num_samples = end_frame - start_frame @@ -1261,8 +1278,7 @@ def get_traces( def generate_recording_by_size( full_traces_size_GiB: float, - num_channels: int = 384, - seed: Optional[int] = None, + seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", ) -> NoiseGeneratorRecording: """ @@ -1279,11 +1295,14 @@ def generate_recording_by_size( ---------- full_traces_size_GiB : float The size in gigabytes (GiB) of the recording. - num_channels: int - Number of channels. - seed : int, default: None + seed : int | None, default: None The seed for np.random.default_rng. - + strategy : "tile_pregenerated" | "on_the_fly", default: "tile_pregenerated" + The strategy of generating noise chunk: + * "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it + very fast and consume only one noise block. + * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index + no memory preallocation but a bit more computation (random) Returns ------- GeneratorRecording @@ -1519,25 +1538,25 @@ def generate_templates( Parameters ---------- - channel_locations: np.ndarray + channel_locations : np.ndarray Channel locations. - units_locations: np.ndarray + units_locations : np.ndarray Must be 3D. - sampling_frequency: float + sampling_frequency : float Sampling frequency. - ms_before: float + ms_before : float Cut out in ms before spike peak. - ms_after: float + ms_after : float Cut out in ms after spike peak. - seed: int or None + seed : int | None A seed for random. - dtype: numpy.dtype, default: "float32" + dtype : numpy.dtype, default: "float32" Templates dtype - upsample_factor: None or int + upsample_factor : int | None, default: None If not None then template are generated upsampled by this factor. Then a new dimention (axis=3) is added to the template with intermediate inter sample representation. This allow easy random jitter by choising a template this new dim - unit_params: dict of arrays or dict of scalar of dict of tuple + unit_params : dict[np.array] | dict[float] | dict[tuple] | None, default: None An optional dict containing parameters per units. Keys are parameter names: @@ -1554,6 +1573,14 @@ def generate_templates( * array of the same length of units * scalar, then an array is created * tuple, then this difine a range for random values. + mode : "ellipsoid" | "sphere", default: "ellipsoid" + Method used to calculate the distance between unit and channel location. + Ellipsoid injects some anisotropy dependent on unit shape, sphere is equivalent + to Euclidean distance. + + mode : "sphere" | "ellipsoid", default: "ellipsoid" + Mode for how to calculate distances + Returns ------- @@ -1674,31 +1701,33 @@ class InjectTemplatesRecording(BaseRecording): Parameters ---------- - sorting: BaseSorting + sorting : BaseSorting Sorting object containing all the units and their spike train. - templates: np.ndarray[n_units, n_samples, n_channels] or np.ndarray[n_units, n_samples, n_oversampling] + templates : np.ndarray[n_units, n_samples, n_channels] | np.ndarray[n_units, n_samples, n_oversampling] Array containing the templates to inject for all the units. Shape can be: * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter. - nbefore: list[int] | int | None, default: None + nbefore : list[int] | int | None, default: None The number of samples before the peak of the template to align the spike. If None, will default to the highest peak. - amplitude_factor: list[float] | float | None, default: None + amplitude_factor : list[float] | float | None, default: None The amplitude of each spike for each unit. Can be None (no scaling). Can be scalar all spikes have the same factor (certainly useless). Can be a vector with same shape of spike_vector of the sorting. - parent_recording: BaseRecording | None + parent_recording : BaseRecording | None, default: None The recording over which to add the templates. If None, will default to traces containing all 0. - num_samples: list[int] | int | None + num_samples : list[int] | int | None, default: None The number of samples in the recording per segment. You can use int for mono-segment objects. - upsample_vector: np.array or None, default: None. + upsample_vector : np.array | None, default: None. When templates is 4d we can simulate a jitter. Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3]. + check_borders : bool, default: False + Checks if the border of the templates are zero. Returns ------- @@ -1710,11 +1739,11 @@ def __init__( self, sorting: BaseSorting, templates: np.ndarray, - nbefore: Union[List[int], int, None] = None, - amplitude_factor: Union[List[List[float]], List[float], float, None] = None, - parent_recording: Union[BaseRecording, None] = None, - num_samples: Optional[List[int]] = None, - upsample_vector: Union[List[int], None] = None, + nbefore: list[int] | int | None = None, + amplitude_factor: list[float] | float | None = None, + parent_recording: BaseRecording | None = None, + num_samples: list[int] | int | None = None, + upsample_vector: np.array | None = None, check_borders: bool = False, ) -> None: templates = np.asarray(templates) @@ -1846,10 +1875,10 @@ def __init__( spike_vector: np.ndarray, templates: np.ndarray, nbefore: int, - amplitude_vector: Union[List[float], None], - upsample_vector: Union[List[float], None], - parent_recording_segment: Union[BaseRecordingSegment, None] = None, - num_samples: Union[int, None] = None, + amplitude_vector: list[float] | None, + upsample_vector: list[float] | None, + parent_recording_segment: BaseRecordingSegment | None = None, + num_samples: int | None = None, ) -> None: BaseRecordingSegment.__init__( self, @@ -1869,9 +1898,9 @@ def __init__( def get_traces( self, - start_frame: Union[int, None] = None, - end_frame: Union[int, None] = None, - channel_indices: Union[List, None] = None, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: list | None = None, ) -> np.ndarray: if channel_indices is None: n_channels = self.templates.shape[2] @@ -2042,55 +2071,55 @@ def generate_ground_truth_recording( Parameters ---------- - durations: list of float, default: [10.] + durations : list[float], default: [10.] Durations in seconds for all segments. - sampling_frequency: float, default: 25000 + sampling_frequency : float, default: 25000.0 Sampling frequency. - num_channels: int, default: 4 + num_channels : int, default: 4 Number of channels, not used when probe is given. - num_units: int, default: 10 + num_units : int, default: 10 Number of units, not used when sorting is given. - sorting: Sorting or None + sorting : Sorting | None An external sorting object. If not provide, one is genrated. - probe: Probe or None + probe : Probe | None An external Probe object. If not provided a probe is generated using generate_probe_kwargs. - generate_probe_kwargs: dict + generate_probe_kwargs : dict A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. - templates: np.array or None + templates : np.array | None The templates of units. If None they are generated. Shape can be: * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter. - ms_before: float, default: 1.5 + ms_before : float, default: 1.5 Cut out in ms before spike peak. - ms_after: float, default: 3 + ms_after : float, default: 3.0 Cut out in ms after spike peak. - upsample_factor: None or int, default: None + upsample_factor : None | int, default: None A upsampling factor used only when templates are not provided. - upsample_vector: np.array or None + upsample_vector : np.array | None Optional the upsample_vector can given. This has the same shape as spike_vector - generate_sorting_kwargs: dict + generate_sorting_kwargs : dict When sorting is not provide, this dict is used to generated a Sorting. - noise_kwargs: dict + noise_kwargs : dict Dict used to generated the noise with NoiseGeneratorRecording. - generate_unit_locations_kwargs: dict + generate_unit_locations_kwargs : dict Dict used to generated template when template not provided. - generate_templates_kwargs: dict + generate_templates_kwargs : dict Dict used to generated template when template not provided. - dtype: np.dtype, default: "float32" + dtype : np.dtype, default: "float32" The dtype of the recording. - seed: int or None + seed : int | None Seed for random initialization. If None a diffrent Recording is generated at every call. Note: even with None a generated recording keep internaly a seed to regenerate the same signal after dump/load. Returns ------- - recording: Recording + recording : Recording The generated recording extractor. - sorting: Sorting + sorting : Sorting The generated sorting extractor. """ generate_templates_kwargs = generate_templates_kwargs or dict() diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index ea2f20d631..53ef736208 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -88,7 +88,7 @@ def get_unit_ids(self): """ return list(self._unit_map.keys()) - def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): + def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None) -> np.ndarray: """This function extracts spike frames from the specified unit. It will return spike frames from within three ranges: @@ -122,7 +122,7 @@ def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): unit_id=self._unit_map[unit_id], segment_index=0, start_frame=start_frame, end_frame=end_frame ) - def get_units_spike_train(self, unit_ids=None, start_frame=None, end_frame=None): + def get_units_spike_train(self, unit_ids=None, start_frame=None, end_frame=None) -> list[np.ndarray]: """This function extracts spike frames from the specified units. Parameters @@ -137,7 +137,7 @@ def get_units_spike_train(self, unit_ids=None, start_frame=None, end_frame=None) Returns ------- - spike_train: numpy.ndarray + spike_train: list[numpy.ndarray] An 2D array containing all the frames for each spike in the specified units given the range of start and end frames """ @@ -146,7 +146,7 @@ def get_units_spike_train(self, unit_ids=None, start_frame=None, end_frame=None) spike_trains = [self.get_unit_spike_train(uid, start_frame, end_frame) for uid in unit_ids] return spike_trains - def get_sampling_frequency(self): + def get_sampling_frequency(self) -> float: """ It returns the sampling frequency. diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index b4c07e77c9..0ec5449bae 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -69,7 +69,7 @@ def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsi def write_binary_recording( recording: "BaseRecording", file_paths: list[Path | str] | Path | str, - dtype: np.ndtype = None, + dtype: np.typing.DTypeLike = None, add_file_extension: bool = True, byte_offset: int = 0, auto_cast_uint: bool = True, @@ -640,7 +640,7 @@ def get_noise_levels( method: Literal["mad", "std"] = "mad", force_recompute: bool = False, **random_chunk_kwargs, -): +) -> np.ndarray: """ Estimate noise for each channel using MAD methods. You can use standard deviation with `method="std"` @@ -929,7 +929,9 @@ def get_rec_attributes(recording): return rec_attributes -def do_recording_attributes_match(recording1, recording2_attributes) -> bool: +def do_recording_attributes_match( + recording1: "BaseRecording", recording2_attributes: bool, check_dtype: bool = True +) -> tuple[bool, str]: """ Check if two recordings have the same attributes @@ -939,22 +941,43 @@ def do_recording_attributes_match(recording1, recording2_attributes) -> bool: The first recording object recording2_attributes : dict The recording attributes to test against + check_dtype : bool, default: True + If True, check if the recordings have the same dtype Returns ------- bool True if the recordings have the same attributes + str + A string with the exception message with the attributes that do not match """ recording1_attributes = get_rec_attributes(recording1) recording2_attributes = deepcopy(recording2_attributes) recording1_attributes.pop("properties") recording2_attributes.pop("properties") - return ( - np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]) - and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"] - and recording1_attributes["num_channels"] == recording2_attributes["num_channels"] - and recording1_attributes["num_samples"] == recording2_attributes["num_samples"] - and recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"] - and recording1_attributes["dtype"] == recording2_attributes["dtype"] - ) + attributes_match = True + non_matching_attrs = [] + + if not np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]): + non_matching_attrs.append("channel_ids") + if not recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]: + non_matching_attrs.append("sampling_frequency") + if not recording1_attributes["num_channels"] == recording2_attributes["num_channels"]: + non_matching_attrs.append("num_channels") + if not recording1_attributes["num_samples"] == recording2_attributes["num_samples"]: + non_matching_attrs.append("num_samples") + # dtype is optional + if "dtype" in recording1_attributes and "dtype" in recording2_attributes: + if check_dtype: + if not recording1_attributes["dtype"] == recording2_attributes["dtype"]: + non_matching_attrs.append("dtype") + + if len(non_matching_attrs) > 0: + attributes_match = False + exception_str = f"Recordings do not match in the following attributes: {non_matching_attrs}" + else: + attributes_match = True + exception_str = "" + + return attributes_match, exception_str diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ac142405ab..817c453a97 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match -from .core_tools import check_json, retrieve_importing_provenance +from .core_tools import check_json, retrieve_importing_provenance, is_path_remote, clean_zarr_folder_name from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting @@ -44,7 +44,7 @@ def create_sorting_analyzer( return_scaled=True, overwrite=False, **sparsity_kwargs, -): +) -> "SortingAnalyzer": """ Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. @@ -111,6 +111,8 @@ def create_sorting_analyzer( sparsity off (or give external sparsity) like this. """ if format != "memory": + if format == "zarr": + folder = clean_zarr_folder_name(folder) if Path(folder).is_dir(): if not overwrite: raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") @@ -143,7 +145,7 @@ def create_sorting_analyzer( return sorting_analyzer -def load_sorting_analyzer(folder, load_extensions=True, format="auto"): +def load_sorting_analyzer(folder, load_extensions=True, format="auto", storage_options=None) -> "SortingAnalyzer": """ Load a SortingAnalyzer object from disk. @@ -155,6 +157,9 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto"): Load all extensions or not. format : "auto" | "binary_folder" | "zarr" The format of the folder. + storage_options : dict | None, default: None + The storage options to specify credentials to remote zarr bucket. + For open buckets, it doesn't need to be specified. Returns ------- @@ -162,7 +167,7 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto"): The loaded SortingAnalyzer """ - return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format) + return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, storage_options=storage_options) class SortingAnalyzer: @@ -195,6 +200,7 @@ def __init__( format=None, sparsity=None, return_scaled=True, + storage_options=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -204,6 +210,7 @@ def __init__( self.format = format self.sparsity = sparsity self.return_scaled = return_scaled + self.storage_options = storage_options # this is used to store temporary recording self._temporary_recording = None @@ -267,6 +274,8 @@ def create( sorting_analyzer = cls.load_from_binary_folder(folder, recording=recording) sorting_analyzer.folder = Path(folder) elif format == "zarr": + assert folder is not None, "For format='zarr' folder must be provided" + folder = clean_zarr_folder_name(folder) cls.create_zarr(folder, sorting, recording, sparsity, return_scaled, rec_attributes=None) sorting_analyzer = cls.load_from_zarr(folder, recording=recording) sorting_analyzer.folder = Path(folder) @@ -276,17 +285,15 @@ def create( return sorting_analyzer @classmethod - def load(cls, folder, recording=None, load_extensions=True, format="auto"): + def load(cls, folder, recording=None, load_extensions=True, format="auto", storage_options=None): """ Load folder or zarr. The recording can be given if the recording location has changed. Otherwise the recording is loaded when possible. """ - folder = Path(folder) - assert folder.is_dir(), "Waveform folder does not exists" if format == "auto": # make better assumption and check for auto guess format - if folder.suffix == ".zarr": + if Path(folder).suffix == ".zarr": format = "zarr" else: format = "binary_folder" @@ -294,12 +301,18 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto"): if format == "binary_folder": sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) elif format == "zarr": - sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) + sorting_analyzer = SortingAnalyzer.load_from_zarr( + folder, recording=recording, storage_options=storage_options + ) - sorting_analyzer.folder = folder + if is_path_remote(str(folder)): + sorting_analyzer.folder = folder + # in this case we only load extensions when needed + else: + sorting_analyzer.folder = Path(folder) - if load_extensions: - sorting_analyzer.load_all_saved_extension() + if load_extensions: + sorting_analyzer.load_all_saved_extension() return sorting_analyzer @@ -470,7 +483,9 @@ def load_from_binary_folder(cls, folder, recording=None): def _get_zarr_root(self, mode="r+"): import zarr - zarr_root = zarr.open(self.folder, mode=mode) + if is_path_remote(str(self.folder)): + mode = "r" + zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options) return zarr_root @classmethod @@ -479,10 +494,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at import zarr import numcodecs - folder = Path(folder) - # force zarr sufix - if folder.suffix != ".zarr": - folder = folder.parent / f"{folder.stem}.zarr" + folder = clean_zarr_folder_name(folder) if folder.is_dir(): raise ValueError(f"Folder already exists {folder}") @@ -552,25 +564,22 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info = zarr_root.create_group("extensions") @classmethod - def load_from_zarr(cls, folder, recording=None): + def load_from_zarr(cls, folder, recording=None, storage_options=None): import zarr - folder = Path(folder) - assert folder.is_dir(), f"This folder does not exist {folder}" - - zarr_root = zarr.open(folder, mode="r") + zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options) # load internal sorting in memory - # TODO propagate storage_options sorting = NumpySorting.from_sorting( - ZarrSortingExtractor(folder, zarr_group="sorting"), with_metadata=True, copy_spike_vector=True + ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), + with_metadata=True, + copy_spike_vector=True, ) # load recording if possible if recording is None: rec_dict = zarr_root["recording"][0] try: - recording = load_extractor(rec_dict, base_folder=folder) except: recording = None @@ -608,7 +617,7 @@ def load_from_zarr(cls, folder, recording=None): return sorting_analyzer - def set_temporary_recording(self, recording: BaseRecording): + def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool = True): """ Sets a temporary recording object. This function can be useful to temporarily set a "cached" recording object that is not saved in the SortingAnalyzer object to speed up @@ -620,12 +629,17 @@ def set_temporary_recording(self, recording: BaseRecording): ---------- recording : BaseRecording The recording object to set as temporary recording. + check_dtype : bool, default: True + If True, check that the dtype of the temporary recording is the same as the original recording. """ # check that recording is compatible - assert do_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match." - assert np.array_equal( - recording.get_channel_locations(), self.get_channel_locations() - ), "Recording channel locations do not match." + attributes_match, exception_str = do_recording_attributes_match( + recording, self.rec_attributes, check_dtype=check_dtype + ) + if not attributes_match: + raise ValueError(exception_str) + if not np.array_equal(recording.get_channel_locations(), self.get_channel_locations()): + raise ValueError("Recording channel locations do not match.") if self._recording is not None: warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.") self._temporary_recording = recording @@ -763,9 +777,7 @@ def _save_or_select_or_merge( elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" - folder = Path(folder) - if folder.suffix != ".zarr": - folder = folder.parent / f"{folder.stem}.zarr" + folder = clean_zarr_folder_name(folder) SortingAnalyzer.create_zarr( folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes ) @@ -824,6 +836,8 @@ def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": format : "memory" | "binary_folder" | "zarr", default: "memory" The new backend format to use """ + if format == "zarr": + folder = clean_zarr_folder_name(folder) return self._save_or_select_or_merge(format=format, folder=folder) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": @@ -849,6 +863,8 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyz The newly create sorting_analyzer with the selected units """ # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! + if format == "zarr": + folder = clean_zarr_folder_name(folder) return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "SortingAnalyzer": @@ -875,6 +891,8 @@ def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "Sortin """ # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] + if format == "zarr": + folder = clean_zarr_folder_name(folder) return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) def merge_units( @@ -933,6 +951,9 @@ def merge_units( The newly create `SortingAnalyzer` with the selected units """ + if format == "zarr": + folder = clean_zarr_folder_name(folder) + assert merging_mode in ["soft", "hard"], "Merging mode should be either soft or hard" if len(merge_unit_groups) == 0: @@ -1011,6 +1032,9 @@ def has_temporary_recording(self) -> bool: def is_sparse(self) -> bool: return self.sparsity is not None + def is_filtered(self) -> bool: + return self.rec_attributes["is_filtered"] + def get_sorting_provenance(self): """ Get the original sorting if possible otherwise return None @@ -1099,7 +1123,7 @@ def get_num_units(self) -> int: return self.sorting.get_num_units() ## extensions zone - def compute(self, input, save=True, extension_params=None, verbose=False, **kwargs): + def compute(self, input, save=True, extension_params=None, verbose=False, **kwargs) -> "AnalyzerExtension | None": """ Compute one extension or several extensiosn. Internally calls compute_one_extension() or compute_several_extensions() depending on the input type. @@ -1166,7 +1190,7 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar else: raise ValueError("SortingAnalyzer.compute() need str, dict or list") - def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs): + def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs) -> "AnalyzerExtension": """ Compute one extension. @@ -1209,11 +1233,7 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar print(f"Deleting {child}") self.delete_extension(child) - if extension_class.need_job_kwargs: - params, job_kwargs = split_job_kwargs(kwargs) - else: - params = kwargs - job_kwargs = {} + params, job_kwargs = split_job_kwargs(kwargs) # check dependencies if extension_class.need_recording: diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 7f55646b63..24373bd04d 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -94,6 +94,41 @@ def test_name_and_repr(): assert "Hz" in rec_str +def test_setting_properties(): + + num_channels = 5 + recording = generate_recording(num_channels=5, durations=[1.0]) + channel_ids = ["a", "b", "c", "d", "e"] + recording = recording.rename_channels(new_channel_ids=channel_ids) + + complete_values = ["value"] * num_channels + recording.set_property(key="a_property", values=complete_values) + + property_in_recording = recording.get_property("a_property") + expected_array = np.array(complete_values) + assert np.array_equal(property_in_recording, expected_array) + + # Set property with missing values + incomplete_values = ["value"] * (num_channels - 1) + recording.set_property(key="incomplete_property", ids=channel_ids[:-1], values=incomplete_values) + + property_in_recording = recording.get_property("incomplete_property") + expected_array = np.array(incomplete_values + [""]) # Spikeinterface defines missing values as empty strings + assert np.array_equal(property_in_recording, expected_array) + + # # Passs a missing value + # recording.set_property( + # key="missing_property", + # ids=channel_ids[:-1], + # values=incomplete_values, + # missing_value="missing", + # ) + + # property_in_recording = recording.get_property("missing_property") + # expected_array = np.array(incomplete_values + ["missing"]) + # assert np.array_equal(property_in_recording, expected_array) + + if __name__ == "__main__": test_check_if_memory_serializable() test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 682881af8a..9c354510ac 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -289,6 +289,15 @@ def test_BaseRecording(create_cache_folder): rec3 = load_extractor(folder) assert np.allclose(times1, rec3.get_times(1)) + # reset times + rec.reset_times() + for segm in range(num_seg): + time_info = rec.get_time_info(segment_index=segm) + assert not rec.has_time_vector(segment_index=segm) + assert time_info["t_start"] is None + assert time_info["time_vector"] is None + assert time_info["sampling_frequency"] == rec.sampling_frequency + # test 3d probe rec_3d = generate_recording(ndim=3, num_channels=30) locations_3d = rec_3d.get_property("location") diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index d83e4d76fc..23a1574f2a 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -17,6 +17,8 @@ get_channel_distances, get_noise_levels, order_channels_by_depth, + do_recording_attributes_match, + get_rec_attributes, ) @@ -300,6 +302,35 @@ def test_order_channels_by_depth(): assert np.array_equal(order_2d[::-1], order_2d_fliped) +def test_do_recording_attributes_match(): + recording = NoiseGeneratorRecording( + num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" + ) + rec_attributes = get_rec_attributes(recording) + do_match, _ = do_recording_attributes_match(recording, rec_attributes) + assert do_match + + rec_attributes = get_rec_attributes(recording) + rec_attributes["sampling_frequency"] = 1.0 + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert not do_match + assert "sampling_frequency" in exc + + # check dtype options + rec_attributes = get_rec_attributes(recording) + rec_attributes["dtype"] = "int16" + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert not do_match + assert "dtype" in exc + do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_dtype=False) + assert do_match + + # check missing dtype + rec_attributes.pop("dtype") + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert do_match + + if __name__ == "__main__": # Create a temporary folder using the standard library import tempfile diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index d89eb7fac0..689073d6bf 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -141,7 +141,7 @@ def test_SortingAnalyzer_tmp_recording(dataset): recording_sliced = recording.channel_slice(recording.channel_ids[:-1]) # wrong channels - with pytest.raises(AssertionError): + with pytest.raises(ValueError): sorting_analyzer.set_temporary_recording(recording_sliced) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 1b570091be..a129316ee7 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -49,19 +49,14 @@ def _get_time_vector_recording(self, raw_recording): spaced timeseries data. Return the original recording, recoridng with time vectors added and list including the added time vectors. """ - times_recording = copy.deepcopy(raw_recording) + times_recording = raw_recording.clone() all_time_vectors = [] for segment_index in range(raw_recording.get_num_segments()): t_start = segment_index + 1 * 100 + t_stop = t_start + raw_recording.get_duration(segment_index) + segment_index + 1 - some_small_increasing_numbers = np.arange(times_recording.get_num_samples(segment_index)) * ( - 1 / times_recording.get_sampling_frequency() - ) - - offsets = np.cumsum(some_small_increasing_numbers) - time_vector = t_start + times_recording.get_times(segment_index) + offsets - + time_vector = np.linspace(t_start, t_stop, raw_recording.get_num_samples(segment_index)) all_time_vectors.append(time_vector) times_recording.set_times(times=time_vector, segment_index=segment_index) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 98380e955f..3affd7f0ec 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -791,7 +791,7 @@ def estimate_templates_with_accumulator( return_std: bool = False, verbose: bool = False, **job_kwargs, -): +) -> np.ndarray: """ This is a fast implementation to compute template averages and standard deviations. This is useful to estimate sparsity without the need to allocate large waveform buffers. diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index da1f5a71f5..a50a56bf85 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -343,12 +343,37 @@ def get_template( return templates[0] +def load_sorting_analyzer_or_waveforms(folder, sorting=None): + """ + Load a SortingAnalyzer from either a newly saved SortingAnalyzer folder or an old WaveformExtractor folder. + + Parameters + ---------- + folder: str | Path + The folder to the sorting analyzer or waveform extractor + sorting: BaseSorting | None, default: None + The sorting object to instantiate with the SortingAnalyzer (only used for old WaveformExtractor) + + Returns + ------- + sorting_analyzer: SortingAnalyzer + The returned SortingAnalyzer. + """ + folder = Path(folder) + if folder.suffix == ".zarr": + return load_sorting_analyzer(folder) + elif (folder / "spikeinterface_info.json").exists(): + return load_sorting_analyzer(folder) + else: + return load_waveforms(folder, sorting=sorting, output="SortingAnalyzer") + + def load_waveforms( folder, with_recording: bool = True, sorting: Optional[BaseSorting] = None, output="MockWaveformExtractor", -): +) -> MockWaveformExtractor | SortingAnalyzer: """ This read an old WaveformsExtactor folder (folder or zarr) and convert it into a SortingAnalyzer or MockWaveformExtractor. diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 1b9637e097..17f1ac08b3 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -66,7 +66,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) time_kwargs = {} time_vector = self._root.get(f"times_seg{segment_index}", None) if time_vector is not None: - time_kwargs["time_vector"] = time_vector[:] + time_kwargs["time_vector"] = time_vector else: if t_starts is None: t_start = None diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 920d6713ad..19336e5943 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -98,6 +98,7 @@ def get_potential_auto_merge( * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" + If `preset` is None, you can specify the steps manually with the `steps` parameter. resolve_graph : bool, default: False If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. @@ -145,6 +146,8 @@ def get_potential_auto_merge( Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" Please check steps explanations above! + presence_distance_kwargs : None|dict, default: None + A dictionary of kwargs to be passed to compute_presence_distance(). Returns ------- @@ -361,6 +364,9 @@ def get_potential_auto_merge( ind1, ind2 = np.nonzero(pair_mask) potential_merges = list(zip(unit_ids[ind1], unit_ids[ind2])) + # some methods return identities ie (1,1) which we can cleanup first. + potential_merges = [(ids[0], ids[1]) for ids in potential_merges if ids[0] != ids[1]] + if resolve_graph: potential_merges = resolve_merging_graph(sorting, potential_merges) @@ -512,7 +518,7 @@ def smooth_correlogram(correlograms, bins, sigma_smooth_ms=0.6): return correlograms_smoothed -def get_unit_adaptive_window(auto_corr: np.ndarray, threshold: float): +def get_unit_adaptive_window(auto_corr: np.ndarray, threshold: float) -> int: """ Computes an adaptive window to correlogram (basically corresponds to the first peak). Based on a minimum threshold and minimum of second derivative. @@ -754,7 +760,7 @@ def compute_presence_distance(sorting, pair_mask, num_samples=None, **presence_d Returns ------- - potential_merges : list + potential_merges : NDArray The list of potential merges """ diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 88190a9bab..5f85538b08 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -87,18 +87,22 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo Returns ------- - curation_dict: dict + curation_dict : dict A curation dictionary """ assert destination_format == "1" - + if "mergeGroups" not in sortingview_dict.keys(): + sortingview_dict["mergeGroups"] = [] merge_groups = sortingview_dict["mergeGroups"] merged_units = sum(merge_groups, []) - if len(merged_units) > 0: - unit_id_type = int if isinstance(merged_units[0], int) else str + + first_unit_id = next(iter(sortingview_dict["labelsByUnit"].keys())) + if str.isdigit(first_unit_id): + unit_id_type = int else: unit_id_type = str + all_units = [] all_labels = [] manual_labels = [] @@ -138,7 +142,7 @@ def curation_label_to_vectors(curation_dict): Returns ------- - labels: dict of numpy vector + labels : dict of numpy vector """ unit_ids = list(curation_dict["unit_ids"]) @@ -289,7 +293,7 @@ def apply_curation( The Sorting object to apply merges. curation_dict : dict The curation dict. - censor_ms: float | None, default: None + censor_ms : float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of as the desired refractory period. If `censor_ms=None`, no spikes are discarded. new_id_strategy : "append" | "take_first", default: "append" @@ -297,17 +301,17 @@ def apply_curation( * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges - merging_mode : "soft" | "hard", default: "soft" + merging_mode : "soft" | "hard", default: "soft" How merges are performed for SortingAnalyzer. If the `merge_mode` is "soft" , merges will be approximated, with no reloading of the waveforms. This will lead to approximations. If `merge_mode` is "hard", recomputations are accurately performed, reloading waveforms if needed sparsity_overlap : float, default 0.75 - The percentage of overlap that units should share in order to accept merges. If this criteria is not - achieved, soft merging will not be possible and an error will be raised. This is for use with a SortingAnalyzer input. - - verbose: - - **job_kwargs + The percentage of overlap that units should share in order to accept merges. If this criteria is not + achieved, soft merging will not be possible and an error will be raised. This is for use with a SortingAnalyzer input. + verbose : bool, default: False + If True, output is verbose + **job_kwargs : dict + Job keyword arguments for `merge_units` Returns ------- diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index 408a666613..3402638a16 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -106,18 +106,18 @@ def find_duplicated_spikes( Parameters ---------- - spike_train: np.ndarray + spike_train : np.ndarray The spike train on which to look for duplicated spikes. - censored_period: int + censored_period : int The censored period for duplicates (in sample time). - method: "keep_first" |"keep_last" | "keep_first_iterative" | "keep_last_iterative" |random", default: "random" + method : "keep_first" |"keep_last" | "keep_first_iterative" | "keep_last_iterative" |random", default: "random" Method used to remove the duplicated spikes. - seed: int | None + seed : int | None The seed to use if method="random". Returns ------- - indices_of_duplicates: np.ndarray + indices_of_duplicates : np.ndarray The indices of spikes considered to be duplicates. """ diff --git a/src/spikeinterface/curation/curationsorting.py b/src/spikeinterface/curation/curationsorting.py index 702bb587f7..b4afeab547 100644 --- a/src/spikeinterface/curation/curationsorting.py +++ b/src/spikeinterface/curation/curationsorting.py @@ -18,7 +18,7 @@ class CurationSorting: Parameters ---------- - sorting: BaseSorting + sorting : BaseSorting The sorting object properties_policy : "keep" | "remove", default: "keep" Policy used to propagate properties after split and merge operation. If "keep" the properties will be @@ -26,6 +26,7 @@ class CurationSorting: an empty value for all the properties make_graph : bool True to keep a Networkx graph instance with the curation history + Returns ------- sorting : Sorting diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 11f26ea778..df5bb7446c 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -13,7 +13,7 @@ class MergeUnitsSorting(BaseSorting): Parameters ---------- - sorting: BaseSorting + sorting : BaseSorting The sorting object units_to_merge : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index d1d6b7f3cb..0d70e264a9 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -85,7 +85,7 @@ def get_unit_spike_train( return spike_train[min_spike:max_spike] -def remove_excess_spikes(sorting, recording): +def remove_excess_spikes(sorting: BaseSorting, recording: BaseRecording): """ Remove excess spikes from the spike trains. Excess spikes are the ones exceeding a recording number of samples, for each segment. diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index 874552f767..09c0b2f270 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -1,5 +1,6 @@ from __future__ import annotations import numpy as np +from spikeinterface import BaseSorting from spikeinterface import SortingAnalyzer @@ -20,7 +21,7 @@ def remove_redundant_units( remove_strategy="minimum_shift", peak_sign="neg", extra_outputs=False, -): +) -> BaseSorting: """ Removes redundant or duplicate units by comparing the sorting output with itself. @@ -58,6 +59,10 @@ def remove_redundant_units( Used when remove_strategy="highest_amplitude" extra_outputs : bool, default: False If True, will return the redundant pairs. + unit_peak_shifts : dict + Dictionary mapping the unit_id to the unit's shift (in number of samples). + A positive shift means the spike train is shifted back in time, while + a negative shift means the spike train is shifted forward. Returns ------- diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index 33c14dfe5a..0804f637a5 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -13,9 +13,9 @@ class SplitUnitSorting(BaseSorting): Parameters ---------- - sorting: BaseSorting + sorting : BaseSorting The sorting object - parent_unit_id : int + split_unit_id : int Unit id of the unit to split indices_list : list or np.array A list of index arrays selecting the spikes to split in each segment. @@ -28,6 +28,7 @@ class SplitUnitSorting(BaseSorting): Policy used to propagate properties. If "keep" the properties will be passed to the new units (if the units_to_merge have the same value). If "remove" the new units will have an empty value for all the properties of the new unit + Returns ------- sorting : Sorting diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json b/src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json new file mode 100644 index 0000000000..2a350340f3 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-no-merge.json @@ -0,0 +1 @@ +{"labelsByUnit":{"2":["mua"],"3":["mua"],"4":["mua"],"5":["accept"],"6":["accept"],"7":["accept"],"8":["artifact"],"9":["artifact"]}} diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index bb152e7f71..945aca7937 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -243,11 +243,23 @@ def test_label_inheritance_str(): assert np.all(sorting_include_accept.get_property("accept")) +def test_json_no_merge_curation(): + """ + Test curation with no merges using a JSON file. + """ + sorting = generate_sorting(num_units=10) + + json_file = parent_folder / "sv-sorting-curation-no-merge.json" + sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_file) + + if __name__ == "__main__": # generate_sortingview_curation_dataset() # test_sha1_curation() + test_gh_curation() test_json_curation() test_false_positive_curation() test_label_inheritance_int() test_label_inheritance_str() + test_json_no_merge_curation() diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 7b3c7daab0..06041da231 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -167,7 +167,7 @@ def export_to_phy( f.write(f"dtype = '{dtype_str}'\n") f.write(f"offset = 0\n") f.write(f"sample_rate = {fs}\n") - f.write(f"hp_filtered = {sorting_analyzer.recording.is_filtered()}") + f.write(f"hp_filtered = {sorting_analyzer.is_filtered()}") # export spike_times/spike_templates/spike_clusters # here spike_labels is a remapping to unit_index diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index b3f671ebf3..cf47b9819c 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -18,7 +18,7 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): folder_path : str or Path-like The folder path to the AlphaOmega recordings. lsx_files : list of strings or None, default: None - A list of listings files that refers to mpx files to load. + A list of files that refers to mpx files to load. stream_id : {"RAW", "LFP", "SPK", "ACC", "AI", "UD"}, default: "RAW" If there are several streams, specify the stream id you want to load. stream_name : str, default: None @@ -28,6 +28,12 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_alphaomega + >>> recording = read_alphaomega(folder_path="alphaomega_folder") + """ NeoRawIOClass = "AlphaOmegaRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index adfdccddd9..9de39bef2e 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -22,6 +22,11 @@ class AxonaRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_axona + >>> recording = read_axona(file_path=r'my_data.set') """ NeoRawIOClass = "AxonaRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index a42a2d75a5..992d1a8941 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -28,6 +28,11 @@ class CedRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_ced + >>> recording = read_ced(file_path=r'my_data.smr') """ NeoRawIOClass = "CedRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index f0a1894f25..261472ede9 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -34,7 +34,13 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): In Intan the ids provided by NeoRawIO are the hardware channel ids while the names are custom names given by the user - + Examples + -------- + >>> from spikeinterface.extractors import read_intan + # intan amplifier data is stored in stream_id = '0' + >>> recording = read_intan(file_path=r'my_data.rhd', stream_id='0') + # intan has multi-file formats as well, but in this case our path should point to the header file 'info.rhd' + >>> recording = read_intan(file_path=r'info.rhd', stream_id='0') """ NeoRawIOClass = "IntanRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 58110cf7aa..04e41433e1 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -32,7 +32,7 @@ class MaxwellRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the - names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. rec_name : str, default: None When the file contains several recordings you need to specify the one you want to extract. (rec_name='rec0000'). diff --git a/src/spikeinterface/extractors/neoextractors/neo_utils.py b/src/spikeinterface/extractors/neoextractors/neo_utils.py index fa7916e774..ec6aae06c9 100644 --- a/src/spikeinterface/extractors/neoextractors/neo_utils.py +++ b/src/spikeinterface/extractors/neoextractors/neo_utils.py @@ -28,7 +28,7 @@ def get_neo_streams(extractor_name, *args, **kwargs): return neo_extractor.get_streams(*args, **kwargs) -def get_neo_num_blocks(extractor_name, *args, **kwargs): +def get_neo_num_blocks(extractor_name, *args, **kwargs) -> int: """Returns the number of NEO blocks. For multi-block datasets, the `block_index` argument can be used to select which bloack to read with the `read_**extractor_name**()` function. diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index 0adddc2439..412027bc06 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -25,11 +25,16 @@ class PlexonRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. use_names_as_ids : bool, default: True Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the - names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. Example for wideband signals: names: ["WB01", "WB02", "WB03", "WB04"] ids: ["0" , "1", "2", "3"] + + Examples + -------- + >>> from spikeinterface.extractors import read_plexon + >>> recording = read_plexon(file_path=r'my_data.plx') """ NeoRawIOClass = "PlexonRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 4434d02cc1..2f360ed864 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -28,6 +28,11 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): ids: ["source3.1" , "source3.2", "source3.3", "source3.4"] all_annotations : bool, default: False Load exhaustively all annotations from neo. + + Examples + -------- + >>> from spikeinterface.extractors import read_plexon2 + >>> recording = read_plexon2(file_path=r'my_data.pl2') """ NeoRawIOClass = "Plexon2RawIO" diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index 89c457a573..e91a81398b 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -29,6 +29,11 @@ class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_spikegadgets + >>> recording = read_spikegadgets(file_path=r'my_data.rec') """ NeoRawIOClass = "SpikeGadgetsRawIO" diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index cfe20bbfa6..874a65c045 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -41,6 +41,13 @@ class SpikeGLXRecordingExtractor(NeoBaseRecordingExtractor): use_names_as_ids : bool, default: False Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Examples + -------- + >>> from spikeinterface.extractors import read_spikeglx + >>> recording = read_spikeglx(folder_path=r'path_to_folder_with_data', load_sync_channel=False) + # we can load the sync channel, but then the probe is not loaded + >>> recording = read_spikeglx(folder_path=r'pat_to_folder_with_data', load_sync_channel=True) """ NeoRawIOClass = "SpikeGLXRawIO" diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 7ffbc166de..bc143ff33a 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -153,13 +153,14 @@ def __init__( self.extra_requirements.append("pandas") for prop_name in cluster_info.columns: - if prop_name in ["chan_grp", "ch_group"]: + if prop_name in ["chan_grp", "ch_group", "channel_group"]: self.set_property(key="group", values=cluster_info[prop_name]) elif prop_name == "cluster_id": self.set_property(key="original_cluster_id", values=cluster_info[prop_name]) elif prop_name == "group": # rename group property to 'quality' - self.set_property(key="quality", values=cluster_info[prop_name]) + values = cluster_info[prop_name].values.astype("str") + self.set_property(key="quality", values=values) else: if load_all_cluster_properties: # pandas loads strings with empty values as objects with NaNs diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index c7c2dfacae..c79627bb59 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -19,12 +19,15 @@ def setUpClass(cls): from one.api import ONE cls.eid = EID - cls.one = ONE( - base_url="https://openalyx.internationalbrainlab.org", - password="international", - silent=True, - cache_dir=None, - ) + try: + cls.one = ONE( + base_url="https://openalyx.internationalbrainlab.org", + password="international", + silent=True, + cache_dir=None, + ) + except: + pytest.skip("Skipping test due to server being down.") try: cls.recording = read_ibl_recording(eid=cls.eid, stream_name="probe00.ap", one=cls.one) except requests.exceptions.HTTPError as e: @@ -109,12 +112,15 @@ def setUpClass(cls): from one.api import ONE cls.eid = "e2b845a1-e313-4a08-bc61-a5f662ed295e" - cls.one = ONE( - base_url="https://openalyx.internationalbrainlab.org", - password="international", - silent=True, - cache_dir=None, - ) + try: + cls.one = ONE( + base_url="https://openalyx.internationalbrainlab.org", + password="international", + silent=True, + cache_dir=None, + ) + except: + pytest.skip("Skipping test due to server being down.") cls.recording = read_ibl_recording(eid=cls.eid, stream_name="probe00.ap", load_sync_channel=True, one=cls.one) cls.small_scaled_trace = cls.recording.get_traces(start_frame=5, end_frame=26, return_scaled=True) cls.small_unscaled_trace = cls.recording.get_traces( @@ -182,12 +188,15 @@ def test_ibl_sorting_extractor(self): """ from one.api import ONE - one = ONE( - base_url="https://openalyx.internationalbrainlab.org", - password="international", - silent=True, - cache_dir=None, - ) + try: + one = ONE( + base_url="https://openalyx.internationalbrainlab.org", + password="international", + silent=True, + cache_dir=None, + ) + except: + pytest.skip("Skipping test due to server being down.") sorting = read_ibl_sorting(pid=PID, one=one) assert len(sorting.unit_ids) == 733 sorting_good = read_ibl_sorting(pid=PID, good_clusters_only=True) diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 0e4f1985c6..9d28340352 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -40,9 +40,13 @@ def interpolate_templates(templates_array, source_locations, dest_locations, int source_locations = np.asarray(source_locations) dest_locations = np.asarray(dest_locations) if dest_locations.ndim == 2: - new_shape = templates_array.shape + new_shape = (*templates_array.shape[:2], len(dest_locations)) elif dest_locations.ndim == 3: - new_shape = (dest_locations.shape[0],) + templates_array.shape + new_shape = ( + dest_locations.shape[0], + *templates_array.shape[:2], + dest_locations.shape[1], + ) else: raise ValueError(f"Incorrect dimensions for dest_locations: {dest_locations.ndim}. Dimensions can be 2 or 3. ") @@ -116,6 +120,16 @@ class DriftingTemplates(Templates): * move every templates on-the-fly, this lead to one interpolation per spike * precompute some displacements for all templates and use a discreate interpolation, for instance by step of 1um This is the same strategy used by MEArec. + + Parameters + ---------- + templates_array_moved : np.array + Shape is (num_displacement, num_templates, num_samples, num_channels) + displacements : np.array + Displacement vector + shape : (num_displacement, 2) + **static_kwargs : dict + Keyword arguments for `Templates` """ def __init__(self, templates_array_moved=None, displacements=None, **static_kwargs): @@ -306,6 +320,8 @@ class InjectDriftingTemplatesRecording(BaseRecording): If None, no amplitude scaling is applied. If scalar all spikes have the same factor (certainly useless). If vector, it must have the same size as the spike vector. + mode : str, default: "precompute" + Mode for how to compute templates. Returns ------- diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index b439c57c52..6ff8adadd2 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -194,6 +194,8 @@ def generate_displacement_vector( motion_list : list of dict List of dicts containing individual motion vector parameters. len(motion_list) == displacement_vectors.shape[2] + seed : None | seed, default: None + Random seed for `make_one_displacement_vector` Returns ------- @@ -348,9 +350,6 @@ def generate_drifting_recording( This can be helpfull for motion benchmark. """ - - rng = np.random.default_rng(seed=seed) - # probe if generate_probe_kwargs is None: generate_probe_kwargs = _toy_probes[probe_name] diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 0c82e496c0..747389a6d7 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -42,6 +42,8 @@ def estimate_templates_from_recording( Parameters ---------- + recording : BaseRecording + The recording to get temaples from. ms_before : float The time before peaks of templates. ms_after : float @@ -181,6 +183,8 @@ def scale_template_to_range( The minimum amplitude of the output templates after scaling. max_amplitude : float The maximum amplitude of the output templates after scaling. + amplitude_function : "ptp" | "min" | "max", default: "ptp" + The function to use to compute the amplitude of the templates. Can be "ptp", "min" or "max". Returns ------- @@ -356,10 +360,6 @@ def generate_hybrid_recording( are_templates_scaled : bool, default: True If True, the templates are assumed to be in uV, otherwise in the same unit as the recording. In case the recording has scaling, the templates are "unscaled" before injection. - ms_before : float, default: 1.5 - Cut out in ms before spike peak. - ms_after : float, default: 3 - Cut out in ms after spike peak. unit_locations : np.array, default: None The locations at which the templates should be injected. If not provided, generated (see generate_unit_location_kwargs). diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py index 17d2bdf521..6d094adf11 100644 --- a/src/spikeinterface/generation/template_database.py +++ b/src/spikeinterface/generation/template_database.py @@ -71,6 +71,8 @@ def query_templates_from_database(template_df: "pandas.DataFrame", verbose: bool ---------- template_df : pd.DataFrame Dataframe containing the template information, obtained by slicing/querying the output of fetch_templates_info. + verbose : bool, default: False + if True, output is verbose Returns ------- diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 0d57aec21e..298953c94a 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -456,7 +456,7 @@ def find_collisions(spikes, spikes_within_margin, delta_collision_samples, spars Returns ------- - collision_spikes_dict: np.array + collision_spikes_dict: dict A dictionary with collisions. The key is the index of the spike with collision, the value is an array of overlapping spikes, including the spike itself at position 0. """ diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 3c65f2075c..ba12a5c462 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -20,8 +20,8 @@ class ComputeCorrelograms(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer - A SortingAnalyzer object + sorting_analyzer_or_sorting : SortingAnalyzer | Sorting + A SortingAnalyzer or Sorting object window_ms : float, default: 50.0 The window around the spike to compute the correlation in ms. For example, if 50 ms, the correlations will be computed at lags -25 ms ... 25 ms. @@ -137,7 +137,7 @@ def compute_correlograms( compute_correlograms.__doc__ = compute_correlograms_sorting_analyzer.__doc__ -def _make_bins(sorting, window_ms, bin_ms): +def _make_bins(sorting, window_ms, bin_ms) -> tuple[np.ndarray, int, int]: """ Create the bins for the correlogram, in samples. diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index fa919e11e2..542f829f21 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -17,7 +17,7 @@ class ComputeISIHistograms(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object window_ms : float, default: 50 The window in ms diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 99c60a5043..f1f89403c7 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -25,21 +25,21 @@ class ComputePrincipalComponents(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - n_components: int, default: 5 + n_components : int, default: 5 Number of components fo PCA - mode: "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" + mode : "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" The PCA mode: - "by_channel_local": a local PCA is fitted for each channel (projection by channel) - "by_channel_global": a global PCA is fitted for all channels (projection by channel) - "concatenated": channels are concatenated and a global PCA is fitted - sparsity: ChannelSparsity or None, default: None + sparsity : ChannelSparsity or None, default: None The sparsity to apply to waveforms. If sorting_analyzer is already sparse, the default sparsity will be used - whiten: bool, default: True + whiten : bool, default: True If True, waveforms are pre-whitened - dtype: dtype, default: "float32" + dtype : dtype, default: "float32" Dtype of the pc scores Examples diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index e82a9e61e4..2efac0e0d0 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -22,13 +22,13 @@ class ComputeSpikeAmplitudes(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object ms_before : float, default: 0.5 The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - spike_retriver_kwargs: dict + spike_retriver_kwargs : dict A dictionary to control the behavior for getting the maximum channel for each spike This dictionary contains: @@ -42,10 +42,10 @@ class ComputeSpikeAmplitudes(AnalyzerExtension): In case channel_from_template=False, this is the peak sign. method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The localization method to use - method_kwargs : dict, default: dict() - Other kwargs depending on the method. - outputs : "concatenated" | "by_unit", default: "concatenated" - The output format + **method_kwargs : dict, default: {} + Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. + outputs : "numpy" | "by_unit", default: "numpy" + The output format, either concatenated as numpy array or separated on a per unit basis Returns ------- @@ -148,7 +148,7 @@ def _get_data(self, outputs="numpy"): amplitudes_by_units[segment_index][unit_id] = all_amplitudes[inds] return amplitudes_by_units else: - raise ValueError(f"Wrong .get_data(outputs={outputs})") + raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`") register_result_extension(ComputeSpikeAmplitudes) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 53e55b4d1f..6995fc04da 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -17,13 +17,13 @@ class ComputeSpikeLocations(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object ms_before : float, default: 0.5 The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - spike_retriver_kwargs: dict + spike_retriver_kwargs : dict A dictionary to control the behavior for getting the maximum channel for each spike This dictionary contains: diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index eef2a2f32c..45ba55dee4 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -8,11 +8,9 @@ import numpy as np import warnings -from typing import Optional from copy import deepcopy from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension -from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.template_tools import get_dense_templates_array @@ -50,7 +48,7 @@ class ComputeTemplateMetrics(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer The SortingAnalyzer object metric_names : list or None, default: None List of metrics to compute (see si.postprocessing.get_template_metric_names()) @@ -58,13 +56,13 @@ class ComputeTemplateMetrics(AnalyzerExtension): Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. upsampling_factor : int, default: 10 The upsampling factor to upsample the templates - sparsity: ChannelSparsity or None, default: None + sparsity : ChannelSparsity or None, default: None If None, template metrics are computed on the extremum channel only. If sparsity is given, template metrics are computed on all sparse channels of each unit. For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. - include_multi_channel_metrics: bool, default: False + include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics - metrics_kwargs: dict + metrics_kwargs : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 @@ -77,9 +75,9 @@ class ComputeTemplateMetrics(AnalyzerExtension): * spread_threshold: the threshold to compute the spread, default: 0.2 * spread_smooth_um: the smoothing in um to compute the spread, default: 20 * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None - - If None, all channels all channels are considered - - If 0 or 1, only the "column" that includes the max channel is considered - - If > 1, only channels within range (+/-) um from the max channel horizontal position are used + - If None, all channels all channels are considered + - If 0 or 1, only the "column" that includes the max channel is considered + - If > 1, only channels within range (+/-) um from the max channel horizontal position are used Returns ------- @@ -238,13 +236,17 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job for metric_name in metrics_single_channel: func = _metric_name_to_func[metric_name] - value = func( - template_upsampled, - sampling_frequency=sampling_frequency_up, - trough_idx=trough_idx, - peak_idx=peak_idx, - **self.params["metrics_kwargs"], - ) + try: + value = func( + template_upsampled, + sampling_frequency=sampling_frequency_up, + trough_idx=trough_idx, + peak_idx=peak_idx, + **self.params["metrics_kwargs"], + ) + except Exception as e: + warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") + value = np.nan template_metrics.at[index, metric_name] = value # compute metrics multi_channel @@ -274,12 +276,16 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job sampling_frequency_up = sampling_frequency func = _metric_name_to_func[metric_name] - value = func( - template_upsampled, - channel_locations=channel_locations_sparse, - sampling_frequency=sampling_frequency_up, - **self.params["metrics_kwargs"], - ) + try: + value = func( + template_upsampled, + channel_locations=channel_locations_sparse, + sampling_frequency=sampling_frequency_up, + **self.params["metrics_kwargs"], + ) + except Exception as e: + warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") + value = np.nan template_metrics.at[index, metric_name] = value return template_metrics @@ -337,7 +343,7 @@ def get_trough_and_peak_idx(template): ######################################################################################### # Single-channel metrics -def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: """ Return the peak to valley duration in seconds of input waveforms. @@ -363,7 +369,7 @@ def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, pea return ptv -def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs): +def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs) -> float: """ Return the peak to trough ratio of input waveforms. @@ -389,7 +395,7 @@ def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=N return ptratio -def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): +def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: """ Return the half width of input waveforms in seconds. @@ -889,7 +895,7 @@ def exp_decay(x, decay, amp0, offset): return exp_decay_value -def get_spread(template, channel_locations, sampling_frequency, **kwargs): +def get_spread(template, channel_locations, sampling_frequency, **kwargs) -> float: """ Compute the spread of the template amplitude over distance in units um/s. diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cb4cc323ad..0e70b1f494 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -13,25 +13,22 @@ class ComputeTemplateSimilarity(AnalyzerExtension): Similarity is defined as 1 - distance(T_1, T_2) for two templates T_1, T_2 - Parameters ---------- sorting_analyzer : SortingAnalyzer The SortingAnalyzer object - method : str, default: "cosine" - The method to compute the similarity. Can be in ["cosine", "l2", "l1"] + method : "cosine" | "l1" | "l2", default: "cosine" + The method to compute the similarity. + In case of "l1" or "l2", the formula used is: + - similarity = 1 - norm(T_1 - T_2)/(norm(T_1) + norm(T_2)). + In case of cosine it is: + - similarity = 1 - sum(T_1.T_2)/(norm(T_1)norm(T_2)). max_lag_ms : float, default: 0 If specified, the best distance for all given lag within max_lag_ms is kept, for every template support : "dense" | "union" | "intersection", default: "union" Support that should be considered to compute the distances between the templates, given their sparsities. Can be either ["dense", "union", "intersection"] - In case of "l1" or "l2", the formula used is: - similarity = 1 - norm(T_1 - T_2)/(norm(T_1) + norm(T_2)) - - In case of cosine this is: - similarity = 1 - sum(T_1.T_2)/(norm(T_1)norm(T_2)) - Returns ------- similarity: np.array @@ -153,7 +150,6 @@ def _get_data(self): def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): - import sklearn.metrics.pairwise if method == "cosine_similarity": @@ -226,15 +222,17 @@ def compute_similarity_with_templates_array( if method == "l1": norm_i = np.sum(np.abs(src)) norm_j = np.sum(np.abs(tgt)) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1") + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item() distances[count, i, j] /= norm_i + norm_j elif method == "l2": norm_i = np.linalg.norm(src, ord=2) norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2") + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item() distances[count, i, j] /= norm_i + norm_j else: - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="cosine") + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances( + src, tgt, metric="cosine" + ).item() if same_array: distances[count, j, i] = distances[count, i, j] diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 818f0a8062..4029fc88c7 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -24,16 +24,16 @@ class ComputeUnitLocations(AnalyzerExtension): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - method: "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The method to use for localization - method_kwargs: dict, default: {} - Other kwargs depending on the method + **method_kwargs : dict, default: {} + Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. Returns ------- - unit_locations: np.array + unit_locations : np.array unit location with shape (num_unit, 2) or (num_unit, 3) or (num_unit, 3) (with alpha) """ @@ -94,7 +94,7 @@ def _run(self, verbose=False): method_kwargs.pop("method") if method not in _unit_location_methods: - raise ValueError(f"Wrong ethod for unit_locations : it should be in {list(_unit_location_methods.keys())}") + raise ValueError(f"Wrong method for unit_locations : it should be in {list(_unit_location_methods.keys())}") func = _unit_location_methods[method] self.data["unit_locations"] = func(self.sorting_analyzer, **method_kwargs) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 54c5ab2b2d..a67d163d3d 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -24,10 +24,12 @@ class FilterRecording(BasePreprocessor): """ - Generic filter class based on: - - * scipy.signal.iirfilter - * scipy.signal.filtfilt or scipy.signal.sosfilt + A generic filter class based on: + For filter coefficient generation: + * scipy.signal.iirfilter + For filter application: + * scipy.signal.filtfilt or scipy.signal.sosfiltfilt when direction = "forward-backward" + * scipy.signal.lfilter or scipy.signal.sosfilt when direction = "forward" or "backward" BandpassFilterRecording is built on top of it. @@ -56,6 +58,11 @@ class FilterRecording(BasePreprocessor): - numerator/denominator : ("ba") ftype : str, default: "butter" Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + direction : "forward" | "backward" | "forward-backward", default: "forward-backward" + Direction of filtering: + - "forward" - filter is applied to the timeseries in one direction, creating phase shifts + - "backward" - the timeseries is reversed, the filter is applied and filtered timeseries reversed again. Creates phase shifts in the opposite direction to "forward" + - "forward-backward" - Applies the filter in the forward and backward direction, resulting in zero-phase filtering. Note this doubles the effective filter order. Returns ------- @@ -75,6 +82,7 @@ def __init__( add_reflect_padding=False, coeff=None, dtype=None, + direction="forward-backward", ): import scipy.signal @@ -106,7 +114,13 @@ def __init__( for parent_segment in recording._recording_segments: self.add_recording_segment( FilterRecordingSegment( - parent_segment, filter_coeff, filter_mode, margin, dtype, add_reflect_padding=add_reflect_padding + parent_segment, + filter_coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=add_reflect_padding, + direction=direction, ) ) @@ -121,14 +135,25 @@ def __init__( margin_ms=margin_ms, add_reflect_padding=add_reflect_padding, dtype=dtype.str, + direction=direction, ) class FilterRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, add_reflect_padding=False): + def __init__( + self, + parent_recording_segment, + coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=False, + direction="forward-backward", + ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff self.filter_mode = filter_mode + self.direction = direction self.margin = margin self.add_reflect_padding = add_reflect_padding self.dtype = dtype @@ -150,11 +175,24 @@ def get_traces(self, start_frame, end_frame, channel_indices): import scipy.signal - if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) - elif self.filter_mode == "ba": - b, a = self.coeff - filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + if self.direction == "forward-backward": + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + else: + if self.direction == "backward": + traces_chunk = np.flip(traces_chunk, axis=0) + + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.lfilter(b, a, traces_chunk, axis=0) + + if self.direction == "backward": + filtered_traces = np.flip(filtered_traces, axis=0) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] @@ -289,6 +327,73 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") + +def causal_filter( + recording, + direction="forward", + band=[300.0, 6000.0], + btype="bandpass", + filter_order=5, + ftype="butter", + filter_mode="sos", + margin_ms=5.0, + add_reflect_padding=False, + coeff=None, + dtype=None, +): + """ + Generic causal filter built on top of the filter function. + + Parameters + ---------- + recording : Recording + The recording extractor to be re-referenced + direction : "forward" | "backward", default: "forward" + Direction of causal filter. The "backward" option flips the traces in time before applying the filter + and then flips them back. + band : float or list, default: [300.0, 6000.0] + If float, cutoff frequency in Hz for "highpass" filter type + If list. band (low, high) in Hz for "bandpass" filter type + btype : "bandpass" | "highpass", default: "bandpass" + Type of the filter + margin_ms : float, default: 5.0 + Margin in ms on border to avoid border effect + coeff : array | None, default: None + Filter coefficients in the filter_mode form. + dtype : dtype or None, default: None + The dtype of the returned traces. If None, the dtype of the parent recording is used + add_reflect_padding : Bool, default False + If True, uses a left and right margin during calculation. + filter_order : order + The order of the filter for `scipy.signal.iirfilter` + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients for `scipy.signal.iirfilter`: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + + Returns + ------- + filter_recording : FilterRecording + The causal-filtered recording extractor object + """ + assert direction in ["forward", "backward"], "Direction must be either 'forward' or 'backward'" + return filter( + recording=recording, + direction=direction, + band=band, + btype=btype, + filter_order=filter_order, + ftype=ftype, + filter_mode=filter_mode, + margin_ms=margin_ms, + add_reflect_padding=add_reflect_padding, + coeff=coeff, + dtype=dtype, + ) + + bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 8d1f9bc9f3..ddb981a944 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -216,6 +216,11 @@ def get_motion_presets(): def get_motion_parameters_preset(preset): """ Get the parameters tree for a given preset for motion correction. + + Parameters + ---------- + preset : str, default: None + The preset name. See available presets using `spikeinterface.preprocessing.get_motion_presets()`. """ preset_params = copy.deepcopy(motion_options_preset[preset]) all_default_params = _get_default_motion_params() diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 149c6eb458..bdf5f2219c 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -12,6 +12,7 @@ notch_filter, HighpassFilterRecording, highpass_filter, + causal_filter, ) from .filter_gaussian import GaussianFilterRecording, gaussian_filter from .normalize_scale import ( diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 68790b3273..9df60af3db 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -4,7 +4,140 @@ from spikeinterface.core import generate_recording from spikeinterface import NumpyRecording, set_global_tmp_folder -from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter +from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter, causal_filter + + +class TestCausalFilter: + """ + The only thing that is not tested (JZ, as of 23/07/2024) is the + propagation of margin kwargs, these are general filter params + and can be tested in an upcoming PR. + """ + + @pytest.fixture(scope="session") + def recording_and_data(self): + recording = generate_recording(durations=[1]) + raw_data = recording.get_traces() + + return (recording, raw_data) + + def test_causal_filter_main_kwargs(self, recording_and_data): + """ + Perform a test that expected output is returned under change + of all key filter-related kwargs. First run the filter in + the forward direction with options and compare it + to the expected output from scipy. + + Next, change every filter-related kwarg and set in the backwards + direction. Again check it matches expected scipy output. + """ + from scipy.signal import lfilter, sosfilt + + recording, raw_data = recording_and_data + + # First, check in the forward direction with + # the default set of kwargs + options = self._get_filter_options() + + sos = self._run_iirfilter(options, recording) + + test_data = sosfilt(sos, raw_data, axis=0) + test_data.astype(recording.dtype) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + + # Then, change all kwargs to ensure they are propagated + # and check the backwards version. + options["band"] = [671] + options["btype"] = "highpass" + options["filter_order"] = 8 + options["ftype"] = "bessel" + options["filter_mode"] = "ba" + options["dtype"] = np.float16 + + b, a = self._run_iirfilter(options, recording) + + flip_raw = np.flip(raw_data, axis=0) + test_data = lfilter(b, a, flip_raw, axis=0) + test_data = np.flip(test_data, axis=0) + test_data = test_data.astype(options["dtype"]) + + filt_data = causal_filter(recording, direction="backward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + + def test_causal_filter_custom_coeff(self, recording_and_data): + """ + A different path is taken when custom coeff is selected. + Therefore, explicitly test the expected outputs are obtained + when passing custom coeff, under the "ba" and "sos" conditions. + """ + from scipy.signal import lfilter, sosfilt + + recording, raw_data = recording_and_data + + options = self._get_filter_options() + options["filter_mode"] = "ba" + options["coeff"] = (np.array([0.1, 0.2, 0.3]), np.array([0.4, 0.5, 0.6])) + + # Check the custom coeff are propagated in both modes. + # First, in "ba" mode + test_data = lfilter(options["coeff"][0], options["coeff"][1], raw_data, axis=0) + test_data = test_data.astype(recording.get_dtype()) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + + # Next, in "sos" mode + options["filter_mode"] = "sos" + options["coeff"] = np.ones((2, 6)) + + test_data = sosfilt(options["coeff"], raw_data, axis=0) + test_data = test_data.astype(recording.get_dtype()) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + + def test_causal_kwarg_error_raised(self, recording_and_data): + """ + Test that passing the "forward-backward" direction results in + an error. It is is critical this error is raised, + otherwise the filter will no longer be causal. + """ + recording, raw_data = recording_and_data + + with pytest.raises(BaseException) as e: + filt_data = causal_filter(recording, direction="forward-backward") + + def _run_iirfilter(self, options, recording): + """ + Convenience function to convert Si kwarg + names to Scipy. + """ + from scipy.signal import iirfilter + + return iirfilter( + N=options["filter_order"], + Wn=options["band"], + btype=options["btype"], + ftype=options["ftype"], + output=options["filter_mode"], + fs=recording.get_sampling_frequency(), + ) + + def _get_filter_options(self): + return { + "band": [300.0, 6000.0], + "btype": "bandpass", + "filter_order": 5, + "ftype": "butter", + "filter_mode": "sos", + "coeff": None, + } def test_filter(): @@ -28,6 +161,8 @@ def test_filter(): # other filtering types rec3 = filter(rec, band=500.0, btype="highpass", filter_mode="ba", filter_order=2) rec4 = notch_filter(rec, freq=3000, q=30, margin_ms=5.0) + rec5 = causal_filter(rec, direction="forward") + rec6 = causal_filter(rec, direction="backward") # filter from coefficients from scipy.signal import iirfilter diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 7465d58737..2de31ad750 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -69,7 +69,7 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): return num_spikes -def compute_firing_rates(sorting_analyzer, unit_ids=None, **kwargs): +def compute_firing_rates(sorting_analyzer, unit_ids=None): """ Compute the firing rate across segments. @@ -98,7 +98,7 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None, **kwargs): return firing_rates -def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None, **kwargs): +def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -117,7 +117,7 @@ def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio Returns ------- - presence_ratio : dict of flaots + presence_ratio : dict of floats The presence ratio for each unit ID. Notes @@ -529,7 +529,7 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): Returns ------- - synchrony_counts : dict + synchrony_counts : np.ndarray The synchrony counts for the synchrony sizes. References @@ -620,7 +620,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ _default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) -def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): +def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -1437,6 +1437,8 @@ def compute_sd_ratio( In this case, noise refers to the global voltage trace on the same channel as the best channel of the unit. (ideally (not implemented yet), the noise would be computed outside of spikes from the unit itself). + TODO: Take jitter into account. + Parameters ---------- sorting_analyzer : SortingAnalyzer @@ -1450,9 +1452,8 @@ def compute_sd_ratio( and will make a rough estimation of what that impact is (and remove it). unit_ids : list or None, default: None The list of unit ids to compute this metric. If None, all units are used. - **kwargs: + **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. - TODO: Take jitter into account. Returns ------- diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 1c5a491bf8..7c099a2f74 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -56,7 +56,7 @@ def compute_pc_metrics( seed=None, n_jobs=1, progress_bar=False, -): +) -> dict: """ Calculate principal component derived metrics. @@ -295,7 +295,7 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id): return isolation_distance, l_ratio -def lda_metrics(all_pcs, all_labels, this_unit_id): +def lda_metrics(all_pcs, all_labels, this_unit_id) -> float: """ Calculate d-prime based on Linear Discriminant Analysis. diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 0c7cf25237..cdf6151e95 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -164,7 +164,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]: if not sorting_analyzer.has_extension("principal_components"): - raise ValueError("waveform_principal_component must be provied") + raise ValueError( + "To compute principal components base metrics, the principal components " + "extension must be computed first." + ) pc_metrics = compute_pc_metrics( sorting_analyzer, unit_ids=non_empty_unit_ids, diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index 8e03090eaf..f9611586c9 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -55,16 +55,16 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): """ Parameters ---------- - mode: "docker" | "singularity" + mode : "docker" | "singularity" The container mode - container_image: str + container_image : str container image name and tag - volumes: dict + volumes : dict dict of volumes to bind - py_user_base: str + py_user_base : str Python user base folder to set as PYTHONUSERBASE env var in Singularity mode Prevents from overwriting user's packages when running pip install - extra_kwargs: dict + extra_kwargs : dict Extra kwargs to start container """ assert mode in ("docker", "singularity") @@ -99,7 +99,7 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): singularity_image = sif_file else: - docker_image = self._get_docker_image(container_image) + docker_image = Client.load("docker://" + container_image) if docker_image and len(docker_image.tags) > 0: tag = docker_image.tags[0] print(f"Building singularity image from local docker image: {tag}") @@ -180,28 +180,28 @@ def install_package_in_container( Parameters ---------- - container_client: ContainerClient + container_client : ContainerClient The container client - package_name: str + package_name : str The package name - installation_mode: str + installation_mode : str The installation mode - extra: str + extra : str Extra pip install arguments, e.g. [full] - version: str + version : str The package version to install - tag: str + tag : str The github tag to install - github_url: str + github_url : str The github url to install (needed for github mode) - container_folder_source: str + container_folder_source : str The container folder source (needed for folder mode) - verbose: bool + verbose : bool If True, print output of pip install command Returns ------- - res_output: str + res_output : str The output of the pip install command """ assert installation_mode in ("pypi", "github", "folder") diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index a84c05c240..94d66e7f86 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -19,6 +19,7 @@ class HerdingspikesSorter(BaseSorter): "chunk_size": None, "rescale": True, "rescale_value": -1280.0, + "lowpass": True, "common_reference": "median", "spike_duration": 1.0, "amp_avg_duration": 0.4, @@ -53,6 +54,7 @@ class HerdingspikesSorter(BaseSorter): "out_file": "Path and filename to store detection and clustering results. (`str`, `HS2_detected`)", "verbose": "Print progress information. (`bool`, `True`)", "chunk_size": " Number of samples per chunk during detection. If `None`, a suitable value will be estimated. (`int`, `None`)", + "lowpass": "Enable internal low-pass filtering (simple two-step average). (`bool`, `True`)", "common_reference": "Method for common reference filtering, can be `average` or `median` (`str`, `median`)", "rescale": "Automatically re-scale the data. (`bool`, `True`)", "rescale_value": "Factor by which data is re-scaled. (`float`, `-1280.0`)", @@ -122,20 +124,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): hs_version = version.parse(hs.__version__) - if hs_version >= version.parse("0.4.001"): + if hs_version >= version.parse("0.4.1"): lightning_api = True else: lightning_api = False assert ( lightning_api - ), "HerdingSpikes version <0.4.001 is no longer supported. run:\n>>> pip install --upgrade herdingspikes" + ), "HerdingSpikes version <0.4.1 is no longer supported. To upgrade, run:\n>>> pip install --upgrade herdingspikes" recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) sorted_file = str(sorter_output_folder / "HS2_sorted.hdf5") params["out_file"] = str(sorter_output_folder / "HS2_detected") p = params + p.update({"verbose": verbose}) det = hs.HSDetectionLightning(recording, p) det.DetectFromRaw() diff --git a/src/spikeinterface/sorters/external/kilosort.py b/src/spikeinterface/sorters/external/kilosort.py index 56adb3b632..ad20408a06 100644 --- a/src/spikeinterface/sorters/external/kilosort.py +++ b/src/spikeinterface/sorters/external/kilosort.py @@ -131,20 +131,20 @@ def _check_params(cls, recording, output_folder, params): return p @classmethod - def _get_specific_options(cls, ops, params): + def _get_specific_options(cls, ops, params) -> dict: """ Adds specific options for Kilosort in the ops dict and returns the final dict Parameters ---------- - ops: dict + ops : dict options data - params: dict + params : dict Custom parameters dictionary for kilosort3 Returns ---------- - ops: dict + ops : dict Final ops data """ diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index 0425ad5e53..643769b6f9 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -142,20 +142,20 @@ def _check_params(cls, recording, output_folder, params): return p @classmethod - def _get_specific_options(cls, ops, params): + def _get_specific_options(cls, ops, params) -> dict: """ Adds specific options for Kilosort2 in the ops dict and returns the final dict Parameters ---------- - ops: dict + ops : dict options data - params: dict + params : dict Custom parameters dictionary for kilosort3 Returns ---------- - ops: dict + ops : dict Final ops data """ diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index b3d1718d59..df8f4e6873 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -158,20 +158,20 @@ def _check_params(cls, recording, output_folder, params): return p @classmethod - def _get_specific_options(cls, ops, params): + def _get_specific_options(cls, ops, params) -> dict: """ Adds specific options for Kilosort2_5 in the ops dict and returns the final dict Parameters ---------- - ops: dict + ops : dict options data - params: dict + params : dict Custom parameters dictionary for kilosort3 Returns ---------- - ops: dict + ops : dict Final ops data """ # frequency for high pass filtering (300) diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index f560fd7e1e..3681b036a2 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -154,20 +154,20 @@ def _check_params(cls, recording, output_folder, params): return p @classmethod - def _get_specific_options(cls, ops, params): + def _get_specific_options(cls, ops, params) -> dict: """ Adds specific options for Kilosort3 in the ops dict and returns the final dict Parameters ---------- - ops: dict + ops : dict options data - params: dict + params : dict Custom parameters dictionary for kilosort3 Returns ---------- - ops: dict + ops : dict Final ops data """ # frequency for high pass filtering (150) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a7f40a9558..e73ac2cb6c 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -4,8 +4,12 @@ from typing import Union from packaging import version -from ..basesorter import BaseSorter + +from ...core import write_binary_recording +from ..basesorter import BaseSorter, get_job_kwargs from .kilosortbase import KilosortBase +from ..basesorter import get_job_kwargs +from importlib.metadata import version as importlib_version PathType = Union[str, Path] @@ -16,6 +20,7 @@ class Kilosort4Sorter(BaseSorter): sorter_name: str = "kilosort4" requires_locations = True gpu_capability = "nvidia-optional" + requires_binary_data = False _default_params = { "batch_size": 60000, @@ -30,6 +35,7 @@ class Kilosort4Sorter(BaseSorter): "artifact_threshold": None, "nskip": 25, "whitening_range": 32, + "highpass_cutoff": 300, "binning_depth": 5, "sig_interp": 20, "drift_smoothing": [0.5, 0.5, 0.5], @@ -50,13 +56,18 @@ class Kilosort4Sorter(BaseSorter): "cluster_downsampling": 20, "cluster_pcs": 64, "x_centers": None, - "duplicate_spike_bins": 7, + "duplicate_spike_ms": 0.25, + "scaleproc": None, + "save_preprocessed_copy": False, + "torch_device": "auto", + "bad_channels": None, + "clear_cache": False, + "save_extra_vars": False, "do_correction": True, "keep_good_only": False, - "save_extra_kwargs": False, "skip_kilosort_preprocessing": False, - "scaleproc": None, - "torch_device": "auto", + "use_binary_file": None, + "delete_recording_dat": True, } _params_description = { @@ -72,6 +83,7 @@ class Kilosort4Sorter(BaseSorter): "artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.", "nskip": "Batch stride for computing whitening matrix. Default value: 25.", "whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.", + "highpass_cutoff": "High-pass filter cutoff frequency in Hz. Default value: 300.", "binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.", "sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.", "drift_smoothing": "Amount of gaussian smoothing to apply to the spatiotemporal drift estimation, for x,y,time axes in units of registration blocks (for x,y axes) and batch size (for time axis). The x,y smoothing has no effect for `nblocks = 1`.", @@ -93,12 +105,19 @@ class Kilosort4Sorter(BaseSorter): "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", "x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.", "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.", - "keep_good_only": "If True only 'good' units are returned", - "do_correction": "If True, drift correction is performed", - "save_extra_kwargs": "If True, additional kwargs are saved to the output", - "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", + "save_extra_vars": "If True, additional kwargs are saved to the output", "scaleproc": "int16 scaling of whitened data, if None set to 200.", + "save_preprocessed_copy": "Save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", + "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", + "clear_cache": "If True, force pytorch to free up memory reserved for its cache in between memory-intensive operations. Note that setting `clear_cache=True` is NOT recommended unless you encounter GPU out-of-memory errors, since this can result in slower sorting.", + "do_correction": "If True, drift correction is performed. Default is True. (spikeinterface parameter)", + "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing. (spikeinterface parameter)", + "keep_good_only": "If True, only the units labeled as 'good' by Kilosort are returned in the output. (spikeinterface parameter)", + "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " + "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " + "Default is None. (spikeinterface parameter)", + "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True. (spikeinterface parameter)", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -108,7 +127,7 @@ class Kilosort4Sorter(BaseSorter): For more information see https://github.com/MouseLand/Kilosort""" installation_mesg = """\nTo use Kilosort4 run:\n - >>> pip install kilosort==4.0 + >>> pip install kilosort --upgrade More information on Kilosort4 at: https://github.com/MouseLand/Kilosort @@ -129,9 +148,27 @@ def is_installed(cls): @classmethod def get_sorter_version(cls): - import kilosort as ks + """kilosort.__version__ <4.0.10 is always '4'""" + return importlib_version("kilosort") + + @classmethod + def initialize_folder(cls, recording, output_folder, verbose, remove_existing_folder): + if not cls.is_installed(): + raise Exception( + f"The sorter {cls.sorter_name} is not installed. Please install it with:\n{cls.installation_mesg}" + ) + cls.check_sorter_version() + return super(Kilosort4Sorter, cls).initialize_folder(recording, output_folder, verbose, remove_existing_folder) - return ks.__version__ + @classmethod + def check_sorter_version(cls): + kilosort_version = version.parse(cls.get_sorter_version()) + if kilosort_version < version.parse("4.0.16"): + raise Exception( + f"""SpikeInterface only supports kilosort versions 4.0.16 and above. You are running version {kilosort_version}. To install the latest version, run: + >>> pip install kilosort --upgrade + """ + ) @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @@ -141,6 +178,17 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): probe_filename = sorter_output_folder / "probe.prb" write_prb(probe_filename, pg) + if params["use_binary_file"]: + if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # local copy needed + binary_file_path = sorter_output_folder / "recording.dat" + write_binary_recording( + recording=recording, + file_paths=[binary_file_path], + **get_job_kwargs(params, verbose), + ) + params["filename"] = str(binary_file_path) + @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): from kilosort.run_kilosort import ( @@ -153,7 +201,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): save_sorting, get_run_parameters, ) - from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered + from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered, save_preprocessing from kilosort.parameters import DEFAULT_SETTINGS import time @@ -165,6 +213,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): logging.basicConfig(level=logging.INFO) + if version.parse(cls.get_sorter_version()) < version.parse("4.0.5"): + raise RuntimeError( + "Kilosort versions before 4.0.5 are not supported" + "in SpikeInterface. " + "Please upgrade Kilosort version." + ) + sorter_output_folder = sorter_output_folder.absolute() probe_filename = sorter_output_folder / "probe.prb" @@ -176,16 +231,42 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # load probe recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) - probe = load_probe(probe_filename) + probe = load_probe(probe_path=probe_filename) probe_name = "" - filename = "" - # this internally concatenates the recording - file_object = RecordingExtractorAsArray(recording) + if params["use_binary_file"] is None: + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + else: + # the recording is not binary compatible and no binary copy has been written. + # in this case, we use the RecordingExtractorAsArray object + filename = "" + file_object = RecordingExtractorAsArray(recording_extractor=recording) + elif params["use_binary_file"]: + # here we force the use of a binary file + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + else: + # a local copy has been written + filename = str(sorter_output_folder / "recording.dat") + file_object = None + else: + # here we force the use of the RecordingExtractorAsArray object + filename = "" + file_object = RecordingExtractorAsArray(recording_extractor=recording) + + data_dtype = recording.get_dtype() do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] - save_extra_vars = params["save_extra_kwargs"] + save_extra_vars = params["save_extra_vars"] + save_preprocessed_copy = params["save_preprocessed_copy"] progress_bar = None settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS} settings_ks["n_chan_bin"] = recording.get_num_channels() @@ -205,31 +286,46 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # NOTE: Also modifies settings in-place data_dir = "" results_dir = sorter_output_folder - filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) - if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( - get_run_parameters(ops) - ) - else: - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( - get_run_parameters(ops) - ) + bad_channels = params["bad_channels"] + clear_cache = params["clear_cache"] + + filename, data_dir, results_dir, probe = set_files( + settings=settings, + filename=filename, + probe=probe, + probe_name=probe_name, + data_dir=data_dir, + results_dir=results_dir, + bad_channels=bad_channels, + ) + + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=data_dtype, + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + save_preprocessed_copy=save_preprocessed_copy, + ) + + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( + get_run_parameters(ops) + ) # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: - ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) + ops = compute_preprocessing(ops=ops, device=device, tic0=tic0, file_object=file_object) else: print("Skipping kilosort preprocessing.") bfile = BinaryFiltered( - ops["filename"], - n_chan_bin, - fs, - NT, - nt, - twav_min, - chan_map, + filename=ops["filename"], + n_chan_bin=n_chan_bin, + fs=fs, + NT=NT, + nt=nt, + nt0min=twav_min, + chan_map=chan_map, hp_filter=None, device=device, do_CAR=do_CAR, @@ -243,29 +339,67 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None) ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) ops["Nbatches"] = bfile.n_batches + # bfile.close() # TODO: KS do this after preprocessing? np.random.seed(1) torch.cuda.manual_seed_all(1) torch.random.manual_seed(1) - # if not params["skip_kilosort_preprocessing"]: + if not params["do_correction"]: print("Skipping drift correction.") ops["nblocks"] = 0 # this function applies both preprocessing and drift correction ops, bfile, st0 = compute_drift_correction( - ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ops=ops, + device=device, + tic0=tic0, + progress_bar=progress_bar, + file_object=file_object, + clear_cache=clear_cache, ) + if save_preprocessed_copy: + save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) + # Sort spikes and save results - st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) - clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar) + st, tF, _, _ = detect_spikes( + ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar, clear_cache=clear_cache + ) + + clu, Wall = cluster_spikes( + st=st, + tF=tF, + ops=ops, + device=device, + bfile=bfile, + tic0=tic0, + progress_bar=progress_bar, + clear_cache=clear_cache, + ) + if params["skip_kilosort_preprocessing"]: ops["preprocessing"] = dict( hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) + _ = save_sorting( + ops=ops, + results_dir=results_dir, + st=st, + clu=clu, + tF=tF, + Wall=Wall, + imin=bfile.imin, + tic0=tic0, + save_extra_vars=save_extra_vars, + save_preprocessed_copy=save_preprocessed_copy, + ) + + if params["delete_recording_dat"]: + # only delete dat file if it was created by the wrapper + if (sorter_output_folder / "recording.dat").is_file(): + (sorter_output_folder / "recording.dat").unlink() @classmethod def _get_result_from_folder(cls, sorter_output_folder): diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index d04128b50c..95d8d3badc 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -79,11 +79,11 @@ def _generate_ops_file(cls, recording, params, sorter_output_folder, binary_file Parameters ---------- - recording: BaseRecording + recording : BaseRecording The recording to generate the channel map file - params: dict + params : dict Custom parameters dictionary for kilosort - sorter_output_folder: pathlib.Path + sorter_output_folder : pathlib.Path Path object to save `ops.mat` """ ops = {} @@ -146,7 +146,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): padding_start = 0 padding_end = pad padded_recording = TracePaddedRecording( - parent_recording=recording, + recording=recording, padding_start=padding_start, padding_end=padding_end, ) diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index c7127226b0..7ed5b29556 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -250,6 +250,8 @@ def run_sorter_by_property( Controls sorter verboseness docker_image : None or str, default: None If str run the sorter inside a container (docker) using the docker package + singularity_image : None or str, default: None + If str run the sorter inside a container (singularity) using the docker package **sorter_params : keyword args Spike sorter specific arguments (they can be retrieved with `get_default_sorter_params(sorter_name_or_class)`) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 80608f8973..d28af7b99c 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -96,23 +96,6 @@ If True, the output Sorting is returned as a Sorting delete_container_files : bool, default: True If True, the container temporary files are deleted after the sorting is done - extra_requirements : list, default: None - List of extra requirements to install in the container - installation_mode : "auto" | "pypi" | "github" | "folder" | "dev" | "no-install", default: "auto" - How spikeinterface is installed in the container: - * "auto" : if host installation is a pip release then use "github" with tag - if host installation is DEV_MODE=True then use "dev" - * "pypi" : use pypi with pip install spikeinterface - * "github" : use github with `pip install git+https` - * "folder" : mount a folder in container and install from this one. - So the version in the container is a different spikeinterface version from host, useful for - cross checks - * "dev" : same as "folder", but the folder is the spikeinterface.__file__ to ensure same version as host - * "no-install" : do not install spikeinterface in the container because it is already installed - spikeinterface_version : str, default: None - The spikeinterface version to install in the container. If None, the current version is used - spikeinterface_folder_source : Path or None, default: None - In case of installation_mode="folder", the spikeinterface folder source to use to install in the container output_folder : None, default: None Do not use. Deprecated output function to be removed in 0.103. **sorter_params : keyword args @@ -272,7 +255,9 @@ def run_sorter_local( # only classmethod call not instance (stateless at instance level but state is in folder) folder = SorterClass.initialize_folder(recording, folder, verbose, remove_existing_folder) SorterClass.set_params_to_folder(recording, folder, sorter_params, verbose) + # This writes parameters and recording to binary and could ideally happen in the host SorterClass.setup_recording(recording, folder, verbose=verbose) + # This NEEDS to happen in the docker because of dependencies SorterClass.run_from_folder(folder, raise_error, verbose) if with_output: sorting = SorterClass.get_result_from_folder(folder, register_recording=True, sorting_info=True) @@ -691,7 +676,9 @@ def read_sorter_folder(folder, register_recording=True, sorting_info=True, raise register_recording : bool, default: True Attach recording (when json or pickle) to the sorting sorting_info : bool, default: True - Attach sorting info to the sorting. + Attach sorting info to the sorting + raise_error : bool, detault: True + Raise an error if the spike sorting failed """ folder = Path(folder) log_file = folder / "spikeinterface_log.json" diff --git a/src/spikeinterface/sorters/sorterlist.py b/src/spikeinterface/sorters/sorterlist.py index 105a536617..e1a6816133 100644 --- a/src/spikeinterface/sorters/sorterlist.py +++ b/src/spikeinterface/sorters/sorterlist.py @@ -76,7 +76,7 @@ def print_sorter_versions(): print(txt) -def get_default_sorter_params(sorter_name_or_class): +def get_default_sorter_params(sorter_name_or_class) -> dict: """Returns default parameters for the specified sorter. Parameters @@ -100,7 +100,7 @@ def get_default_sorter_params(sorter_name_or_class): return SorterClass.default_params() -def get_sorter_params_description(sorter_name_or_class): +def get_sorter_params_description(sorter_name_or_class) -> dict: """Returns a description of the parameters for the specified sorter. Parameters @@ -124,7 +124,7 @@ def get_sorter_params_description(sorter_name_or_class): return SorterClass.params_description() -def get_sorter_description(sorter_name_or_class): +def get_sorter_description(sorter_name_or_class) -> dict: """Returns a brief description for the specified sorter. Parameters diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 7381875557..99881f2f34 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -12,15 +12,15 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, Parameters ---------- - recording: RecordingExtractor + recording : RecordingExtractor The recording extractor object - peaks: numpy.array + peaks : numpy.array The peak vector - method: str + method : str Which method to use ("stupid" | "XXXX") - method_kwargs: dict, default: dict() + method_kwargs : dict, default: dict() Keyword arguments for the chosen method - extra_outputs: bool, default: False + extra_outputs : bool, default: False If True then debug is also return {} diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 183fdab04c..ad7391a297 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -20,7 +20,9 @@ from .main import BaseTemplateMatchingEngine -def compress_templates(templates_array, approx_rank, remove_mean=True, return_new_templates=True): +def compress_templates( + templates_array, approx_rank, remove_mean=True, return_new_templates=True +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]: """Compress templates using singular value decomposition. Parameters diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 9476a0df03..6e5267cb70 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -9,25 +9,27 @@ def find_spikes_from_templates( recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs -): +) -> np.ndarray | tuple[np.ndarray, dict]: """Find spike from a recording from given templates. Parameters ---------- - recording: RecordingExtractor + recording : RecordingExtractor The recording extractor object - method: "naive" | "tridesclous" | "circus" | "circus-omp" | "wobble" + method : "naive" | "tridesclous" | "circus" | "circus-omp" | "wobble", default: "naive" Which method to use for template matching - method_kwargs: dict, optional + method_kwargs : dict, optional Keyword arguments for the chosen method - extra_outputs: bool + extra_outputs : bool If True then method_kwargs is also returned - job_kwargs: dict + **job_kwargs : dict Parameters for ChunkRecordingExecutor + verbose : Bool, default: False + If True, output is verbose Returns ------- - spikes: ndarray + spikes : ndarray Spikes found from templates. method_kwargs: Optionaly returns for debug purpose. diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 3b692c3bf0..99de6fcd4e 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -522,7 +522,9 @@ def main_function(cls, traces, method_kwargs): # TODO: Replace this method with equivalent from spikeinterface @classmethod - def find_peaks(cls, objective, objective_normalized, spike_trains, params, template_data, template_meta): + def find_peaks( + cls, objective, objective_normalized, spike_trains, params, template_data, template_meta + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Find new peaks in the objective and update spike train accordingly. Parameters @@ -603,7 +605,7 @@ def find_peaks(cls, objective, objective_normalized, spike_trains, params, templ @classmethod def subtract_spike_train( cls, spike_train, scalings, template_data, objective, objective_normalized, params, template_meta, sparsity - ): + ) -> tuple[np.ndarray, np.ndarray]: """Subtract spike train of templates from the objective directly. Parameters @@ -662,7 +664,7 @@ def calculate_high_res_shift( template_data, params, template_meta, - ): + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Determines optimal shifts when super-resolution, scaled templates are used. Parameters @@ -746,7 +748,9 @@ def calculate_high_res_shift( return template_shift, time_shift, non_refractory_indices, scalings @classmethod - def enforce_refractory(cls, spike_train, objective, objective_normalized, params, template_meta): + def enforce_refractory( + cls, spike_train, objective, objective_normalized, params, template_meta + ) -> tuple[np.ndarray, np.ndarray]: """Enforcing the refractory period for each unit by setting the objective to -infinity. Parameters @@ -815,7 +819,7 @@ def compute_template_norm(visible_channels, templates): return norm_squared -def compress_templates(templates, approx_rank): +def compress_templates(templates, approx_rank) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Compress templates using singular value decomposition. Parameters @@ -936,7 +940,7 @@ def convolve_templates(compressed_templates, jitter_factor, approx_rank, jittere return pairwise_convolution -def compute_objective(traces, template_data, approx_rank): +def compute_objective(traces, template_data, approx_rank) -> np.ndarray: """Compute objective by convolving templates with voltage traces. Parameters @@ -973,7 +977,9 @@ def compute_objective(traces, template_data, approx_rank): return objective -def compute_scale_amplitudes(high_resolution_conv, norm_peaks, scale_min, scale_max, amplitude_variance): +def compute_scale_amplitudes( + high_resolution_conv, norm_peaks, scale_min, scale_max, amplitude_variance +) -> tuple[np.ndarray, np.ndarray]: """Compute optimal amplitude scaling and the high-resolution objective resulting from scaled spikes. Without hard clipping, the objective can be obtained via diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index ded4c257ab..e2b6b1a2bc 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -1141,7 +1141,7 @@ def normxcorr1d( normalized=True, padding="same", conv_engine="torch", -): +) -> "torch.Tensor | np.ndarray": """ normxcorr1d: Normalized cross-correlation, optionally weighted diff --git a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py index 2fc1a281a9..87dca64496 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py @@ -11,16 +11,16 @@ def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=3 Parameters ---------- - motion: numpy array 2d + motion : numpy array 2d Motion estimate in um. - temporal_bins: numpy.array 1d + temporal_bins : numpy.array 1d temporal bins (bin center) - bin_duration_s: float + bin_duration_s : float bin duration in second - speed_threshold: float (units um/s) + speed_threshold : float (units um/s) Maximum speed treshold between 2 bins allowed. Expressed in um/s - sigma_smooth_s: None or float + sigma_smooth_s : None or float Optional smooting gaussian kernel. Returns diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index 0d425c98da..8a4daeb808 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -30,10 +30,8 @@ def estimate_motion( verbose=False, margin_um=None, **method_kwargs, -): +) -> Motion | tuple[Motion, dict]: """ - - Estimate motion with several possible methods. Most of methods except dredge_lfp needs peaks and after their localization. @@ -43,21 +41,19 @@ def estimate_motion( Parameters ---------- - recording: BaseRecording + recording : BaseRecording The recording extractor - peaks: numpy array + peaks : numpy array Peak vector (complex dtype). Needed for decentralized and iterative_template methods. - peak_locations: numpy array + peak_locations : numpy array Complex dtype with "x", "y", "z" fields Needed for decentralized and iterative_template methods. - direction: "x" | "y" | "z", default: "y" + direction : "x" | "y" | "z", default: "y" Dimension on which the motion is estimated. "y" is depth along the probe. {method_doc} - **non-rigid section** - rigid : bool, default: False Compute rigid (one motion for the entire probe) or non rigid motion Rigid computation is equivalent to non-rigid with only one window with rectangular shape. @@ -76,14 +72,14 @@ def estimate_motion( See win_shape win_margin_um : None | float, default: None See win_shape - extra_outputs: bool, default: False + extra_outputs : bool, default: False If True then return an extra dict that contains variables to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) - progress_bar: bool, default: False + progress_bar : bool, default: False Display progress bar or not - verbose: bool, default: False + verbose : bool, default: False If True, output is verbose - + **method_kwargs : Returns ------- diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 11ce11e1aa..a5e6ded519 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -7,7 +7,7 @@ from spikeinterface.preprocessing.filter import fix_dtype -def correct_motion_on_peaks(peaks, peak_locations, motion, recording): +def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray: """ Given the output of estimate_motion(), apply inverse motion on peak locations. @@ -69,23 +69,25 @@ def interpolate_motion_on_traces( Trace snippet (num_samples, num_channels) times : np.array Sample times in seconds for the frames of the traces snippet - channel_location: np.array 2d + channel_locations : np.array 2d Channel location with shape (n, 2) or (n, 3) - motion: Motion + motion : Motion The motion object. - segment_index: int or None + segment_index : int or None The segment index. - channel_inds: None or list + channel_inds : None or list If not None, interpolate only a subset of channels. interpolation_time_bin_centers_s : None or np.array Manually specify the time bins which the interpolation happens in for this segment. If None, these are the motion estimate's time bins. - spatial_interpolation_method: "idw" | "kriging", default: "kriging" + spatial_interpolation_method : "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing * kriging : kilosort2.5 like - spatial_interpolation_kwargs: - * specific option for the interpolation method + spatial_interpolation_kwargs : dict + specific option for the interpolation method + dtype : np.dtype, default: None + The dtype of the traces. If None, interhits from traces snippet Returns ------- @@ -237,11 +239,11 @@ class InterpolateMotionRecording(BasePreprocessor): Parameters ---------- - recording: Recording + recording : Recording The parent recording. - motion: Motion + motion : Motion The motion object - spatial_interpolation_method: "kriging" | "idw" | "nearest", default: "kriging" + spatial_interpolation_method : "kriging" | "idw" | "nearest", default: "kriging" The spatial interpolation method used to interpolate the channel locations. See `spikeinterface.preprocessing.get_spatial_interpolation_kernel()` for more details. Choice of the method: @@ -249,23 +251,24 @@ class InterpolateMotionRecording(BasePreprocessor): * "kriging" : the same one used in kilosort * "idw" : inverse distance weighted * "nearest" : use neareast channel - sigma_um: float, default: 20.0 + + sigma_um : float, default: 20.0 Used in the "kriging" formula - p: int, default: 1 + p : int, default: 1 Used in the "kriging" formula - num_closest: int, default: 3 + num_closest : int, default: 3 Number of closest channels used by "idw" method for interpolation. - border_mode: "remove_channels" | "force_extrapolate" | "force_zeros", default: "remove_channels" + border_mode : "remove_channels" | "force_extrapolate" | "force_zeros", default: "remove_channels" Control how channels are handled on border: * "remove_channels": remove channels on the border, the recording has less channels * "force_extrapolate": keep all channel and force extrapolation (can lead to strange signal) * "force_zeros": keep all channel but set zeros when outside (force_extrapolate=False) - interpolation_time_bin_centers_s: np.array or list of np.array, optional + interpolation_time_bin_centers_s : np.array or list of np.array, optional Spatially interpolate each frame according to the displacement estimate at its closest bin center in this array. If not supplied, this is set to the motion estimate's time bin centers. If it's supplied, the motion estimate is interpolated to these bin centers. If you have a multi-segment recording, pass a list of these, one per segment. - interpolation_time_bin_size_s: float, optional + interpolation_time_bin_size_s : float, optional Similar to the previous argument: interpolation_time_bin_centers_s will be constructed by bins spaced by interpolation_time_bin_size_s. This is ignored if interpolation_time_bin_centers_s is supplied. @@ -273,6 +276,8 @@ class InterpolateMotionRecording(BasePreprocessor): Interpolation needs to convert to a floating dtype. If dtype is supplied, that will be used. If the input recording is already floating and dtype=None, then its dtype is used by default. If the input recording is integer, then float32 is used by default. + **spatial_interpolation_kwargs : dict + Spatial interpolation kwargs for `interpolate_motion_on_traces`. Returns ------- @@ -386,6 +391,10 @@ def __init__( ) self.add_recording_segment(rec_segment) + # this object is currently not JSON-serializable because the Motion obejct cannot be reloaded properly + # see issue #3313 + self._serializability["json"] = False + self._kwargs = dict( recording=recording, motion=motion, diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 203bd2473b..635624cca8 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -17,6 +17,7 @@ class Motion: Motion estimate in um. List is the number of segment. For each semgent : + * shape (temporal bins, spatial bins) * motion.shape[0] = temporal_bins.shape[0] * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index b984853123..4fe90dd7bc 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -59,19 +59,19 @@ def detect_peaks( Parameters ---------- - recording: RecordingExtractor + recording : RecordingExtractor The recording extractor object. - pipeline_nodes: None or list[PipelineNode] + pipeline_nodes : None or list[PipelineNode] Optional additional PipelineNode need to computed just after detection time. This avoid reading the recording multiple times. - gather_mode: str + gather_mode : str How to gather the results: * "memory": results are returned as in-memory numpy arrays * "npy": results are stored to .npy files in `folder` - folder: str or Path + folder : str or Path If gather_mode is "npy", the folder where the files are created. - names: list + names : list List of strings with file stems associated with returns. {method_doc} diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index b578eb4478..ddc8add995 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -88,7 +88,7 @@ def get_localization_pipeline_nodes( return pipeline_nodes -def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs): +def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_after=0.5, **kwargs) -> np.ndarray: """Localize peak (spike) in 2D or 3D depending the method. When a probe is 2D then: @@ -98,10 +98,14 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ Parameters ---------- - recording: RecordingExtractor + recording : RecordingExtractor The recording extractor object. - peaks: array + peaks : array Peaks array, as returned by detect_peaks() in "compact_numpy" way. + ms_before : float + The number of milliseconds to include before the peak of the spike + ms_after : float + The number of milliseconds to include after the peak of the spike {method_doc} diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index facefac4c5..1501582336 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -20,14 +20,14 @@ def make_multi_method_doc(methods, ident=" "): doc = "" - doc += "method: " + ", ".join(f"'{method.name}'" for method in methods) + "\n" + doc += "method : " + ", ".join(f"'{method.name}'" for method in methods) + "\n" doc += ident + " Method to use.\n" for method in methods: doc += "\n" - doc += ident + f"arguments for method='{method.name}'" + doc += ident + ident + f"arguments for method='{method.name}'" for line in method.params_doc.splitlines(): - doc += ident + line + "\n" + doc += ident + ident + line + "\n" return doc @@ -70,18 +70,18 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j def get_prototype_spike(recording, peaks, ms_before=0.5, ms_after=0.5, nb_peaks=1000, **job_kwargs): + from spikeinterface.sortingcomponents.peak_selection import select_peaks + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) - from spikeinterface.sortingcomponents.peak_selection import select_peaks - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=nb_peaks, margin=(nbefore, nafter)) waveforms = extract_waveform_at_max_channel( recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) with np.errstate(divide="ignore", invalid="ignore"): - prototype = np.median(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) + prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) return prototype diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 6f60e9ab9a..a113298851 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -43,7 +43,12 @@ class SortingSummaryWidget(BaseWidget): List of labels to be added to the curation table (sortingview backend) unit_table_properties : list or None, default: None - List of properties to be added to the unit table + List of properties to be added to the unit table. + These may be drawn from the sorting extractor, and, if available, + the quality_metrics and template_metrics extensions of the SortingAnalyzer. + See all properties available with sorting.get_property_keys(), and, if available, + analyzer.get_extension("quality_metrics").get_data().columns and + analyzer.get_extension("template_metrics").get_data().columns. (sortingview backend) """ @@ -151,7 +156,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # unit ids v_units_table = generate_unit_table_view( - dp.sorting_analyzer.sorting, dp.unit_table_properties, similarity_scores=similarity_scores + dp.sorting_analyzer, dp.unit_table_properties, similarity_scores=similarity_scores ) if dp.curation: diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index e06c79ad2f..debcd52085 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -539,6 +539,21 @@ def test_plot_sorting_summary(self): backend=backend, **self.backend_kwargs[backend], ) + # add unit_properties + sw.plot_sorting_summary( + self.sorting_analyzer_sparse, + unit_table_properties=["firing_rate", "snr"], + backend=backend, + **self.backend_kwargs[backend], + ) + # adding a missing property should raise a warning + with self.assertWarns(UserWarning): + sw.plot_sorting_summary( + self.sorting_analyzer_sparse, + unit_table_properties=["missing_property"], + backend=backend, + **self.backend_kwargs[backend], + ) def test_plot_agreement_matrix(self): possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index e31f0e0444..0aae1777a9 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -10,7 +10,9 @@ def check_ipywidget_backend(): import matplotlib mpl_backend = matplotlib.get_backend() - assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" + assert ( + "ipympl" in mpl_backend or "widget" in mpl_backend + ), "To use the 'ipywidgets' backend, you have to set %matplotlib widget" class TimeSlider(W.HBox): diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index ab926a0104..7a9dc47826 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -2,7 +2,9 @@ import numpy as np +from ..core import SortingAnalyzer, BaseSorting from ..core.core_tools import check_json +from warnings import warn def make_serializable(*args): @@ -45,9 +47,49 @@ def handle_display_and_url(widget, view, **backend_kwargs): return url -def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None): +def generate_unit_table_view( + sorting_or_sorting_analyzer: SortingAnalyzer | BaseSorting, + unit_properties: list[str] | None = None, + similarity_scores: npndarray | None = None, +): import sortingview.views as vv + if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer): + analyzer = sorting_or_sorting_analyzer + sorting = analyzer.sorting + else: + sorting = sorting_or_sorting_analyzer + analyzer = None + + # Find available unit properties from all sources + sorting_props = list(sorting.get_property_keys()) + if analyzer is not None: + if analyzer.get_extension("quality_metrics") is not None: + qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns) + qm_data = analyzer.get_extension("quality_metrics").get_data() + else: + qm_props = [] + if analyzer.get_extension("template_metrics") is not None: + tm_props = list(analyzer.get_extension("template_metrics").get_data().columns) + tm_data = analyzer.get_extension("template_metrics").get_data() + else: + tm_props = [] + # Check for any overlaps and warn user if any + all_props = sorting_props + qm_props + tm_props + else: + all_props = sorting_props + qm_props = [] + tm_props = [] + qm_data = None + tm_data = None + + overlap_props = [prop for prop in all_props if all_props.count(prop) > 1] + if len(overlap_props) > 0: + warn( + f"Warning: Overlapping properties found in sorting, quality_metrics, and template_metrics: {overlap_props}" + ) + + # Get unit properties if unit_properties is None: ut_columns = [] ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] @@ -56,8 +98,21 @@ def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=No ut_rows = [] values = {} valid_unit_properties = [] + + # Create columns for each property for prop_name in unit_properties: - property_values = sorting.get_property(prop_name) + + # Get property values from correct location + if prop_name in sorting_props: + property_values = sorting.get_property(prop_name) + elif prop_name in qm_props: + property_values = qm_data[prop_name].values + elif prop_name in tm_props: + property_values = tm_data[prop_name].values + else: + warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") + continue + # make dtype available val0 = np.array(property_values[0]) if val0.dtype.kind in ("i", "u"): @@ -69,19 +124,29 @@ def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=No elif val0.dtype.kind == "b": dtype = "bool" else: - print(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") + warn(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") continue ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) valid_unit_properties.append(prop_name) + # Create rows for each unit for ui, unit in enumerate(sorting.unit_ids): for prop_name in valid_unit_properties: - property_values = sorting.get_property(prop_name) + + # Get property values from correct location + if prop_name in sorting_props: + property_values = sorting.get_property(prop_name) + elif prop_name in qm_props: + property_values = qm_data[prop_name].values + elif prop_name in tm_props: + property_values = tm_data[prop_name].values + + # Check for NaN values val0 = np.array(property_values[0]) if val0.dtype.kind == "f": if np.isnan(property_values[ui]): continue - values[prop_name] = property_values[ui] + values[prop_name] = np.format_float_positional(property_values[ui], precision=4, fractional=False) ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores)