-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow precomputing spike trains #2175
Changes from 39 commits
0164152
a177516
73fe91c
4cf076a
1a6086b
46d7c19
f1b5086
dcee476
e0e9a86
a9fcabf
f52a4c3
c40c231
ac6483a
5b845d8
93aa65d
ca2f553
1814faf
ab3dbbb
1259fb7
6f80180
ed566aa
278209a
506185f
9d356e9
9197015
83ad66c
7af9a6a
adda632
a5e0c02
f2371c6
ed405bb
eba0251
67ce45c
e17c42b
880d323
32d75c0
2f3b616
7ff37e6
3f71517
8950077
9ec1ed7
41874da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -347,13 +347,16 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): | |||
s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1]) | ||||
self.spikes_in_seg = self.spikes[s0:s1] | ||||
|
||||
start = 0 if start_frame is None else np.searchsorted(self.spikes_in_seg["sample_index"], start_frame) | ||||
end = ( | ||||
len(self.spikes_in_seg) | ||||
if end_frame is None | ||||
else np.searchsorted(self.spikes_in_seg["sample_index"], end_frame) | ||||
) | ||||
|
||||
unit_index = self.unit_ids.index(unit_id) | ||||
times = self.spikes_in_seg[self.spikes_in_seg["unit_index"] == unit_index]["sample_index"] | ||||
times = self.spikes_in_seg[start:end][self.spikes_in_seg[start:end]["unit_index"] == unit_index]["sample_index"] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be easier if we did not use the concatenated version of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not understand? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated to this PR. But the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah ok! I didn't touch that, but I removed the disgusting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "disgusting" is a bit hard but I get the feeling. Note that generally this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @h-mayorquin |
||||
|
||||
if start_frame is not None: | ||||
times = times[times >= start_frame] | ||||
if end_frame is not None: | ||||
times = times[times < end_frame] | ||||
return times | ||||
|
||||
|
||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import numpy as np | ||
|
||
|
||
def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.array) -> dict[dict]: | ||
""" | ||
Computes all spike trains for all units/segments from a spike vector list. | ||
|
||
Internally calls numba if numba is installed. | ||
|
||
Parameters | ||
---------- | ||
spike_vector: list[np.ndarray] | ||
List of spike vectors optained with sorting.to_spike_vector(concatenated=False) | ||
unit_ids: np.array | ||
Unit ids | ||
|
||
Returns | ||
------- | ||
spike_trains: dict[dict]: | ||
DradeAW marked this conversation as resolved.
Show resolved
Hide resolved
|
||
A dict containing, for each segment, the spike trains of all units | ||
(as a dict: unit_id --> spike_train). | ||
""" | ||
|
||
try: | ||
import numba | ||
|
||
HAVE_NUMBA = True | ||
except: | ||
HAVE_NUMBA = False | ||
|
||
if HAVE_NUMBA: | ||
# the trick here is to have a function getter | ||
vector_to_list_of_spiketrain = get_numba_vector_to_list_of_spiketrain() | ||
else: | ||
vector_to_list_of_spiketrain = vector_to_list_of_spiketrain_numpy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, if you don't have numba you use the numpy version. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is eaxctly the "old" way of course. |
||
|
||
num_units = unit_ids.size | ||
spike_trains = {} | ||
for segment_index, spikes in enumerate(spike_vector): | ||
sample_indices = np.array(spikes["sample_index"]).astype(np.int64, copy=False) | ||
unit_indices = np.array(spikes["unit_index"]).astype(np.int64, copy=False) | ||
list_of_spiketrains = vector_to_list_of_spiketrain(sample_indices, unit_indices, num_units) | ||
spike_trains[segment_index] = dict(zip(unit_ids, list_of_spiketrains)) | ||
|
||
return spike_trains | ||
|
||
|
||
def vector_to_list_of_spiketrain_numpy(sample_indices, unit_indices, num_units): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the "old way" you mean using the numpy version then that is what is super-slow and I think should be avoided. If by the "old way" you mean just calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my version, it was calling But the function Sam wrote should be doing the same thing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the sorting is a NumpySorting that would be correct, and it makes sense: NumpySorting is slow at extracting spike trains because, well, the representation is just bad for that. That's what it is. But if you sorting is not as bad as That is, If you have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see the problem! You are speaking about a (corner) case where the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I take issue with you calling that My original appeal was to restrict this to methods that already have the But for functions that that don't have
So again, I think this works great for cases that already have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I read again the function, and realized it does not check whether the spike trains cache is already full (this is in another PR). This should be added. The corner case I was talking about is when |
||
""" | ||
Slower implementation of vetor_to_dict using numpy boolean mask. | ||
This is for one segment. | ||
""" | ||
spike_trains = [] | ||
for u in range(num_units): | ||
spike_trains.append(sample_indices[unit_indices == u]) | ||
return spike_trains | ||
|
||
|
||
def get_numba_vector_to_list_of_spiketrain(): | ||
if hasattr(get_numba_vector_to_list_of_spiketrain, "_cached_numba_function"): | ||
return get_numba_vector_to_list_of_spiketrain._cached_numba_function | ||
|
||
import numba | ||
|
||
@numba.jit((numba.int64[::1], numba.int64[::1], numba.int64), nopython=True, nogil=True, cache=True) | ||
def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units): | ||
""" | ||
Fast implementation of vector_to_dict using numba loop. | ||
This is for one segment. | ||
""" | ||
num_spikes = sample_indices.size | ||
num_spike_per_units = np.zeros(num_units, dtype=np.int32) | ||
for s in range(num_spikes): | ||
num_spike_per_units[unit_indices[s]] += 1 | ||
|
||
spike_trains = [] | ||
for u in range(num_units): | ||
spike_trains.append(np.empty(num_spike_per_units[u], dtype=np.int64)) | ||
|
||
current_x = np.zeros(num_units, dtype=np.int64) | ||
for s in range(num_spikes): | ||
spike_trains[unit_indices[s]][current_x[unit_indices[s]]] = sample_indices[s] | ||
h-mayorquin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
current_x[unit_indices[s]] += 1 | ||
|
||
return spike_trains | ||
|
||
# Cache the compiled function | ||
get_numba_vector_to_list_of_spiketrain._cached_numba_function = vector_to_list_of_spiketrain_numba | ||
|
||
return vector_to_list_of_spiketrain_numba |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import importlib | ||
import pytest | ||
import numpy as np | ||
|
||
from spikeinterface.core import NumpySorting | ||
|
||
from spikeinterface.core.sorting_tools import spike_vector_to_spike_trains | ||
|
||
|
||
@pytest.mark.skipif( | ||
importlib.util.find_spec("numba") is None, reason="Testing `spike_vector_to_dict` requires Python package 'numba'." | ||
) | ||
def test_spike_vector_to_spike_trains(): | ||
sorting = NumpySorting.from_unit_dict({1: np.array([0, 51, 108]), 5: np.array([23, 87])}, 30_000) | ||
spike_vector = sorting.to_spike_vector(concatenated=False) | ||
spike_trains = spike_vector_to_spike_trains(spike_vector, sorting.unit_ids) | ||
|
||
assert len(spike_trains[0]) == sorting.get_num_units() | ||
for unit_index, unit_id in enumerate(sorting.unit_ids): | ||
assert np.array_equal(spike_trains[0][unit_id], sorting.get_unit_spike_train(unit_id=unit_id, segment_index=0)) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_spike_vector_to_spike_trains() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not like too much this method name.
Maybe
cache_all_spike_trains()
would be more explicitThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, I would expect a function
cache_all_spike_trains
to not perform any computation, but just move the result of computation to the cache.But I agree that the name could maybe be improved