diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0c92eab0e4..2853a4fc55 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1094,6 +1094,8 @@ def get_traces( ) -> np.ndarray: start_frame = 0 if start_frame is None else max(start_frame, 0) end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples) + start_frame = int(start_frame) + end_frame = int(end_frame) start_frame_within_block = start_frame % self.noise_block_size end_frame_within_block = end_frame % self.noise_block_size @@ -1652,6 +1654,8 @@ def get_traces( ) -> np.ndarray: start_frame = 0 if start_frame is None else start_frame end_frame = self.num_samples if end_frame is None else end_frame + start_frame = int(start_frame) + end_frame = int(end_frame) if channel_indices is None: n_channels = self.templates.shape[2] @@ -1688,6 +1692,8 @@ def get_traces( end_traces = start_traces + template.shape[0] if start_traces >= end_frame - start_frame or end_traces <= 0: continue + start_traces = int(start_traces) + end_traces = int(end_traces) start_template = 0 end_template = template.shape[0] diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 47846f10ce..a7f40a9558 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Union +from packaging import version from ..basesorter import BaseSorter from .kilosortbase import KilosortBase @@ -24,11 +25,14 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": True, "invert_sign": False, "nt": 61, + "shift": None, + "scale": None, "artifact_threshold": None, "nskip": 25, "whitening_range": 32, "binning_depth": 5, "sig_interp": 20, + "drift_smoothing": [0.5, 0.5, 0.5], "nt0min": None, "dmin": None, "dminx": 32, @@ -63,11 +67,14 @@ class Kilosort4Sorter(BaseSorter): "do_CAR": "Whether to perform common average reference. Default value: True.", "invert_sign": "Invert the sign of the data. Default value: False.", "nt": "Number of samples per waveform. Also size of symmetric padding for filtering. Default value: 61.", + "shift": "Scalar shift to apply to data before all other operations. Default None.", + "scale": "Scaling factor to apply to data before all other operations. Default None.", "artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.", "nskip": "Batch stride for computing whitening matrix. Default value: 25.", "whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.", "binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.", "sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.", + "drift_smoothing": "Amount of gaussian smoothing to apply to the spatiotemporal drift estimation, for x,y,time axes in units of registration blocks (for x,y axes) and batch size (for time axis). The x,y smoothing has no effect for `nblocks = 1`.", "nt0min": "Sample index for aligning waveforms, so that their minimum or maximum value happens here. Default of 20. Default value: None.", "dmin": "Vertical spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.", "dminx": "Horizontal spacing of template centers used for spike detection, in microns. Default value: 32.", @@ -153,6 +160,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import torch import numpy as np + if verbose: + import logging + + logging.basicConfig(level=logging.INFO) + sorter_output_folder = sorter_output_folder.absolute() probe_filename = sorter_output_folder / "probe.prb" @@ -194,11 +206,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): data_dir = "" results_dir = sorter_output_folder filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir) - ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) + if version.parse(cls.get_sorter_version()) >= version.parse("4.0.12"): + ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device, False) + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact, _, _ = ( + get_run_parameters(ops) + ) + else: + ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device) + n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( + get_run_parameters(ops) + ) - n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = ( - get_run_parameters(ops) - ) # Set preprocessing and drift correction parameters if not params["skip_kilosort_preprocessing"]: ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)