Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow precomputing spike trains #2175

Merged
merged 42 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0164152
Allow precomputing spike trains
DradeAW Nov 6, 2023
a177516
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
73fe91c
Making sure `numba` is installed
DradeAW Nov 6, 2023
4cf076a
Merge branch 'fast_vector_to_dict' of github.com:DradeAW/spikeinterfa…
DradeAW Nov 6, 2023
1a6086b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
46d7c19
oops
DradeAW Nov 6, 2023
f1b5086
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
dcee476
oops
DradeAW Nov 6, 2023
e0e9a86
Merge branch 'fast_vector_to_dict' of github.com:DradeAW/spikeinterfa…
DradeAW Nov 6, 2023
a9fcabf
Fix crash
DradeAW Nov 6, 2023
f52a4c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
c40c231
Better precompute spike trains
DradeAW Nov 6, 2023
ac6483a
Merge branch 'fast_vector_to_dict' of github.com:DradeAW/spikeinterfa…
DradeAW Nov 6, 2023
5b845d8
Nicer assert messages
DradeAW Nov 6, 2023
93aa65d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
ca2f553
Small tweaks
DradeAW Nov 7, 2023
1814faf
Heberto's suggestions
DradeAW Nov 9, 2023
ab3dbbb
Make NumpySorting more efficient
DradeAW Nov 9, 2023
1259fb7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2023
6f80180
Added docstring
DradeAW Nov 9, 2023
ed566aa
Merge branch 'fast_vector_to_dict' of github.com:DradeAW/spikeinterfa…
DradeAW Nov 9, 2023
278209a
Heberto's suggestions
DradeAW Nov 9, 2023
506185f
Merge branch 'main' into fast_vector_to_dict
h-mayorquin Nov 9, 2023
9d356e9
oops
DradeAW Nov 9, 2023
9197015
oops
DradeAW Nov 9, 2023
83ad66c
Heberto suggestion
DradeAW Nov 10, 2023
7af9a6a
Merge branch 'main' into fast_vector_to_dict
DradeAW Nov 14, 2023
adda632
Merge branch 'main' into fast_vector_to_dict
DradeAW Nov 16, 2023
a5e0c02
oops
DradeAW Nov 16, 2023
f2371c6
Improve numba kernel for spike_vector to spiketrain dict
samuelgarcia Nov 17, 2023
ed405bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2023
eba0251
oups
samuelgarcia Nov 17, 2023
67ce45c
Merge branch 'fast_vector_to_dict' of github.com:DradeAW/spikeinterfa…
samuelgarcia Nov 17, 2023
e17c42b
rename variables
samuelgarcia Nov 17, 2023
880d323
remove numba import useless
samuelgarcia Nov 17, 2023
32d75c0
Fixed import bug
DradeAW Nov 17, 2023
2f3b616
Merge branch 'main' into fast_vector_to_dict
DradeAW Nov 17, 2023
7ff37e6
Fixed bug + docstring
DradeAW Nov 17, 2023
3f71517
Merge branch 'fast_vector_to_dict' of github.com:DradeAW/spikeinterfa…
DradeAW Nov 17, 2023
8950077
Numba code more readable
DradeAW Nov 20, 2023
9ec1ed7
Merge branch 'main' into fast_vector_to_dict
DradeAW Nov 22, 2023
41874da
Update src/spikeinterface/core/basesorting.py
alejoe91 Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
get_chunk_with_margin,
order_channels_by_depth,
)
from .sorting_tools import spike_vector_to_spike_trains

from .waveform_tools import extract_waveforms_to_buffers
from .snippets_tools import snippets_from_sorting

Expand Down
35 changes: 32 additions & 3 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from .base import BaseExtractor, BaseSegment
from .sorting_tools import spike_vector_to_spike_trains
from .waveform_tools import has_exceeding_spikes


Expand Down Expand Up @@ -130,9 +131,11 @@ def get_unit_spike_train(
else:
spike_frames = self._cached_spike_trains[segment_index][unit_id]
if start_frame is not None:
spike_frames = spike_frames[spike_frames >= start_frame]
start = np.searchsorted(spike_frames, start_frame)
spike_frames = spike_frames[start:]
if end_frame is not None:
spike_frames = spike_frames[spike_frames < end_frame]
end = np.searchsorted(spike_frames, end_frame)
spike_frames = spike_frames[:end]
else:
segment = self._sorting_segments[segment_index]
spike_frames = segment.get_unit_spike_train(
Expand Down Expand Up @@ -409,7 +412,6 @@ def frame_slice(self, start_frame, end_frame, check_spike_frames=True):
def get_all_spike_trains(self, outputs="unit_id"):
"""
Return all spike trains concatenated.

This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead
"""

Expand Down Expand Up @@ -445,6 +447,33 @@ def get_all_spike_trains(self, outputs="unit_id"):
spikes.append((spike_times, spike_labels))
return spikes

def precompute_spike_trains(self, from_spike_vector=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like too much this method name.
Maybe cache_all_spike_trains() would be more explicit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, I would expect a function cache_all_spike_trains to not perform any computation, but just move the result of computation to the cache.
But I agree that the name could maybe be improved

"""
Pre-computes and caches all spike trains for this sorting



Parameters
----------
from_spike_vector: None | bool, default: None
If None, then it is automatic dependin
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
If True, will compute it from the spike vector.
If False, will call `get_unit_spike_train` for each segment for each unit.
"""
unit_ids = self.unit_ids

if from_spike_vector is None:
# if spike vector is cached then use it
from_spike_vector = self._cached_spike_vector is not None

if from_spike_vector:
self._cached_spike_trains = spike_vector_to_spike_trains(self.to_spike_vector(concatenated=False), unit_ids)

else:
for segment_index in range(self.get_num_segments()):
for unit_id in unit_ids:
self.get_unit_spike_train(unit_id, segment_index=segment_index, use_cache=True)

def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cache=True):
"""
Construct a unique structured numpy vector concatenating all spikes
Expand Down
13 changes: 8 additions & 5 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,16 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame):
s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1])
self.spikes_in_seg = self.spikes[s0:s1]

start = 0 if start_frame is None else np.searchsorted(self.spikes_in_seg["sample_index"], start_frame)
end = (
len(self.spikes_in_seg)
if end_frame is None
else np.searchsorted(self.spikes_in_seg["sample_index"], end_frame)
)

unit_index = self.unit_ids.index(unit_id)
times = self.spikes_in_seg[self.spikes_in_seg["unit_index"] == unit_index]["sample_index"]
times = self.spikes_in_seg[start:end][self.spikes_in_seg[start:end]["unit_index"] == unit_index]["sample_index"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be easier if we did not use the concatenated version of to_spike_vector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR. But the to_spike_vector function has a function to return a list of spike_vectors per segment instead of all the spike_vectors concanteted, so you could avoid the first np.searchsorted which look for the segment indexes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cache=True):

concatenated keyword.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok!

I didn't touch that, but I removed the disgusting times[times > start] which doesn't take advantage of the fact that it is ordered!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"disgusting" is a bit hard but I get the feeling. Note that generally this start_frame \ end_frame is very often None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@h-mayorquin
Now it use to_spike_vector(concatenated=False) which is better.


if start_frame is not None:
times = times[times >= start_frame]
if end_frame is not None:
times = times[times < end_frame]
return times


Expand Down
90 changes: 90 additions & 0 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy as np


def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.array) -> dict[dict]:
"""
Computes all spike trains for all units/segments from a spike vector list.

Internally calls numba if numba is installed.

Parameters
----------
spike_vector: list[np.ndarray]
List of spike vectors optained with sorting.to_spike_vector(concatenated=False)
unit_ids: np.array
Unit ids

Returns
-------
spike_trains: dict[dict]:
DradeAW marked this conversation as resolved.
Show resolved Hide resolved
A dict containing, for each segment, the spike trains of all units
(as a dict: unit_id --> spike_train).
"""

try:
import numba

HAVE_NUMBA = True
except:
HAVE_NUMBA = False

if HAVE_NUMBA:
# the trick here is to have a function getter
vector_to_list_of_spiketrain = get_numba_vector_to_list_of_spiketrain()
else:
vector_to_list_of_spiketrain = vector_to_list_of_spiketrain_numpy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, if you don't have numba you use the numpy version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is eaxctly the "old" way of course.


num_units = unit_ids.size
spike_trains = {}
for segment_index, spikes in enumerate(spike_vector):
sample_indices = np.array(spikes["sample_index"]).astype(np.int64, copy=False)
unit_indices = np.array(spikes["unit_index"]).astype(np.int64, copy=False)
list_of_spiketrains = vector_to_list_of_spiketrain(sample_indices, unit_indices, num_units)
spike_trains[segment_index] = dict(zip(unit_ids, list_of_spiketrains))

return spike_trains


def vector_to_list_of_spiketrain_numpy(sample_indices, unit_indices, num_units):
Copy link
Collaborator

@h-mayorquin h-mayorquin Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samuelgarcia @DradeAW

If the "old way" you mean using the numpy version then that is what is super-slow and I think should be avoided.

If by the "old way" you mean just calling get_unit_spike_train for every unit (which is what I meant) then I think is indeed a bug because the implementation is calling this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my version, it was calling get_unit_spike_train(unit_id).

But the function Sam wrote should be doing the same thing the NumpySorting does when you call get_unit_spike_train ... So it should be the same performance

Copy link
Collaborator

@h-mayorquin h-mayorquin Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the sorting is a NumpySorting that would be correct, and it makes sense: NumpySorting is slow at extracting spike trains because, well, the representation is just bad for that. That's what it is.

But if you sorting is not as bad as NumpySorting for calling get_unit_spike_train then the current implementation is terrible without Numba.

That is, If you have a spike_vector cached in something that is not as bad NumpySorting for extracting spikes and you don't have Numba then you will be silently calling a terribly slow function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see the problem!

You are speaking about a (corner) case where the get_unit_spike_train function is very fast, yet for some reason the spike trains are not cached but the spike vector is.
Indeed you are correct, in this case it would be terribly slow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take issue with you calling that corner_case. Why are you so confident than sorting to build the spike_vector and then masking within resulting structured array is a faster operation than just calling get_unit_spike_train?

My original appeal was to restrict this to methods that already have the spike_vector pre-calculated like NumpySorting and.a as you were saying, maybe Phy. There, you can be confident than calling this unlikely to be bad. But then I tested this method with numba and it works OK even for those cases with kind of fast-io so that seems OK.

But for functions that that don't have spike_vector pre-calculated this seems like a corner case:

  1. You already called get_unit_spike_train to get your spike_vector (so your data should be cached already, not need for this)
  2. Then, you already called get_unit_spike_train but somehow it is better to mask a very long vector (generating a spike in memory) and extract it from there.

So again, I think this works great for cases that already have spike_vector pre-calculated and it works OK even for other cases (if you have numba), but, if you don't have it? or if you don't have numba? Are you sure that get_unit_spike_train is that bad in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I read again the function, and realized it does not check whether the spike trains cache is already full (this is in another PR). This should be added.

The corner case I was talking about is when get_unit_spike_train is very fast, yet the spike vector is cached but not the spike trains (since the spike vector is computed form the spike trains, they should be cached).

"""
Slower implementation of vetor_to_dict using numpy boolean mask.
This is for one segment.
"""
spike_trains = []
for u in range(num_units):
spike_trains.append(sample_indices[unit_indices == u])
return spike_trains


def get_numba_vector_to_list_of_spiketrain():
if hasattr(get_numba_vector_to_list_of_spiketrain, "_cached_numba_function"):
return get_numba_vector_to_list_of_spiketrain._cached_numba_function

import numba

@numba.jit((numba.int64[::1], numba.int64[::1], numba.int64), nopython=True, nogil=True, cache=True)
def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units):
"""
Fast implementation of vector_to_dict using numba loop.
This is for one segment.
"""
num_spikes = sample_indices.size
num_spike_per_units = np.zeros(num_units, dtype=np.int32)
for s in range(num_spikes):
num_spike_per_units[unit_indices[s]] += 1

spike_trains = []
for u in range(num_units):
spike_trains.append(np.empty(num_spike_per_units[u], dtype=np.int64))

current_x = np.zeros(num_units, dtype=np.int64)
for s in range(num_spikes):
spike_trains[unit_indices[s]][current_x[unit_indices[s]]] = sample_indices[s]
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
current_x[unit_indices[s]] += 1

return spike_trains

# Cache the compiled function
get_numba_vector_to_list_of_spiketrain._cached_numba_function = vector_to_list_of_spiketrain_numba

return vector_to_list_of_spiketrain_numba
8 changes: 7 additions & 1 deletion src/spikeinterface/core/tests/test_core_tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import platform
from multiprocessing.shared_memory import SharedMemory
from pathlib import Path
import importlib

import pytest
import numpy as np

from spikeinterface.core.core_tools import write_binary_recording, write_memory_recording, recursive_path_modifier
from spikeinterface.core.core_tools import (
write_binary_recording,
write_memory_recording,
recursive_path_modifier,
)
from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor
from spikeinterface.core.generate import NoiseGeneratorRecording
from spikeinterface.core.numpyextractors import NumpySorting


if hasattr(pytest, "global_test_folder"):
Expand Down
24 changes: 24 additions & 0 deletions src/spikeinterface/core/tests/test_sorting_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import importlib
import pytest
import numpy as np

from spikeinterface.core import NumpySorting

from spikeinterface.core.sorting_tools import spike_vector_to_spike_trains


@pytest.mark.skipif(
importlib.util.find_spec("numba") is None, reason="Testing `spike_vector_to_dict` requires Python package 'numba'."
)
def test_spike_vector_to_spike_trains():
sorting = NumpySorting.from_unit_dict({1: np.array([0, 51, 108]), 5: np.array([23, 87])}, 30_000)
spike_vector = sorting.to_spike_vector(concatenated=False)
spike_trains = spike_vector_to_spike_trains(spike_vector, sorting.unit_ids)

assert len(spike_trains[0]) == sorting.get_num_units()
for unit_index, unit_id in enumerate(sorting.unit_ids):
assert np.array_equal(spike_trains[0][unit_id], sorting.get_unit_spike_train(unit_id=unit_id, segment_index=0))


if __name__ == "__main__":
test_spike_vector_to_spike_trains()