From 0ee0c43a08711802a14323b1bc13046328b99540 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 26 Aug 2024 17:01:56 +0100 Subject: [PATCH 01/19] Add bad channels and do version check --- .../sorters/external/kilosort4.py | 64 +++++++++++-------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 8499cef11f..1a3ba59b54 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -59,6 +59,7 @@ class Kilosort4Sorter(BaseSorter): "scaleproc": None, "save_preprocessed_copy": False, "torch_device": "auto", + "bad_channels": None, } _params_description = { @@ -101,6 +102,7 @@ class Kilosort4Sorter(BaseSorter): "scaleproc": "int16 scaling of whitened data, if None set to 200.", "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", + "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -110,7 +112,7 @@ class Kilosort4Sorter(BaseSorter): For more information see https://github.com/MouseLand/Kilosort""" installation_mesg = """\nTo use Kilosort4 run:\n - >>> pip install kilosort==4.0 + >>> pip install kilosort --upgrade More information on Kilosort4 at: https://github.com/MouseLand/Kilosort @@ -134,6 +136,25 @@ def get_sorter_version(cls): """kilosort.__version__ <4.0.10 is always '4'""" return importlib_version("kilosort") + @classmethod + def initialize_folder(cls, recording, output_folder, verbose, remove_existing_folder): + if not cls.is_installed(): + raise Exception( + f"The sorter {cls.sorter_name} is not installed. Please install it with:\n{cls.installation_mesg}" + ) + cls.check_sorter_version() + return super(Kilosort4Sorter, cls).initialize_folder(recording, output_folder, verbose, remove_existing_folder) + + @classmethod + def check_sorter_version(cls): + kilosort_version = version.parse(cls.get_sorter_version()) + if kilosort_version < version.parse("4.0.16"): + raise Exception( + f"""SpikeInterface only supports kilosort versions 4.0.16 and above. You are running version {kilosort_version}. To install the latest version, run: + >>> pip install kilosort --upgrade + """ + ) + @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): from probeinterface import write_prb @@ -214,6 +235,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # NOTE: Also modifies settings in-place data_dir = "" results_dir = sorter_output_folder + bad_channels = params["bad_channels"] filename, data_dir, results_dir, probe = set_files( settings=settings, @@ -222,36 +244,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe_name=probe_name, data_dir=data_dir, results_dir=results_dir, + bad_channels=bad_channels, ) - if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - ops = initialize_ops( - settings=settings, - probe=probe, - data_dtype=recording.get_dtype(), - do_CAR=do_CAR, - invert_sign=invert_sign, - device=device, - save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo) - ) - else: - ops = initialize_ops( - settings=settings, - probe=probe, - data_dtype=recording.get_dtype(), - do_CAR=do_CAR, - invert_sign=invert_sign, - device=device, - ) + ops = initialize_ops( + settings=settings, + probe=probe, + data_dtype=recording.get_dtype(), + do_CAR=do_CAR, + invert_sign=invert_sign, + device=device, + save_preprocessed_copy=save_preprocessed_copy, # this kwarg is correct (typo) + ) - if version.parse(cls.get_sorter_version()) >= version.parse("4.0.11"): - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( - get_run_parameters(ops) - ) - else: - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( - get_run_parameters(ops) - ) + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( + get_run_parameters(ops) + ) # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: From bc290ff48820cd6b50c433effd89edd86a569383 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 27 Aug 2024 09:17:12 +0100 Subject: [PATCH 02/19] remove comment about preprocesed spelling --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 1a3ba59b54..7541b48201 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -254,7 +254,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR=do_CAR, invert_sign=invert_sign, device=device, - save_preprocessed_copy=save_preprocessed_copy, # this kwarg is correct (typo) + save_preprocessed_copy=save_preprocessed_copy, ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( From 0df714123f63a9b27a38f7a60f7af61287fceaba Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 27 Aug 2024 16:22:31 +0200 Subject: [PATCH 03/19] Add use_binary_file argument and logic to KS4 --- .../sorters/external/kilosort4.py | 39 ++++++++++++++++--- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 7541b48201..4b7d0dbe6e 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -4,8 +4,11 @@ from typing import Union from packaging import version -from ..basesorter import BaseSorter + +from ...core import write_binary_recording +from ..basesorter import BaseSorter, get_job_kwargs from .kilosortbase import KilosortBase +from ..basesorter import get_job_kwargs from importlib.metadata import version as importlib_version PathType = Union[str, Path] @@ -17,6 +20,7 @@ class Kilosort4Sorter(BaseSorter): sorter_name: str = "kilosort4" requires_locations = True gpu_capability = "nvidia-optional" + requires_binary_data = False _default_params = { "batch_size": 60000, @@ -60,6 +64,7 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": False, "torch_device": "auto", "bad_channels": None, + "use_binary_file": False, } _params_description = { @@ -103,6 +108,8 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", + "use_binary_file": "If True, the Kilosort is run from a binary file. In this case, if the recording is not binary it is written to a binary file in the output folder" + "If False, the Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. Default is False.", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -163,6 +170,16 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): probe_filename = sorter_output_folder / "probe.prb" write_prb(probe_filename, pg) + if params["use_binary_file"] and not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # local copy needed + binary_file_path = sorter_output_folder / "recording.dat" + write_binary_recording( + recording=recording, + file_paths=[binary_file_path], + **get_job_kwargs(params, verbose), + ) + params["filename"] = str(binary_file_path) + @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): from kilosort.run_kilosort import ( @@ -207,10 +224,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) probe = load_probe(probe_path=probe_filename) probe_name = "" - filename = "" - # this internally concatenates the recording - file_object = RecordingExtractorAsArray(recording_extractor=recording) + if params["use_binary_file"]: + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + else: + # a local copy has been written + filename = str(sorter_output_folder / "recording.dat") + file_object = None + else: + # this internally concatenates the recording + filename = "" + file_object = RecordingExtractorAsArray(recording_extractor=recording) + data_dtype = recording.get_dtype() do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] @@ -250,7 +279,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ops = initialize_ops( settings=settings, probe=probe, - data_dtype=recording.get_dtype(), + data_dtype=data_dtype, do_CAR=do_CAR, invert_sign=invert_sign, device=device, From 227b0e71d11c20f8bbbe0014456d80c50075fead Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 31 Aug 2024 10:12:12 +0200 Subject: [PATCH 04/19] Update KS4 versions --- .github/scripts/check_kilosort4_releases.py | 2 +- .github/workflows/test_kilosort4.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 92e7bf277f..5544224d8d 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -22,7 +22,7 @@ def get_pypi_versions(package_name): "At version 0.101.1, this should be updated to support newer" "kilosort verrsions." ) - versions = [ver for ver in versions if parse("4.0.12") >= parse(ver) >= parse("4.0.5")] + versions = [ver for ver in versions if parse(ver) >= parse("4.0.16")] return versions diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 390bec98be..6c58c76813 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -7,7 +7,7 @@ on: jobs: versions: - # Poll Pypi for all released KS4 versions >4.0.4, save to JSON + # Poll Pypi for all released KS4 versions >4.0.16, save to JSON # and store them in a matrix for the next job. runs-on: ubuntu-latest outputs: From c23d53032d406f6b386773338c751af71f4e1be4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Sep 2024 18:52:18 +0200 Subject: [PATCH 05/19] Update actions and always use binary if recording is binary --- .github/scripts/check_kilosort4_releases.py | 7 +- .github/scripts/test_kilosort4_ci.py | 189 ++++++++---------- .github/workflows/test_kilosort4.yml | 7 +- .../sorters/external/kilosort4.py | 54 +++-- 4 files changed, 112 insertions(+), 145 deletions(-) diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py index 5544224d8d..7a6368f3cf 100644 --- a/.github/scripts/check_kilosort4_releases.py +++ b/.github/scripts/check_kilosort4_releases.py @@ -16,12 +16,7 @@ def get_pypi_versions(package_name): response.raise_for_status() data = response.json() versions = list(sorted(data["releases"].keys())) - - assert parse(spikeinterface.__version__) < parse("0.101.1"), ( - "Kilosort 4.0.5-12 are supported in SpikeInterface < 0.101.1." - "At version 0.101.1, this should be updated to support newer" - "kilosort verrsions." - ) + # Filter out versions that are less than 4.0.16 versions = [ver for ver in versions if parse(ver) >= parse("4.0.16")] return versions diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index e0d1f2a504..c7853a2add 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -18,6 +18,7 @@ - Do some tests to check all KS4 parameters are tested against. """ + import copy from typing import Any import spikeinterface.full as si @@ -33,11 +34,16 @@ from packaging.version import parse from importlib.metadata import version from inspect import signature -from kilosort.run_kilosort import (set_files, initialize_ops, - compute_preprocessing, - compute_drift_correction, detect_spikes, - cluster_spikes, save_sorting, - get_run_parameters, ) +from kilosort.run_kilosort import ( + set_files, + initialize_ops, + compute_preprocessing, + compute_drift_correction, + detect_spikes, + cluster_spikes, + save_sorting, + get_run_parameters, +) from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered from kilosort.parameters import DEFAULT_SETTINGS from kilosort import preprocessing as ks_preprocessing @@ -49,8 +55,7 @@ PARAMS_TO_TEST = [ # Not tested # ("torch_device", "auto") - - # Stable across KS version 4.0.01 - 4.0.12 + # Stable across KS version 4.0.16 - 4.0.X (?) ("change_nothing", None), ("nblocks", 0), ("do_CAR", False), @@ -83,38 +88,21 @@ ("acg_threshold", 1e12), ("cluster_downsampling", 2), ("duplicate_spike_bins", 5), + ("drift_smoothing", [250, 250, 250]), + ("bad_channels", None), + ("save_preprocessed_copy", False), ] -if parse(version("kilosort")) >= parse("4.0.11"): - PARAMS_TO_TEST.extend( - [ - ("shift", 1e9), - ("scale", -1e9), - ] - ) -if parse(version("kilosort")) == parse("4.0.9"): - # bug in 4.0.9 for "nblocks=0" - PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] != "nblocks"] -if parse(version("kilosort")) >= parse("4.0.8"): - PARAMS_TO_TEST.extend( - [ - ("drift_smoothing", [250, 250, 250]), - ] - ) -if parse(version("kilosort")) <= parse("4.0.6"): - # AFAIK this parameter was always unused in KS (that's why it was removed) - PARAMS_TO_TEST.extend( - [ - ("cluster_pcs", 1e9), - ] - ) -if parse(version("kilosort")) <= parse("4.0.3"): - PARAMS_TO_TEST = [param for param in PARAMS_TO_TEST if param[0] not in ["x_centers", "max_channel_distance"]] +# if parse(version("kilosort")) >= parse("4.0.X"): +# PARAMS_TO_TEST.extend( +# [ +# ("new_param", new_values), +# ] +# ) class TestKilosort4Long: - # Fixtures ###### @pytest.fixture(scope="session") def recording_and_paths(self, tmp_path_factory): @@ -200,7 +188,6 @@ def test_params_to_test(self): otherwise there is no point to the test. """ for parameter in PARAMS_TO_TEST: - param_key, param_value = parameter if param_key == "change_nothing": @@ -218,7 +205,6 @@ def test_default_settings_all_represented(self): tested_keys = [entry[0] for entry in PARAMS_TO_TEST] for param_key in DEFAULT_SETTINGS: - if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: if parse(version("kilosort")) == parse("4.0.9") and param_key == "nblocks": continue @@ -241,16 +227,18 @@ def test_spikeinterface_defaults_against_kilsort(self): # Testing Arguments ### def test_set_files_arguments(self): - self._check_arguments( - set_files, - ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"] - ) + self._check_arguments(set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"]) def test_initialize_ops_arguments(self): - expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"] - - if parse(version("kilosort")) >= parse("4.0.12"): - expected_arguments.append("save_preprocesed_copy") + expected_arguments = [ + "settings", + "probe", + "data_dtype", + "do_CAR", + "invert_sign", + "device", + "save_preprocessed_copy", + ] self._check_arguments( initialize_ops, @@ -258,28 +246,16 @@ def test_initialize_ops_arguments(self): ) def test_compute_preprocessing_arguments(self): - self._check_arguments( - compute_preprocessing, - ["ops", "device", "tic0", "file_object"] - ) + self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"]) def test_compute_drift_location_arguments(self): - self._check_arguments( - compute_drift_correction, - ["ops", "device", "tic0", "progress_bar", "file_object"] - ) + self._check_arguments(compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object"]) def test_detect_spikes_arguments(self): - self._check_arguments( - detect_spikes, - ["ops", "device", "bfile", "tic0", "progress_bar"] - ) + self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar"]) def test_cluster_spikes_arguments(self): - self._check_arguments( - cluster_spikes, - ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"] - ) + self._check_arguments(cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"]) def test_save_sorting_arguments(self): expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] @@ -287,50 +263,47 @@ def test_save_sorting_arguments(self): if parse(version("kilosort")) > parse("4.0.11"): expected_arguments.append("save_preprocessed_copy") - self._check_arguments( - save_sorting, - expected_arguments - ) + self._check_arguments(save_sorting, expected_arguments) def test_get_run_parameters(self): - self._check_arguments( - get_run_parameters, - ["ops"] - ) + self._check_arguments(get_run_parameters, ["ops"]) def test_load_probe_parameters(self): - self._check_arguments( - load_probe, - ["probe_path"] - ) + self._check_arguments(load_probe, ["probe_path"]) def test_recording_extractor_as_array_arguments(self): - self._check_arguments( - RecordingExtractorAsArray, - ["recording_extractor"] - ) + self._check_arguments(RecordingExtractorAsArray, ["recording_extractor"]) def test_binary_filtered_arguments(self): expected_arguments = [ - "filename", "n_chan_bin", "fs", "NT", "nt", "nt0min", - "chan_map", "hp_filter", "whiten_mat", "dshift", - "device", "do_CAR", "artifact_threshold", "invert_sign", - "dtype", "tmin", "tmax", "file_object" + "filename", + "n_chan_bin", + "fs", + "NT", + "nt", + "nt0min", + "chan_map", + "hp_filter", + "whiten_mat", + "dshift", + "device", + "do_CAR", + "artifact_threshold", + "invert_sign", + "dtype", + "tmin", + "tmax", + "shift", + "scale", + "file_object", ] - if parse(version("kilosort")) >= parse("4.0.11"): - expected_arguments.pop(-1) - expected_arguments.extend(["shift", "scale", "file_object"]) - - self._check_arguments( - BinaryFiltered, - expected_arguments - ) + self._check_arguments(BinaryFiltered, expected_arguments) def _check_arguments(self, object_, expected_arguments): """ Check that the argument signature of `object_` is as expected - (i..e has not changed across kilosort versions). + (i.e. has not changed across kilosort versions). """ sig = signature(object_) obj_arguments = list(sig.parameters.keys()) @@ -352,7 +325,9 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa kilosort_output_dir = tmp_path / "kilosort_output_dir" spikeinterface_output_dir = tmp_path / "spikeinterface_output_dir" - settings, run_kilosort_kwargs, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) + settings, run_kilosort_kwargs, ks_format_probe = self._get_kilosort_native_settings( + recording, paths, param_key, param_value + ) kilosort.run_kilosort( settings=settings, @@ -434,15 +409,18 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") - @pytest.mark.parametrize("param_to_test", [ - ("change_nothing", None), - ("do_CAR", False), - ("batch_size", 42743), - ("Th_learned", 14), - ("dmin", 15), - ("max_channel_distance", 5), - ("n_pcs", 3), - ]) + @pytest.mark.parametrize( + "param_to_test", + [ + ("change_nothing", None), + ("do_CAR", False), + ("batch_size", 42743), + ("Th_learned", 14), + ("dmin", 15), + ("max_channel_distance", 5), + ("n_pcs", 3), + ], + ) def test_kilosort4_skip_preprocessing_correction(self, tmp_path, monkeypatch, param_to_test): """ Test that skipping KS4 preprocessing works as expected. Run @@ -498,8 +476,7 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): pass return X - monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", - monkeypatch_filter_function) + monkeypatch.setattr("kilosort.io.BinaryFiltered.filter", monkeypatch_filter_function) ks_settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, param_key, param_value) ks_settings["nblocks"] = 0 @@ -552,15 +529,11 @@ def _check_test_parameters_are_changing_the_output(self, results, default_result return if param_key == "change_nothing": - assert all( - default_results["ks"]["st"] == results["ks"]["st"] - ) and all( + assert all(default_results["ks"]["st"] == results["ks"]["st"]) and all( default_results["ks"]["clus"] == results["ks"]["clus"] ), f"{param_key} changed somehow!." else: - assert not ( - default_results["ks"]["st"].size == results["ks"]["st"].size - ) or not all( + assert not (default_results["ks"]["st"].size == results["ks"]["st"].size) or not all( default_results["ks"]["clus"] == results["ks"]["clus"] ), f"{param_key} results did not change with parameter change." @@ -598,7 +571,7 @@ def _get_spikeinterface_settings(self, param_key, param_value): Generate settings kwargs for running KS4 in SpikeInterface. See `_get_kilosort_native_settings()` for some details. """ - settings = {} # copy.deepcopy(DEFAULT_SETTINGS) + settings = {} # copy.deepcopy(DEFAULT_SETTINGS) if param_key != "change_nothing": settings.update({param_key: param_value}) @@ -606,7 +579,7 @@ def _get_spikeinterface_settings(self, param_key, param_value): if param_key == "binning_depth": settings.update({"nblocks": 5}) - # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: + # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: # settings.pop(name) return settings diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 6c58c76813..b8930c8ccc 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -14,16 +14,17 @@ jobs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: 3.12 - name: Install dependencies run: | pip install requests packaging + pip install . - name: Fetch package versions from PyPI run: | @@ -47,7 +48,7 @@ jobs: ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v5 diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 4b7d0dbe6e..d4f0a26b3c 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -108,8 +108,8 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", - "use_binary_file": "If True, the Kilosort is run from a binary file. In this case, if the recording is not binary it is written to a binary file in the output folder" - "If False, the Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. Default is False.", + "use_binary_file": "If True and the recording is not binary compatible, then Kilosort is written to a binary file in the output folder. If False, the Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. " + "If the recording is binary compatible, then the sorter will always use the binary file. Default is False.", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -225,20 +225,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe = load_probe(probe_path=probe_filename) probe_name = "" - if params["use_binary_file"]: - if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): - # no copy - binary_description = recording.get_binary_description() - filename = str(binary_description["file_paths"][0]) - file_object = None - else: - # a local copy has been written - filename = str(sorter_output_folder / "recording.dat") - file_object = None + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + elif params["use_binary_file"]: + # a local copy has been written + filename = str(sorter_output_folder / "recording.dat") + file_object = None else: - # this internally concatenates the recording + # the recording is not binary compatible and no binary copy has been written. + # in this case, we use the RecordingExtractorAsArray object filename = "" file_object = RecordingExtractorAsArray(recording_extractor=recording) + data_dtype = recording.get_dtype() do_CAR = params["do_CAR"] @@ -346,21 +347,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - _ = save_sorting( - ops=ops, - results_dir=results_dir, - st=st, - clu=clu, - tF=tF, - Wall=Wall, - imin=bfile.imin, - tic0=tic0, - save_extra_vars=save_extra_vars, - save_preprocessed_copy=save_preprocessed_copy, - ) - else: - _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) + _ = save_sorting( + ops=ops, + results_dir=results_dir, + st=st, + clu=clu, + tF=tF, + Wall=Wall, + imin=bfile.imin, + tic0=tic0, + save_extra_vars=save_extra_vars, + save_preprocessed_copy=save_preprocessed_copy, + ) @classmethod def _get_result_from_folder(cls, sorter_output_folder): From f9dfa04190eaba160d5beec1c74300aaa2989fcc Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Sep 2024 07:35:14 +0200 Subject: [PATCH 06/19] Add highpass_cutoff and fix KS tests --- .github/scripts/test_kilosort4_ci.py | 29 ++++++++++++------- .../sorters/external/kilosort4.py | 4 ++- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index c7853a2add..96a037876f 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -21,19 +21,21 @@ import copy from typing import Any -import spikeinterface.full as si import numpy as np import torch import kilosort from kilosort.io import load_probe import pandas as pd -from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter import pytest -from probeinterface.io import write_prb -from kilosort.parameters import DEFAULT_SETTINGS from packaging.version import parse from importlib.metadata import version from inspect import signature + +import spikeinterface.full as si +from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter +from probeinterface.io import write_prb + +from kilosort.parameters import DEFAULT_SETTINGS from kilosort.run_kilosort import ( set_files, initialize_ops, @@ -66,6 +68,7 @@ ("nt", 93), ("nskip", 1), ("whitening_range", 16), + ("highpass_cutoff", 200), ("sig_interp", 5), ("nt0min", 25), ("dmin", 15), @@ -87,10 +90,11 @@ ("ccg_threshold", 1e12), ("acg_threshold", 1e12), ("cluster_downsampling", 2), - ("duplicate_spike_bins", 5), + ("duplicate_spike_ms", 0.3), ("drift_smoothing", [250, 250, 250]), - ("bad_channels", None), ("save_preprocessed_copy", False), + ("shift", 0), + ("scale", 1), ] @@ -194,7 +198,10 @@ def test_params_to_test(self): continue if param_key not in RUN_KILOSORT_ARGS: - assert DEFAULT_SETTINGS[param_key] != param_value, f"{param_key} values should be different in test." + assert DEFAULT_SETTINGS[param_key] != param_value, ( + f"{param_key} values should be different in test: " + f"{param_value} vs. {DEFAULT_SETTINGS[param_key]}" + ) def test_default_settings_all_represented(self): """ @@ -227,7 +234,7 @@ def test_spikeinterface_defaults_against_kilsort(self): # Testing Arguments ### def test_set_files_arguments(self): - self._check_arguments(set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"]) + self._check_arguments(set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]) def test_initialize_ops_arguments(self): expected_arguments = [ @@ -249,13 +256,13 @@ def test_compute_preprocessing_arguments(self): self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"]) def test_compute_drift_location_arguments(self): - self._check_arguments(compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object"]) + self._check_arguments(compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"]) def test_detect_spikes_arguments(self): - self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar"]) + self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]) def test_cluster_spikes_arguments(self): - self._check_arguments(cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"]) + self._check_arguments(cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]) def test_save_sorting_arguments(self): expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index d4f0a26b3c..b0ba054e2d 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -35,6 +35,7 @@ class Kilosort4Sorter(BaseSorter): "artifact_threshold": None, "nskip": 25, "whitening_range": 32, + "highpass_cutoff": 300, "binning_depth": 5, "sig_interp": 20, "drift_smoothing": [0.5, 0.5, 0.5], @@ -55,7 +56,7 @@ class Kilosort4Sorter(BaseSorter): "cluster_downsampling": 20, "cluster_pcs": 64, "x_centers": None, - "duplicate_spike_bins": 7, + "duplicate_spike_ms": 0.25, "do_correction": True, "keep_good_only": False, "save_extra_kwargs": False, @@ -80,6 +81,7 @@ class Kilosort4Sorter(BaseSorter): "artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.", "nskip": "Batch stride for computing whitening matrix. Default value: 25.", "whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.", + "highpass_cutoff": "High-pass filter cutoff frequency in Hz. Default value: 300.", "binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.", "sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.", "drift_smoothing": "Amount of gaussian smoothing to apply to the spatiotemporal drift estimation, for x,y,time axes in units of registration blocks (for x,y axes) and batch size (for time axis). The x,y smoothing has no effect for `nblocks = 1`.", From 1964f86b45e9bb96750b65bd84774f0d60595d5b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Sep 2024 08:00:00 +0200 Subject: [PATCH 07/19] test ks4 on ks4 changes --- .github/workflows/test_kilosort4.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index b8930c8ccc..5a8259726e 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -4,6 +4,9 @@ on: workflow_dispatch: schedule: - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC + push: + paths: + - '**/kilosort4.py' jobs: versions: From 8e9995d843ee7b0a43cbd877a6fb01a0c1fb7cb8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Sep 2024 08:02:44 +0200 Subject: [PATCH 08/19] Move testing scripts into scripts folder --- .github/run_tests.sh | 2 +- .github/{ => scripts}/build_job_summary.py | 0 .github/{ => scripts}/determine_testing_environment.py | 0 .github/{ => scripts}/import_test.py | 0 .github/workflows/all-tests.yml | 2 +- .github/workflows/core-test.yml | 2 +- .github/workflows/full-test-with-codecov.yml | 2 +- .github/workflows/test_imports.yml | 4 ++-- 8 files changed, 6 insertions(+), 6 deletions(-) rename .github/{ => scripts}/build_job_summary.py (100%) rename .github/{ => scripts}/determine_testing_environment.py (100%) rename .github/{ => scripts}/import_test.py (100%) diff --git a/.github/run_tests.sh b/.github/run_tests.sh index 558e0b64d3..02eb6ab8a1 100644 --- a/.github/run_tests.sh +++ b/.github/run_tests.sh @@ -10,5 +10,5 @@ fi pytest -m "$MARKER" -vv -ra --durations=0 --durations-min=0.001 | tee report.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of ${MARKER}" >> $GITHUB_STEP_SUMMARY -python $GITHUB_WORKSPACE/.github/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY +python $GITHUB_WORKSPACE/.github/scripts/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY rm report.txt diff --git a/.github/build_job_summary.py b/.github/scripts/build_job_summary.py similarity index 100% rename from .github/build_job_summary.py rename to .github/scripts/build_job_summary.py diff --git a/.github/determine_testing_environment.py b/.github/scripts/determine_testing_environment.py similarity index 100% rename from .github/determine_testing_environment.py rename to .github/scripts/determine_testing_environment.py diff --git a/.github/import_test.py b/.github/scripts/import_test.py similarity index 100% rename from .github/import_test.py rename to .github/scripts/import_test.py diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 8317d7bec4..5b583934ef 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -50,7 +50,7 @@ jobs: shell: bash run: | changed_files="${{ steps.changed-files.outputs.all_changed_files }}" - python .github/determine_testing_environment.py $changed_files + python .github/scripts/determine_testing_environment.py $changed_files - name: Display testing environment shell: bash diff --git a/.github/workflows/core-test.yml b/.github/workflows/core-test.yml index a513d48f3b..1dbf0f5109 100644 --- a/.github/workflows/core-test.yml +++ b/.github/workflows/core-test.yml @@ -39,7 +39,7 @@ jobs: pip install tabulate echo "# Timing profile of core tests in ${{matrix.os}}" >> $GITHUB_STEP_SUMMARY # Outputs markdown summary to standard output - python ./.github/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/build_job_summary.py report.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY rm report.txt shell: bash # Necessary for pipeline to work on windows diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index ab4a083ae1..6a222f5e25 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -47,7 +47,7 @@ jobs: source ${{ github.workspace }}/test_env/bin/activate pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY - python ./.github/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY rm report_full.txt - uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_imports.yml b/.github/workflows/test_imports.yml index d39fc37242..a2631f6eb7 100644 --- a/.github/workflows/test_imports.yml +++ b/.github/workflows/test_imports.yml @@ -34,7 +34,7 @@ jobs: echo "## OS: ${{ matrix.os }}" >> $GITHUB_STEP_SUMMARY echo "---" >> $GITHUB_STEP_SUMMARY echo "### Import times when only installing only core dependencies " >> $GITHUB_STEP_SUMMARY - python ./.github/import_test.py >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/import_test.py >> $GITHUB_STEP_SUMMARY shell: bash # Necessary for pipeline to work on windows - name: Install in full mode run: | @@ -44,5 +44,5 @@ jobs: # Add a header to separate the two profiles echo "---" >> $GITHUB_STEP_SUMMARY echo "### Import times when installing full dependencies in " >> $GITHUB_STEP_SUMMARY - python ./.github/import_test.py >> $GITHUB_STEP_SUMMARY + python ./.github/scripts/import_test.py >> $GITHUB_STEP_SUMMARY shell: bash # Necessary for pipeline to work on windows From 219bee4ed768b5de1543e5e797d56973e8ac2664 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Sep 2024 08:47:33 +0200 Subject: [PATCH 09/19] change trigger --- .github/workflows/test_kilosort4.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 5a8259726e..42e6140917 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: schedule: - cron: "0 12 * * 0" # Weekly on Sunday at noon UTC - push: + pull_request: paths: - '**/kilosort4.py' From e26e143c0c2ec6af79e509e1d09c6601717b93cf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Sep 2024 08:51:38 +0200 Subject: [PATCH 10/19] Remove last conditions on prior ks versions --- .github/scripts/test_kilosort4_ci.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 96a037876f..009a2c447c 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -213,8 +213,6 @@ def test_default_settings_all_represented(self): for param_key in DEFAULT_SETTINGS: if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: - if parse(version("kilosort")) == parse("4.0.9") and param_key == "nblocks": - continue assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." def test_spikeinterface_defaults_against_kilsort(self): @@ -267,8 +265,7 @@ def test_cluster_spikes_arguments(self): def test_save_sorting_arguments(self): expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] - if parse(version("kilosort")) > parse("4.0.11"): - expected_arguments.append("save_preprocessed_copy") + expected_arguments.append("save_preprocessed_copy") self._check_arguments(save_sorting, expected_arguments) @@ -369,14 +366,9 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa assert ops[param_key] == param_value # Finally, check out test parameters actually change the output of - # KS4, ensuring our tests are actually doing something. This is not - # done prior to 4.0.4 because a number of parameters seem to stop - # having an effect. This is probably due to small changes in their - # behaviour, and the test file chosen here. - if parse(version("kilosort")) > parse("4.0.4"): - self._check_test_parameters_are_changing_the_output(results, default_results, param_key) - - @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") + # KS4, ensuring our tests are actually doing something. + self._check_test_parameters_are_changing_the_output(results, default_results, param_key) + def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): """ Test the SpikeInterface wrappers `do_correction` argument. We set @@ -415,7 +407,6 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): assert np.array_equal(results["ks"]["st"], results["si"]["st"]) assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) - @pytest.mark.skipif(parse(version("kilosort")) == parse("4.0.9"), reason="nblock=0 fails on KS4=4.0.9") @pytest.mark.parametrize( "param_to_test", [ From 9c338dd9688455e7595433849399e7ee313301b0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Sep 2024 09:35:19 +0200 Subject: [PATCH 11/19] Fix KS parameters in tests --- .github/scripts/test_kilosort4_ci.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 009a2c447c..0593534010 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -90,11 +90,9 @@ ("ccg_threshold", 1e12), ("acg_threshold", 1e12), ("cluster_downsampling", 2), - ("duplicate_spike_ms", 0.3), ("drift_smoothing", [250, 250, 250]), - ("save_preprocessed_copy", False), - ("shift", 0), - ("scale", 1), + # Not tested beacuse with ground truth data it doesn't change the results + # ("duplicate_spike_ms", 0.3), ] @@ -210,6 +208,8 @@ def test_default_settings_all_represented(self): on the KS side. """ tested_keys = [entry[0] for entry in PARAMS_TO_TEST] + additional_non_tested_keys = ["shift", "scale", "save_preprocessed_copy", "duplicate_spike_ms"] + tested_keys += additional_non_tested_keys for param_key in DEFAULT_SETTINGS: if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: @@ -407,6 +407,11 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): assert np.array_equal(results["ks"]["st"], results["si"]["st"]) assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + + def test_kilosort4_use_binary_file(self, recording_and_paths, tmp_path): + # TODO + pass + @pytest.mark.parametrize( "param_to_test", [ From 87fbe55a7118b8e741ad9b801cba50f11cdc379a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Sep 2024 10:47:01 +0200 Subject: [PATCH 12/19] More cleanup of KS4 tests --- .github/scripts/test_kilosort4_ci.py | 265 +++++++++++++-------------- 1 file changed, 128 insertions(+), 137 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 0593534010..35946a5a56 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -10,31 +10,27 @@ changes when skipping KS4 preprocessing is true, because this takes a slightly different path through the kilosort4.py wrapper logic. This also checks that changing the parameter changes the test output from default - on our test case (otherwise, the test could not detect a failure). This is possible - for nearly all parameters, see `_check_test_parameters_are_changing_the_output()`. + on our test case (otherwise, the test could not detect a failure). - Test that kilosort functions called from `kilosort4.py` wrapper have the expected input signatures - Do some tests to check all KS4 parameters are tested against. """ - +import pytest import copy from typing import Any +from inspect import signature + import numpy as np import torch -import kilosort -from kilosort.io import load_probe -import pandas as pd -import pytest -from packaging.version import parse -from importlib.metadata import version -from inspect import signature import spikeinterface.full as si +from spikeinterface.core.testing import check_sortings_equal from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter from probeinterface.io import write_prb +import kilosort from kilosort.parameters import DEFAULT_SETTINGS from kilosort.run_kilosort import ( set_files, @@ -47,59 +43,62 @@ get_run_parameters, ) from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered -from kilosort.parameters import DEFAULT_SETTINGS -from kilosort import preprocessing as ks_preprocessing + RUN_KILOSORT_ARGS = ["do_CAR", "invert_sign", "save_preprocessed_copy"] # "device", "progress_bar", "save_extra_vars" are not tested. "save_extra_vars" could be. # Setup Params to test #### -PARAMS_TO_TEST = [ - # Not tested - # ("torch_device", "auto") - # Stable across KS version 4.0.16 - 4.0.X (?) - ("change_nothing", None), - ("nblocks", 0), - ("do_CAR", False), - ("batch_size", 42743), - ("Th_universal", 12), - ("Th_learned", 14), - ("invert_sign", True), - ("nt", 93), - ("nskip", 1), - ("whitening_range", 16), - ("highpass_cutoff", 200), - ("sig_interp", 5), - ("nt0min", 25), - ("dmin", 15), - ("dminx", 16), - ("min_template_size", 15), - ("template_sizes", 10), - ("nearest_chans", 8), - ("nearest_templates", 35), - ("max_channel_distance", 5), - ("templates_from_data", False), - ("n_templates", 10), - ("n_pcs", 3), - ("Th_single_ch", 4), - ("x_centers", 5), - ("binning_depth", 1), - # Note: These don't change the results from - # default when applied to the test case. - ("artifact_threshold", 200), - ("ccg_threshold", 1e12), - ("acg_threshold", 1e12), - ("cluster_downsampling", 2), - ("drift_smoothing", [250, 250, 250]), - # Not tested beacuse with ground truth data it doesn't change the results - # ("duplicate_spike_ms", 0.3), +PARAMS_TO_TEST_DICT = { + "nblocks": 0, + "do_CAR": False, + "batch_size": 42743, + "Th_universal": 12, + "Th_learned": 14, + "invert_sign": True, + "nt": 93, + "nskip": 1, + "whitening_range": 16, + "highpass_cutoff": 200, + "sig_interp": 5, + "nt0min": 25, + "dmin": 15, + "dminx": 16, + "min_template_size": 15, + "template_sizes": 10, + "nearest_chans": 8, + "nearest_templates": 35, + "max_channel_distance": 5, + "templates_from_data": False, + "n_templates": 10, + "n_pcs": 3, + "Th_single_ch": 4, + "x_centers": 5, + "binning_depth": 1, + "drift_smoothing": [250, 250, 250], + "artifact_threshold": 200, + "ccg_threshold": 1e12, + "acg_threshold": 1e12, + "cluster_downsampling": 2, + "duplicate_spike_ms": 0.3, +} + +PARAMS_TO_TEST = list(PARAMS_TO_TEST_DICT.keys()) + +PARAMETERS_NOT_AFFECTING_RESULTS = [ + "artifact_threshold", + "ccg_threshold", + "acg_threshold", + "cluster_downsampling", + "cluster_pcs", + "duplicate_spike_ms" # this is because gorund-truth spikes don't have violations ] - +# THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST # if parse(version("kilosort")) >= parse("4.0.X"): -# PARAMS_TO_TEST.extend( +# PARAMS_TO_TEST_DICT.update( # [ -# ("new_param", new_values), +# {"new_param": new_value}, # ] # ) @@ -122,7 +121,7 @@ def recording_and_paths(self, tmp_path_factory): return (recording, paths) @pytest.fixture(scope="session") - def default_results(self, recording_and_paths): + def default_kilosort_sorting(self, recording_and_paths): """ Because we check each parameter at a time and check the KS4 and SpikeInterface versions match, if changing the parameter @@ -133,7 +132,7 @@ def default_results(self, recording_and_paths): """ recording, paths = recording_and_paths - settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, "change_nothing", None) + settings, _, ks_format_probe = self._get_kilosort_native_settings(recording, paths, None, None) defaults_ks_output_dir = paths["session_scope_tmp_path"] / "default_ks_output" @@ -144,9 +143,8 @@ def default_results(self, recording_and_paths): results_dir=defaults_ks_output_dir, ) - default_results = self._get_sorting_output(defaults_ks_output_dir) + return si.read_kilosort(defaults_ks_output_dir) - return default_results def _get_ground_truth_recording(self): """ @@ -185,16 +183,11 @@ def _save_ground_truth_recording(self, recording, tmp_path): # Tests ###### def test_params_to_test(self): """ - Test that all values in PARAMS_TO_TEST are + Test that all values in PARAMS_TO_TEST_DICT are different to the default values used in Kilosort, otherwise there is no point to the test. """ - for parameter in PARAMS_TO_TEST: - param_key, param_value = parameter - - if param_key == "change_nothing": - continue - + for param_key, param_value in PARAMS_TO_TEST_DICT.items(): if param_key not in RUN_KILOSORT_ARGS: assert DEFAULT_SETTINGS[param_key] != param_value, ( f"{param_key} values should be different in test: " @@ -207,8 +200,8 @@ def test_default_settings_all_represented(self): PARAMS_TO_TEST, otherwise we are missing settings added on the KS side. """ - tested_keys = [entry[0] for entry in PARAMS_TO_TEST] - additional_non_tested_keys = ["shift", "scale", "save_preprocessed_copy", "duplicate_spike_ms"] + tested_keys = PARAMS_TO_TEST + additional_non_tested_keys = ["shift", "scale", "save_preprocessed_copy"] tested_keys += additional_non_tested_keys for param_key in DEFAULT_SETTINGS: @@ -315,7 +308,7 @@ def _check_arguments(self, object_, expected_arguments): # Full Test #### @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) - def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, parameter): + def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp_path, parameter): """ Given a recording, paths to raw data, and a parameter to change, run KS4 natively and within the SpikeInterface wrapper with the @@ -323,7 +316,8 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa check the outputs are the same. """ recording, paths = recording_and_paths - param_key, param_value = parameter + param_key = parameter + param_value = PARAMS_TO_TEST_DICT[param_key] # Setup parameters for KS4 and run it natively kilosort_output_dir = tmp_path / "kilosort_output_dir" @@ -340,11 +334,12 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa results_dir=kilosort_output_dir, **run_kilosort_kwargs, ) + sorting_ks4 = si.read_kilosort(kilosort_output_dir) # Setup Parameters for SI and KS4 through SI spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) - si.run_sorter( + sorting_si = si.run_sorter( "kilosort4", recording, remove_existing_folder=True, @@ -353,21 +348,19 @@ def test_kilosort4_main(self, recording_and_paths, default_results, tmp_path, pa ) # Get the results and check they match - results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) - - assert np.array_equal(results["ks"]["st"], results["si"]["st"]), f"{param_key} spike times different" - assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]), f"{param_key} cluster assignment different" + check_sortings_equal(sorting_ks4, sorting_si) # Check the ops file in KS4 output is as expected. This is saved on the # SI side so not an extremely robust addition, but it can't hurt. - if param_key != "change_nothing": - ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) - ops = ops.tolist() # strangely this makes a dict - assert ops[param_key] == param_value + ops = np.load(spikeinterface_output_dir / "sorter_output" / "ops.npy", allow_pickle=True) + ops = ops.tolist() # strangely this makes a dict + assert ops[param_key] == param_value # Finally, check out test parameters actually change the output of - # KS4, ensuring our tests are actually doing something. - self._check_test_parameters_are_changing_the_output(results, default_results, param_key) + # KS4, ensuring our tests are actually doing something (exxcept for some params). + if param_key not in PARAMETERS_NOT_AFFECTING_RESULTS: + with pytest.raises(AssertionError): + check_sortings_equal(default_kilosort_sorting, sorting_si) def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): """ @@ -391,9 +384,10 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): results_dir=kilosort_output_dir, do_CAR=True, ) + sorting_ks = si.read_kilosort(kilosort_output_dir) spikeinterface_settings = self._get_spikeinterface_settings("nblocks", 1) - si.run_sorter( + sorting_si = si.run_sorter( "kilosort4", recording, remove_existing_folder=True, @@ -401,21 +395,46 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): do_correction=False, **spikeinterface_settings, ) + check_sortings_equal(sorting_ks, sorting_si) - results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) - assert np.array_equal(results["ks"]["st"], results["si"]["st"]) - assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + def test_use_binary_file(self, tmp_path): + """ + Test that the SpikeInterface wrapper can run KS4 using a binary file as input or directly + from the recording. + """ + recording = self._get_ground_truth_recording() + recording_bin = recording.save() + # run with SI wrapper + sorting_ks4 = si.run_sorter( + "kilosort4", + recording, + folder = tmp_path / "spikeinterface_output_dir_wrapper", + use_binary_file=False, + remove_existing_folder=True, + ) + sorting_ks4_bin = si.run_sorter( + "kilosort4", + recording_bin, + folder = tmp_path / "spikeinterface_output_dir_bin", + use_binary_file=False, + remove_existing_folder=True, + ) + sorting_ks4_non_bin = si.run_sorter( + "kilosort4", + recording, + folder = tmp_path / "spikeinterface_output_dir_non_bin", + use_binary_file=True, + remove_existing_folder=True, + ) - def test_kilosort4_use_binary_file(self, recording_and_paths, tmp_path): - # TODO - pass + check_sortings_equal(sorting_ks4, sorting_ks4_bin) + check_sortings_equal(sorting_ks4, sorting_ks4_non_bin) @pytest.mark.parametrize( "param_to_test", [ - ("change_nothing", None), ("do_CAR", False), ("batch_size", 42743), ("Th_learned", 14), @@ -496,6 +515,7 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): ) monkeypatch.undo() + si.read_kilosort(kilosort_output_dir) # Now, run kilosort through spikeinterface with the same options. spikeinterface_settings = self._get_spikeinterface_settings(param_key, param_value) @@ -517,29 +537,17 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): # memory file. Because in this test recordings are preprocessed, there are # some filter edge effects that depend on the chunking in `get_traces()`. # These are all extremely close (usually just 1 spike, 1 idx different). - results = self._get_sorting_output(kilosort_output_dir, spikeinterface_output_dir) + results = {} + results["ks"] = {} + results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") + results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") + results["si"] = {} + results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") + results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) - # Helpers ###### - def _check_test_parameters_are_changing_the_output(self, results, default_results, param_key): - """ - If nothing is changed, default vs. results outputs are identical. - Otherwise, check they are not the same. Can't figure out how to get - the skipped three parameters below to change the results on this - small test file. - """ - if param_key in ["acg_threshold", "ccg_threshold", "artifact_threshold", "cluster_downsampling", "cluster_pcs"]: - return - - if param_key == "change_nothing": - assert all(default_results["ks"]["st"] == results["ks"]["st"]) and all( - default_results["ks"]["clus"] == results["ks"]["clus"] - ), f"{param_key} changed somehow!." - else: - assert not (default_results["ks"]["st"].size == results["ks"]["st"].size) or not all( - default_results["ks"]["clus"] == results["ks"]["clus"] - ), f"{param_key} results did not change with parameter change." + ##### Helpers ###### def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): """ Function to generate the settings and function inputs to run kilosort. @@ -554,16 +562,18 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value "n_chan_bin": recording.get_num_channels(), "fs": recording.get_sampling_frequency(), } + run_kilosort_kwargs = {} - if param_key == "binning_depth": - settings.update({"nblocks": 5}) + if param_key is not None: + if param_key == "binning_depth": + settings.update({"nblocks": 5}) - if param_key in RUN_KILOSORT_ARGS: - run_kilosort_kwargs = {param_key: param_value} - else: - if param_key != "change_nothing": - settings.update({param_key: param_value}) - run_kilosort_kwargs = {} + if param_key in RUN_KILOSORT_ARGS: + run_kilosort_kwargs = {param_key: param_value} + else: + if param_key != "change_nothing": + settings.update({param_key: param_value}) + run_kilosort_kwargs = {} ks_format_probe = load_probe(paths["probe_path"]) @@ -576,31 +586,12 @@ def _get_spikeinterface_settings(self, param_key, param_value): """ settings = {} # copy.deepcopy(DEFAULT_SETTINGS) - if param_key != "change_nothing": - settings.update({param_key: param_value}) - if param_key == "binning_depth": settings.update({"nblocks": 5}) + settings.update({param_key: param_value}) + # for name in ["n_chan_bin", "fs", "tmin", "tmax"]: # settings.pop(name) return settings - - def _get_sorting_output(self, kilosort_output_dir=None, spikeinterface_output_dir=None) -> dict[str, Any]: - """ - Load the results of sorting into a dict for easy comparison. - """ - results = { - "si": {}, - "ks": {}, - } - if kilosort_output_dir: - results["ks"]["st"] = np.load(kilosort_output_dir / "spike_times.npy") - results["ks"]["clus"] = np.load(kilosort_output_dir / "spike_clusters.npy") - - if spikeinterface_output_dir: - results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") - results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") - - return results From 10b7e1adc68c51fb784c3529f3638b3ab0d9de3d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Sep 2024 12:20:20 +0200 Subject: [PATCH 13/19] Remove last change_nothing --- .github/scripts/test_kilosort4_ci.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 35946a5a56..dbd8135b9a 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -17,6 +17,7 @@ - Do some tests to check all KS4 parameters are tested against. """ + import pytest import copy from typing import Any @@ -91,7 +92,7 @@ "acg_threshold", "cluster_downsampling", "cluster_pcs", - "duplicate_spike_ms" # this is because gorund-truth spikes don't have violations + "duplicate_spike_ms", # this is because gorund-truth spikes don't have violations ] # THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST @@ -145,7 +146,6 @@ def default_kilosort_sorting(self, recording_and_paths): return si.read_kilosort(defaults_ks_output_dir) - def _get_ground_truth_recording(self): """ A ground truth recording chosen to be as small as possible (for speed). @@ -225,7 +225,9 @@ def test_spikeinterface_defaults_against_kilsort(self): # Testing Arguments ### def test_set_files_arguments(self): - self._check_arguments(set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"]) + self._check_arguments( + set_files, ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir", "bad_channels"] + ) def test_initialize_ops_arguments(self): expected_arguments = [ @@ -247,13 +249,17 @@ def test_compute_preprocessing_arguments(self): self._check_arguments(compute_preprocessing, ["ops", "device", "tic0", "file_object"]) def test_compute_drift_location_arguments(self): - self._check_arguments(compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"]) + self._check_arguments( + compute_drift_correction, ["ops", "device", "tic0", "progress_bar", "file_object", "clear_cache"] + ) def test_detect_spikes_arguments(self): self._check_arguments(detect_spikes, ["ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]) def test_cluster_spikes_arguments(self): - self._check_arguments(cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"]) + self._check_arguments( + cluster_spikes, ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar", "clear_cache"] + ) def test_save_sorting_arguments(self): expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] @@ -397,7 +403,6 @@ def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): ) check_sortings_equal(sorting_ks, sorting_si) - def test_use_binary_file(self, tmp_path): """ Test that the SpikeInterface wrapper can run KS4 using a binary file as input or directly @@ -410,21 +415,21 @@ def test_use_binary_file(self, tmp_path): sorting_ks4 = si.run_sorter( "kilosort4", recording, - folder = tmp_path / "spikeinterface_output_dir_wrapper", + folder=tmp_path / "spikeinterface_output_dir_wrapper", use_binary_file=False, remove_existing_folder=True, ) sorting_ks4_bin = si.run_sorter( "kilosort4", recording_bin, - folder = tmp_path / "spikeinterface_output_dir_bin", + folder=tmp_path / "spikeinterface_output_dir_bin", use_binary_file=False, remove_existing_folder=True, ) sorting_ks4_non_bin = si.run_sorter( "kilosort4", recording, - folder = tmp_path / "spikeinterface_output_dir_non_bin", + folder=tmp_path / "spikeinterface_output_dir_non_bin", use_binary_file=True, remove_existing_folder=True, ) @@ -546,7 +551,6 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) - ##### Helpers ###### def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): """ @@ -571,8 +575,7 @@ def _get_kilosort_native_settings(self, recording, paths, param_key, param_value if param_key in RUN_KILOSORT_ARGS: run_kilosort_kwargs = {param_key: param_value} else: - if param_key != "change_nothing": - settings.update({param_key: param_value}) + settings.update({param_key: param_value}) run_kilosort_kwargs = {} ks_format_probe = load_probe(paths["probe_path"]) From 007b64de84ce0ad432d980064e9226d9cc83df39 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Sep 2024 17:02:55 +0200 Subject: [PATCH 14/19] Allow use_binary_file=None (default) and add delete_recording_dat param --- .github/scripts/test_kilosort4_ci.py | 31 +++++++-- .../sorters/external/kilosort4.py | 65 ++++++++++++------- 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index dbd8135b9a..61c10fd8e8 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -415,27 +415,46 @@ def test_use_binary_file(self, tmp_path): sorting_ks4 = si.run_sorter( "kilosort4", recording, - folder=tmp_path / "spikeinterface_output_dir_wrapper", - use_binary_file=False, + folder=tmp_path / "ks4_output_si_wrapper_default", + use_binary_file=None, remove_existing_folder=True, ) sorting_ks4_bin = si.run_sorter( "kilosort4", recording_bin, - folder=tmp_path / "spikeinterface_output_dir_bin", + folder=tmp_path / "ks4_output_bin_default", + use_binary_file=None, + remove_existing_folder=True, + ) + sorting_ks4_force_binary = si.run_sorter( + "kilosort4", + recording, + folder=tmp_path / "ks4_output_force_bin", + use_binary_file=True, + remove_existing_folder=True, + ) + assert not (tmp_path / "ks4_output_force_bin" / "sorter_output" / "recording.dat").exists() + sorting_ks4_force_non_binary = si.run_sorter( + "kilosort4", + recording_bin, + folder=tmp_path / "ks4_output_force_wrapper", use_binary_file=False, remove_existing_folder=True, ) - sorting_ks4_non_bin = si.run_sorter( + # test deleting recording.dat + sorting_ks4_force_binary_keep = si.run_sorter( "kilosort4", recording, - folder=tmp_path / "spikeinterface_output_dir_non_bin", + folder=tmp_path / "ks4_output_force_bin_keep", use_binary_file=True, + delete_recording_dat=False, remove_existing_folder=True, ) + assert (tmp_path / "ks4_output_force_bin_keep" / "sorter_output" / "recording.dat").exists() check_sortings_equal(sorting_ks4, sorting_ks4_bin) - check_sortings_equal(sorting_ks4, sorting_ks4_non_bin) + check_sortings_equal(sorting_ks4, sorting_ks4_force_binary) + check_sortings_equal(sorting_ks4, sorting_ks4_force_non_binary) @pytest.mark.parametrize( "param_to_test", diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index b0ba054e2d..8a15642af4 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -65,7 +65,8 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": False, "torch_device": "auto", "bad_channels": None, - "use_binary_file": False, + "use_binary_file": None, + "delete_recording_dat": True, } _params_description = { @@ -110,8 +111,10 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", - "use_binary_file": "If True and the recording is not binary compatible, then Kilosort is written to a binary file in the output folder. If False, the Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. " - "If the recording is binary compatible, then the sorter will always use the binary file. Default is False.", + "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binaru compatible, it is written to a binary file in the output folder. " + "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " + "Default is None.", + "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True.", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -172,15 +175,16 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): probe_filename = sorter_output_folder / "probe.prb" write_prb(probe_filename, pg) - if params["use_binary_file"] and not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): - # local copy needed - binary_file_path = sorter_output_folder / "recording.dat" - write_binary_recording( - recording=recording, - file_paths=[binary_file_path], - **get_job_kwargs(params, verbose), - ) - params["filename"] = str(binary_file_path) + if params["use_binary_file"]: + if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # local copy needed + binary_file_path = sorter_output_folder / "recording.dat" + write_binary_recording( + recording=recording, + file_paths=[binary_file_path], + **get_job_kwargs(params, verbose), + ) + params["filename"] = str(binary_file_path) @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): @@ -227,18 +231,30 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe = load_probe(probe_path=probe_filename) probe_name = "" - if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): - # no copy - binary_description = recording.get_binary_description() - filename = str(binary_description["file_paths"][0]) - file_object = None + if params["use_binary_file"] is None: + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + else: + # the recording is not binary compatible and no binary copy has been written. + # in this case, we use the RecordingExtractorAsArray object + filename = "" + file_object = RecordingExtractorAsArray(recording_extractor=recording) elif params["use_binary_file"]: - # a local copy has been written - filename = str(sorter_output_folder / "recording.dat") - file_object = None + # here we force the use of a binary file + if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + # no copy + binary_description = recording.get_binary_description() + filename = str(binary_description["file_paths"][0]) + file_object = None + else: + # a local copy has been written + filename = str(sorter_output_folder / "recording.dat") + file_object = None else: - # the recording is not binary compatible and no binary copy has been written. - # in this case, we use the RecordingExtractorAsArray object + # here we force the use of the RecordingExtractorAsArray object filename = "" file_object = RecordingExtractorAsArray(recording_extractor=recording) @@ -362,6 +378,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): save_preprocessed_copy=save_preprocessed_copy, ) + if params["delete_recording_dat"]: + # only delete dat file if it was created by the wrapper + if (sorter_output_folder / "recording.dat").is_file(): + (sorter_output_folder / "recording.dat").unlink() + @classmethod def _get_result_from_folder(cls, sorter_output_folder): return KilosortBase._get_result_from_folder(sorter_output_folder) From f399f6ecc84ba0da3e33f5aefe99fa0248a7a578 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Sep 2024 09:48:40 +0200 Subject: [PATCH 15/19] Update .github/scripts/test_kilosort4_ci.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- .github/scripts/test_kilosort4_ci.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 61c10fd8e8..df4cb64216 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -92,7 +92,7 @@ "acg_threshold", "cluster_downsampling", "cluster_pcs", - "duplicate_spike_ms", # this is because gorund-truth spikes don't have violations + "duplicate_spike_ms", # this is because ground-truth spikes don't have violations ] # THIS IS A PLACEHOLDER FOR FUTURE PARAMS TO TEST From 464c6e3d59531cf8dd4126038c4ff23f9a33b69a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Sep 2024 14:19:22 +0200 Subject: [PATCH 16/19] Update src/spikeinterface/sorters/external/kilosort4.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 8a15642af4..183f26d86c 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -111,7 +111,7 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", - "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binaru compatible, it is written to a binary file in the output folder. " + "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " "Default is None.", "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True.", From 0ed4876dfedb175ca08c90d94c8a1d3e215d0586 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Sep 2024 14:25:57 +0200 Subject: [PATCH 17/19] Extend check on clus --- .github/scripts/test_kilosort4_ci.py | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index df4cb64216..1da2f2ba92 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -569,6 +569,7 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): results["si"]["st"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_times.npy") results["si"]["clus"] = np.load(spikeinterface_output_dir / "sorter_output" / "spike_clusters.npy") assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) + assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) ##### Helpers ###### def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): From 8fbf100dfc0be3032a85f02a0cd857a42edea53a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:50:29 +0200 Subject: [PATCH 18/19] Expose clear_Cache argument in KS4 --- .github/scripts/test_kilosort4_ci.py | 24 +++++++++++++++++++ .../sorters/external/kilosort4.py | 23 +++++++++++++++--- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 1da2f2ba92..6eeb71f1dd 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -368,6 +368,30 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp with pytest.raises(AssertionError): check_sortings_equal(default_kilosort_sorting, sorting_si) + def test_clear_cache(self,recording_and_paths, tmp_path): + """ + Test clear_cache parameter in kilosort4.run_kilosort + """ + recording, paths = recording_and_paths + + spikeinterface_output_dir = tmp_path / "spikeinterface_output_clear" + sorting_si_clear = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + clear_cache=True + ) + spikeinterface_output_dir = tmp_path / "spikeinterface_output_no_clear" + sorting_si_no_clear = si.run_sorter( + "kilosort4", + recording, + remove_existing_folder=True, + folder=spikeinterface_output_dir, + clear_cache=False + ) + check_sortings_equal(sorting_si_clear, sorting_si_no_clear) + def test_kilosort4_no_correction(self, recording_and_paths, tmp_path): """ Test the SpikeInterface wrappers `do_correction` argument. We set diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 183f26d86c..4a8c9d1782 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -65,6 +65,7 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": False, "torch_device": "auto", "bad_channels": None, + "clear_cache": False, "use_binary_file": None, "delete_recording_dat": True, } @@ -111,6 +112,7 @@ class Kilosort4Sorter(BaseSorter): "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", + "clear_cache": "If True, force pytorch to free up memory reserved for its cache in between memory-intensive operations. Note that setting `clear_cache=True` is NOT recommended unless you encounter GPU out-of-memory errors, since this can result in slower sorting.", "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " "Default is None.", @@ -284,6 +286,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): data_dir = "" results_dir = sorter_output_folder bad_channels = params["bad_channels"] + clear_cache = params["clear_cache"] filename, data_dir, results_dir, probe = set_files( settings=settings, @@ -347,17 +350,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # this function applies both preprocessing and drift correction ops, bfile, st0 = compute_drift_correction( - ops=ops, device=device, tic0=tic0, progress_bar=progress_bar, file_object=file_object + ops=ops, + device=device, + tic0=tic0, + progress_bar=progress_bar, + file_object=file_object, + clear_cache=clear_cache, ) if save_preprocessed_copy: save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) # Sort spikes and save results - st, tF, _, _ = detect_spikes(ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar) + st, tF, _, _ = detect_spikes( + ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar, clear_cache=clear_cache + ) clu, Wall = cluster_spikes( - st=st, tF=tF, ops=ops, device=device, bfile=bfile, tic0=tic0, progress_bar=progress_bar + st=st, + tF=tF, + ops=ops, + device=device, + bfile=bfile, + tic0=tic0, + progress_bar=progress_bar, + clear_cache=clear_cache, ) if params["skip_kilosort_preprocessing"]: From fd61bb6cd0baa65efd5fde22af95da9c80c9d8cf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 14:58:27 +0200 Subject: [PATCH 19/19] Explicitly add (spikeinterface parameter) to KS4 param description --- .../sorters/external/kilosort4.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 4a8c9d1782..e73ac2cb6c 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -57,15 +57,15 @@ class Kilosort4Sorter(BaseSorter): "cluster_pcs": 64, "x_centers": None, "duplicate_spike_ms": 0.25, - "do_correction": True, - "keep_good_only": False, - "save_extra_kwargs": False, - "skip_kilosort_preprocessing": False, "scaleproc": None, "save_preprocessed_copy": False, "torch_device": "auto", "bad_channels": None, "clear_cache": False, + "save_extra_vars": False, + "do_correction": True, + "keep_good_only": False, + "skip_kilosort_preprocessing": False, "use_binary_file": None, "delete_recording_dat": True, } @@ -105,18 +105,19 @@ class Kilosort4Sorter(BaseSorter): "cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.", "x_centers": "Number of x-positions to use when determining center points for template groupings. If None, this will be determined automatically by finding peaks in channel density. For 2D array type probes, we recommend specifying this so that centers are placed every few hundred microns.", "duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 7.", - "do_correction": "If True, drift correction is performed", - "save_extra_kwargs": "If True, additional kwargs are saved to the output", - "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", + "save_extra_vars": "If True, additional kwargs are saved to the output", "scaleproc": "int16 scaling of whitened data, if None set to 200.", - "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", + "save_preprocessed_copy": "Save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", "clear_cache": "If True, force pytorch to free up memory reserved for its cache in between memory-intensive operations. Note that setting `clear_cache=True` is NOT recommended unless you encounter GPU out-of-memory errors, since this can result in slower sorting.", + "do_correction": "If True, drift correction is performed. Default is True. (spikeinterface parameter)", + "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing. (spikeinterface parameter)", + "keep_good_only": "If True, only the units labeled as 'good' by Kilosort are returned in the output. (spikeinterface parameter)", "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " - "Default is None.", - "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True.", + "Default is None. (spikeinterface parameter)", + "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True. (spikeinterface parameter)", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -264,7 +265,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] - save_extra_vars = params["save_extra_kwargs"] + save_extra_vars = params["save_extra_vars"] save_preprocessed_copy = params["save_preprocessed_copy"] progress_bar = None settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS}