From 7bd22357cb1f5e27bc7d5e80f587daa00c28836a Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:07:50 +0100 Subject: [PATCH] Added bad channels for kilosort >= 4.0.14 --- src/spikeinterface/sorters/external/kilosort4.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a7f40a9558..a78e14f29a 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -57,6 +57,7 @@ class Kilosort4Sorter(BaseSorter): "skip_kilosort_preprocessing": False, "scaleproc": None, "torch_device": "auto", + "bad_channels": None, } _params_description = { @@ -99,6 +100,7 @@ class Kilosort4Sorter(BaseSorter): "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "torch_device": "Select the torch device auto/cuda/cpu", + "bad_channels": "A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -205,7 +207,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # NOTE: Also modifies settings in-place data_dir = "" results_dir = sorter_output_folder - filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) + bad_channels = params["bad_channels"] + + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.14"): + filename, data_dir, results_dir, probe = set_files( + settings, filename, probe, probe_name, data_dir, results_dir, bad_channels + ) + else: + filename, data_dir, results_dir, probe = set_files( + settings, filename, probe, probe_name, data_dir, results_dir + ) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = (