diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index d7e1ffac01..33e0ff6c03 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -38,7 +38,6 @@ from .isi import ( ISIHistogramsCalculator, - compute_isi_histograms_from_spiketrain, compute_isi_histograms, compute_isi_histograms_numpy, compute_isi_histograms_numba, diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index aec70141cf..e98e64f753 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -65,61 +65,6 @@ def get_extension_function(): WaveformExtractor.register_extension(ISIHistogramsCalculator) -def compute_isi_histograms_from_spiketrain(spike_train: np.ndarray, max_time: int, bin_size: int, sampling_f: float): - """ - Computes the Inter-Spike Intervals histogram from a given spike train. - - This implementation only works if you have numba installed, to accelerate the - computation time. - - Parameters - ---------- - spike_train: np.ndarray - The ordered spike train to compute the ISI. - max_time: int - Compute the ISI from 0 to +max_time (in sampling time). - bin_size: int - Size of a bin (in sampling time). - sampling_f: float - Sampling rate/frequency (in Hz). - - Returns - ------- - tuple (ISI, bins) - ISI: np.ndarray[int64] - The computed ISI histogram. - bins: np.ndarray[float64] - The bins for the ISI histogram. - """ - if not HAVE_NUMBA: - print("Error: numba is not installed.") - print("compute_ISI_from_spiketrain cannot run without numba.") - return 0 - - return _compute_isi_histograms_from_spiketrain(spike_train.astype(np.int64), max_time, bin_size, sampling_f) - - -if HAVE_NUMBA: - - @numba.jit((numba.int64[::1], numba.int32, numba.int32, numba.float32), nopython=True, nogil=True, cache=True) - def _compute_isi_histograms_from_spiketrain(spike_train, max_time, bin_size, sampling_f): - n_bins = int(max_time / bin_size) - - bins = np.arange(0, max_time + bin_size, bin_size) * 1e3 / sampling_f - ISI = np.zeros(n_bins, dtype=np.int64) - - for i in range(1, len(spike_train)): - diff = spike_train[i] - spike_train[i - 1] - - if diff >= max_time: - continue - - bin = int(diff / bin_size) - ISI[bin] += 1 - - return ISI, bins - - def compute_isi_histograms( waveform_or_sorting_extractor, load_if_exists=False, @@ -140,7 +85,7 @@ def compute_isi_histograms( bin_ms : float, optional The bin size in ms, by default 1.0. method : str, optional - "auto" | "numpy" | "numba". If _auto" and numba is installed, numba is used, by default "auto" + "auto" | "numpy" | "numba". If "auto" and numba is installed, numba is used, by default "auto" Returns ------- @@ -191,21 +136,19 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float """ fs = sorting.get_sampling_frequency() num_units = len(sorting.unit_ids) - + assert bin_ms * 1e-3 >= 1 / fs, f"bin size must be larger than the sampling period {1e3 / fs}" + assert bin_ms <= window_ms window_size = int(round(fs * window_ms * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size - num_bins = int(window_size / bin_size) - assert num_bins >= 1 - - ISIs = np.zeros((num_units, num_bins), dtype=np.int64) bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs + ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64) # TODO: There might be a better way than a double for loop? for i, unit_id in enumerate(sorting.unit_ids): for seg_index in range(sorting.get_num_segments()): spike_train = sorting.get_unit_spike_train(unit_id, segment_index=seg_index) - ISI = np.histogram(np.diff(spike_train), bins=num_bins, range=(0, window_size - 1))[0] + ISI = np.histogram(np.diff(spike_train), bins=bins)[0] ISIs[i] += ISI return ISIs, bins @@ -224,18 +167,18 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float assert HAVE_NUMBA fs = sorting.get_sampling_frequency() + assert bin_ms * 1e-3 >= 1 / fs, f"the bin_ms must be larger than the sampling period: {1e3 / fs}" + assert bin_ms <= window_ms num_units = len(sorting.unit_ids) window_size = int(round(fs * window_ms * 1e-3)) bin_size = int(round(fs * bin_ms * 1e-3)) window_size -= window_size % bin_size - num_bins = int(window_size / bin_size) - assert num_bins >= 1 bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs spikes = sorting.to_spike_vector(concatenated=False) - ISIs = np.zeros((num_units, num_bins), dtype=np.int64) + ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): spike_times = spikes[seg_index]["sample_index"].astype(np.int64) @@ -245,9 +188,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float ISIs, spike_times, spike_labels, - window_size, - bin_size, - fs, + bins, ) return ISIs, bins @@ -256,16 +197,15 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float if HAVE_NUMBA: @numba.jit( - (numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.int32, numba.int32, numba.float32), + (numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.float64[::1]), nopython=True, nogil=True, cache=True, - parallel=True, ) - def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, max_time, bin_size, sampling_f): + def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins): n_units = ISIs.shape[0] - for i in numba.prange(n_units): + units_loop = numba.prange(n_units) if n_units > 300 else range(n_units) + for i in units_loop: spike_train = spike_trains[spike_clusters == i] - - ISIs[i] += _compute_isi_histograms_from_spiketrain(spike_train, max_time, bin_size, sampling_f)[0] + ISIs[i] += np.histogram(np.diff(spike_train), bins=bins)[0]