Skip to content

Commit

Permalink
Merge pull request #2175 from DradeAW/fast_vector_to_dict
Browse files Browse the repository at this point in the history
Allow precomputing spike trains
  • Loading branch information
alejoe91 authored Nov 22, 2023
2 parents 5d7b64e + 41874da commit 029c24a
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
get_chunk_with_margin,
order_channels_by_depth,
)
from .sorting_tools import spike_vector_to_spike_trains

from .waveform_tools import extract_waveforms_to_buffers
from .snippets_tools import snippets_from_sorting

Expand Down
35 changes: 32 additions & 3 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from .base import BaseExtractor, BaseSegment
from .sorting_tools import spike_vector_to_spike_trains
from .waveform_tools import has_exceeding_spikes


Expand Down Expand Up @@ -130,9 +131,11 @@ def get_unit_spike_train(
else:
spike_frames = self._cached_spike_trains[segment_index][unit_id]
if start_frame is not None:
spike_frames = spike_frames[spike_frames >= start_frame]
start = np.searchsorted(spike_frames, start_frame)
spike_frames = spike_frames[start:]
if end_frame is not None:
spike_frames = spike_frames[spike_frames < end_frame]
end = np.searchsorted(spike_frames, end_frame)
spike_frames = spike_frames[:end]
else:
segment = self._sorting_segments[segment_index]
spike_frames = segment.get_unit_spike_train(
Expand Down Expand Up @@ -431,7 +434,6 @@ def frame_slice(self, start_frame, end_frame, check_spike_frames=True):
def get_all_spike_trains(self, outputs="unit_id"):
"""
Return all spike trains concatenated.
This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead
"""

Expand Down Expand Up @@ -467,6 +469,33 @@ def get_all_spike_trains(self, outputs="unit_id"):
spikes.append((spike_times, spike_labels))
return spikes

def precompute_spike_trains(self, from_spike_vector=None):
"""
Pre-computes and caches all spike trains for this sorting
Parameters
----------
from_spike_vector: None | bool, default: None
If None, then it is automatic depending on whether the spike vector is cached.
If True, will compute it from the spike vector.
If False, will call `get_unit_spike_train` for each segment for each unit.
"""
unit_ids = self.unit_ids

if from_spike_vector is None:
# if spike vector is cached then use it
from_spike_vector = self._cached_spike_vector is not None

if from_spike_vector:
self._cached_spike_trains = spike_vector_to_spike_trains(self.to_spike_vector(concatenated=False), unit_ids)

else:
for segment_index in range(self.get_num_segments()):
for unit_id in unit_ids:
self.get_unit_spike_train(unit_id, segment_index=segment_index, use_cache=True)

def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cache=True):
"""
Construct a unique structured numpy vector concatenating all spikes
Expand Down
13 changes: 8 additions & 5 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,16 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame):
s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1])
self.spikes_in_seg = self.spikes[s0:s1]

start = 0 if start_frame is None else np.searchsorted(self.spikes_in_seg["sample_index"], start_frame)
end = (
len(self.spikes_in_seg)
if end_frame is None
else np.searchsorted(self.spikes_in_seg["sample_index"], end_frame)
)

unit_index = self.unit_ids.index(unit_id)
times = self.spikes_in_seg[self.spikes_in_seg["unit_index"] == unit_index]["sample_index"]
times = self.spikes_in_seg[start:end][self.spikes_in_seg[start:end]["unit_index"] == unit_index]["sample_index"]

if start_frame is not None:
times = times[times >= start_frame]
if end_frame is not None:
times = times[times < end_frame]
return times


Expand Down
91 changes: 91 additions & 0 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np


def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.array) -> dict[dict]:
"""
Computes all spike trains for all units/segments from a spike vector list.
Internally calls numba if numba is installed.
Parameters
----------
spike_vector: list[np.ndarray]
List of spike vectors optained with sorting.to_spike_vector(concatenated=False)
unit_ids: np.array
Unit ids
Returns
-------
spike_trains: dict[dict]:
A dict containing, for each segment, the spike trains of all units
(as a dict: unit_id --> spike_train).
"""

try:
import numba

HAVE_NUMBA = True
except:
HAVE_NUMBA = False

if HAVE_NUMBA:
# the trick here is to have a function getter
vector_to_list_of_spiketrain = get_numba_vector_to_list_of_spiketrain()
else:
vector_to_list_of_spiketrain = vector_to_list_of_spiketrain_numpy

num_units = unit_ids.size
spike_trains = {}
for segment_index, spikes in enumerate(spike_vector):
sample_indices = np.array(spikes["sample_index"]).astype(np.int64, copy=False)
unit_indices = np.array(spikes["unit_index"]).astype(np.int64, copy=False)
list_of_spiketrains = vector_to_list_of_spiketrain(sample_indices, unit_indices, num_units)
spike_trains[segment_index] = dict(zip(unit_ids, list_of_spiketrains))

return spike_trains


def vector_to_list_of_spiketrain_numpy(sample_indices, unit_indices, num_units):
"""
Slower implementation of vetor_to_dict using numpy boolean mask.
This is for one segment.
"""
spike_trains = []
for u in range(num_units):
spike_trains.append(sample_indices[unit_indices == u])
return spike_trains


def get_numba_vector_to_list_of_spiketrain():
if hasattr(get_numba_vector_to_list_of_spiketrain, "_cached_numba_function"):
return get_numba_vector_to_list_of_spiketrain._cached_numba_function

import numba

@numba.jit((numba.int64[::1], numba.int64[::1], numba.int64), nopython=True, nogil=True, cache=True)
def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units):
"""
Fast implementation of vector_to_dict using numba loop.
This is for one segment.
"""
num_spikes = sample_indices.size
num_spike_per_units = np.zeros(num_units, dtype=np.int32)
for s in range(num_spikes):
num_spike_per_units[unit_indices[s]] += 1

spike_trains = []
for u in range(num_units):
spike_trains.append(np.empty(num_spike_per_units[u], dtype=np.int64))

current_x = np.zeros(num_units, dtype=np.int64)
for s in range(num_spikes):
unit_index = unit_indices[s]
spike_trains[unit_index][current_x[unit_index]] = sample_indices[s]
current_x[unit_index] += 1

return spike_trains

# Cache the compiled function
get_numba_vector_to_list_of_spiketrain._cached_numba_function = vector_to_list_of_spiketrain_numba

return vector_to_list_of_spiketrain_numba
8 changes: 7 additions & 1 deletion src/spikeinterface/core/tests/test_core_tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import platform
from multiprocessing.shared_memory import SharedMemory
from pathlib import Path
import importlib

import pytest
import numpy as np

from spikeinterface.core.core_tools import write_binary_recording, write_memory_recording, recursive_path_modifier
from spikeinterface.core.core_tools import (
write_binary_recording,
write_memory_recording,
recursive_path_modifier,
)
from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor
from spikeinterface.core.generate import NoiseGeneratorRecording
from spikeinterface.core.numpyextractors import NumpySorting


if hasattr(pytest, "global_test_folder"):
Expand Down
24 changes: 24 additions & 0 deletions src/spikeinterface/core/tests/test_sorting_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import importlib
import pytest
import numpy as np

from spikeinterface.core import NumpySorting

from spikeinterface.core.sorting_tools import spike_vector_to_spike_trains


@pytest.mark.skipif(
importlib.util.find_spec("numba") is None, reason="Testing `spike_vector_to_dict` requires Python package 'numba'."
)
def test_spike_vector_to_spike_trains():
sorting = NumpySorting.from_unit_dict({1: np.array([0, 51, 108]), 5: np.array([23, 87])}, 30_000)
spike_vector = sorting.to_spike_vector(concatenated=False)
spike_trains = spike_vector_to_spike_trains(spike_vector, sorting.unit_ids)

assert len(spike_trains[0]) == sorting.get_num_units()
for unit_index, unit_id in enumerate(sorting.unit_ids):
assert np.array_equal(spike_trains[0][unit_id], sorting.get_unit_spike_train(unit_id=unit_id, segment_index=0))


if __name__ == "__main__":
test_spike_vector_to_spike_trains()

0 comments on commit 029c24a

Please sign in to comment.