Skip to content

Commit

Permalink
Merge pull request #1958 from zm711/isi-refactor
Browse files Browse the repository at this point in the history
Refactor ISI calculation numpy and numba
  • Loading branch information
alejoe91 authored Oct 26, 2023
2 parents 67869c5 + 3f82c59 commit 503d7c8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 75 deletions.
1 change: 0 additions & 1 deletion src/spikeinterface/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

from .isi import (
ISIHistogramsCalculator,
compute_isi_histograms_from_spiketrain,
compute_isi_histograms,
compute_isi_histograms_numpy,
compute_isi_histograms_numba,
Expand Down
88 changes: 14 additions & 74 deletions src/spikeinterface/postprocessing/isi.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,61 +65,6 @@ def get_extension_function():
WaveformExtractor.register_extension(ISIHistogramsCalculator)


def compute_isi_histograms_from_spiketrain(spike_train: np.ndarray, max_time: int, bin_size: int, sampling_f: float):
"""
Computes the Inter-Spike Intervals histogram from a given spike train.
This implementation only works if you have numba installed, to accelerate the
computation time.
Parameters
----------
spike_train: np.ndarray
The ordered spike train to compute the ISI.
max_time: int
Compute the ISI from 0 to +max_time (in sampling time).
bin_size: int
Size of a bin (in sampling time).
sampling_f: float
Sampling rate/frequency (in Hz).
Returns
-------
tuple (ISI, bins)
ISI: np.ndarray[int64]
The computed ISI histogram.
bins: np.ndarray[float64]
The bins for the ISI histogram.
"""
if not HAVE_NUMBA:
print("Error: numba is not installed.")
print("compute_ISI_from_spiketrain cannot run without numba.")
return 0

return _compute_isi_histograms_from_spiketrain(spike_train.astype(np.int64), max_time, bin_size, sampling_f)


if HAVE_NUMBA:

@numba.jit((numba.int64[::1], numba.int32, numba.int32, numba.float32), nopython=True, nogil=True, cache=True)
def _compute_isi_histograms_from_spiketrain(spike_train, max_time, bin_size, sampling_f):
n_bins = int(max_time / bin_size)

bins = np.arange(0, max_time + bin_size, bin_size) * 1e3 / sampling_f
ISI = np.zeros(n_bins, dtype=np.int64)

for i in range(1, len(spike_train)):
diff = spike_train[i] - spike_train[i - 1]

if diff >= max_time:
continue

bin = int(diff / bin_size)
ISI[bin] += 1

return ISI, bins


def compute_isi_histograms(
waveform_or_sorting_extractor,
load_if_exists=False,
Expand All @@ -140,7 +85,7 @@ def compute_isi_histograms(
bin_ms : float, optional
The bin size in ms, by default 1.0.
method : str, optional
"auto" | "numpy" | "numba". If _auto" and numba is installed, numba is used, by default "auto"
"auto" | "numpy" | "numba". If "auto" and numba is installed, numba is used, by default "auto"
Returns
-------
Expand Down Expand Up @@ -191,21 +136,19 @@ def compute_isi_histograms_numpy(sorting, window_ms: float = 50.0, bin_ms: float
"""
fs = sorting.get_sampling_frequency()
num_units = len(sorting.unit_ids)

assert bin_ms * 1e-3 >= 1 / fs, f"bin size must be larger than the sampling period {1e3 / fs}"
assert bin_ms <= window_ms
window_size = int(round(fs * window_ms * 1e-3))
bin_size = int(round(fs * bin_ms * 1e-3))
window_size -= window_size % bin_size
num_bins = int(window_size / bin_size)
assert num_bins >= 1

ISIs = np.zeros((num_units, num_bins), dtype=np.int64)
bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs
ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64)

# TODO: There might be a better way than a double for loop?
for i, unit_id in enumerate(sorting.unit_ids):
for seg_index in range(sorting.get_num_segments()):
spike_train = sorting.get_unit_spike_train(unit_id, segment_index=seg_index)
ISI = np.histogram(np.diff(spike_train), bins=num_bins, range=(0, window_size - 1))[0]
ISI = np.histogram(np.diff(spike_train), bins=bins)[0]
ISIs[i] += ISI

return ISIs, bins
Expand All @@ -224,18 +167,18 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float

assert HAVE_NUMBA
fs = sorting.get_sampling_frequency()
assert bin_ms * 1e-3 >= 1 / fs, f"the bin_ms must be larger than the sampling period: {1e3 / fs}"
assert bin_ms <= window_ms
num_units = len(sorting.unit_ids)

window_size = int(round(fs * window_ms * 1e-3))
bin_size = int(round(fs * bin_ms * 1e-3))
window_size -= window_size % bin_size
num_bins = int(window_size / bin_size)
assert num_bins >= 1

bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs
spikes = sorting.to_spike_vector(concatenated=False)

ISIs = np.zeros((num_units, num_bins), dtype=np.int64)
ISIs = np.zeros((num_units, len(bins) - 1), dtype=np.int64)

for seg_index in range(sorting.get_num_segments()):
spike_times = spikes[seg_index]["sample_index"].astype(np.int64)
Expand All @@ -245,9 +188,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float
ISIs,
spike_times,
spike_labels,
window_size,
bin_size,
fs,
bins,
)

return ISIs, bins
Expand All @@ -256,16 +197,15 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float
if HAVE_NUMBA:

@numba.jit(
(numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.int32, numba.int32, numba.float32),
(numba.int64[:, ::1], numba.int64[::1], numba.int32[::1], numba.float64[::1]),
nopython=True,
nogil=True,
cache=True,
parallel=True,
)
def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, max_time, bin_size, sampling_f):
def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins):
n_units = ISIs.shape[0]

for i in numba.prange(n_units):
units_loop = numba.prange(n_units) if n_units > 300 else range(n_units)
for i in units_loop:
spike_train = spike_trains[spike_clusters == i]

ISIs[i] += _compute_isi_histograms_from_spiketrain(spike_train, max_time, bin_size, sampling_f)[0]
ISIs[i] += np.histogram(np.diff(spike_train), bins=bins)[0]

0 comments on commit 503d7c8

Please sign in to comment.