From 4336f0d6d3ef5595a45242cab6985fa2e638fe3f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 29 Jul 2024 19:24:28 +0100 Subject: [PATCH 1/4] Expose 'save_preprocessed_copy' in KS4 wrapper. --- src/spikeinterface/sorters/external/kilosort4.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a7f40a9558..a904866629 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -56,6 +56,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": False, "skip_kilosort_preprocessing": False, "scaleproc": None, + "save_preprocessed_copy": False, "torch_device": "auto", } @@ -98,6 +99,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", + "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory", "torch_device": "Select the torch device auto/cuda/cpu", } @@ -186,6 +188,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] save_extra_vars = params["save_extra_kwargs"] + save_preprocessed_copy = (params["save_preprocessed_copy"],) progress_bar = None settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS} settings_ks["n_chan_bin"] = recording.get_num_channels() @@ -207,7 +210,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): results_dir = sorter_output_folder filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) + ops = initialize_ops( + settings, + probe, + recording.get_dtype(), + do_CAR, + invert_sign, + device, + save_preprocesed_copy=save_preprocessed_copy, + ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) ) From 20b4d2fcf95171ea5304339c689170e52429d4fd Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 31 Jul 2024 15:04:57 +0100 Subject: [PATCH 2/4] Edit kilosort4.py to match the ks4 'run_sorter' function body. --- .../sorters/external/kilosort4.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a904866629..16918128a2 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -155,7 +155,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): save_sorting, get_run_parameters, ) - from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered + from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered, save_preprocessing from kilosort.parameters import DEFAULT_SETTINGS import time @@ -188,7 +188,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR = params["do_CAR"] invert_sign = params["invert_sign"] save_extra_vars = params["save_extra_kwargs"] - save_preprocessed_copy = (params["save_preprocessed_copy"],) + save_preprocessed_copy = params["save_preprocessed_copy"] progress_bar = None settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS} settings_ks["n_chan_bin"] = recording.get_num_channels() @@ -268,6 +268,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object ) + if save_preprocessed_copy: + save_preprocessing(results_dir / "temp_wh.dat", ops, bfile) + # Sort spikes and save results st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar) clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar) @@ -276,7 +279,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels())) ) - _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): + _ = save_sorting( + ops, + results_dir, + st, + clu, + tF, + Wall, + bfile.imin, + tic0, + save_extra_vars=save_extra_vars, + save_preprocessed_copy=save_preprocessed_copy, + ) + else: + _ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars) @classmethod def _get_result_from_folder(cls, sorter_output_folder): From c320d6c09761ea673a5c24e06ea55622997f4d9f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 31 Jul 2024 15:18:09 +0100 Subject: [PATCH 3/4] Add clarification on typo. --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 16918128a2..250c2865f9 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -217,7 +217,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): do_CAR, invert_sign, device, - save_preprocesed_copy=save_preprocessed_copy, + save_preprocesed_copy=save_preprocessed_copy, # this kwarg is correct (typo) ) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( get_run_parameters(ops) From e51088ab0e5f56596a78fc4cfd4e9a6d50f71414 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 31 Jul 2024 15:20:32 +0100 Subject: [PATCH 4/4] Extend param description. --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 250c2865f9..6d83249653 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -99,7 +99,7 @@ class Kilosort4Sorter(BaseSorter): "save_extra_kwargs": "If True, additional kwargs are saved to the output", "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", - "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory", + "save_preprocessed_copy": "save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data", "torch_device": "Select the torch device auto/cuda/cpu", }