-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow precomputing spike trains #2175
Conversation
for more information, see https://pre-commit.ci
…ce into fast_vector_to_dict
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…ce into fast_vector_to_dict
for more information, see https://pre-commit.ci
Wow, to my surprise this is not faster! import time
import spikeinterface.core as si
t1 = time.perf_counter()
sorting = si.load_extractor("/mnt/raid0/data/MEArec/1h_3000cells/analyses/ks2_5_pj7-3/sorting")
for unit_id in sorting.unit_ids:
sorting.get_unit_spike_train(unit_id)
t2 = time.perf_counter()
sorting = si.load_extractor("/mnt/raid0/data/MEArec/1h_3000cells/analyses/ks2_5_pj7-3/sorting")
sorting.precompute_spike_trains()
for unit_id in sorting.unit_ids:
sorting.get_unit_spike_train(unit_id)
t3 = time.perf_counter()
print(f"Old way: {t2-t1:.1f} s")
print(f"New way: {t3-t2:.1f} s")
|
…ce into fast_vector_to_dict
Ok this is way better with the latest commit! import time
import numpy as np
import spikeinterface.core as si
sorting1 = si.load_extractor("/mnt/raid0/data/MEArec/1h_3000cells/analyses/ks2_5_pj7-3/sorting")
sorting2 = si.load_extractor("/mnt/raid0/data/MEArec/1h_3000cells/analyses/ks2_5_pj7-3/sorting")
t1 = time.perf_counter()
for unit_id in sorting1.unit_ids:
sorting1.get_unit_spike_train(unit_id)
t2 = time.perf_counter()
sorting2.precompute_spike_trains()
for unit_id in sorting2.unit_ids:
sorting2.get_unit_spike_train(unit_id)
t3 = time.perf_counter()
print(f"Old way: {t2-t1:.1f} s")
print(f"New way: {t3-t2:.1f} s")
for unit_id in sorting1.unit_ids:
assert np.all(sorting1.get_unit_spike_train(unit_id) == sorting2.get_unit_spike_train(unit_id))
|
For reference, my sorting dataset contains 531 units for a total of 11,298,565 spikes (1h recording) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a question and suggestion @DradeAW
@@ -424,6 +427,31 @@ def get_all_spike_trains(self, outputs="unit_id"): | |||
spikes.append((spike_times, spike_labels)) | |||
return spikes | |||
|
|||
def precompute_spike_trains(self, from_spike_vector: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is in core where numba is not automatically installed wouldn't it be safer to have from_spike_vector
default to False, so that a user who has only installed core doesn't instantly hit an assert error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the parameter to be complete, but I actually don't know when it would be interesting to precompute spike trains not from the spike vector.
Can you think of a use case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not off the top of my head. But I'll think about it :)
for more information, see https://pre-commit.ci
Hi Aurelien. I will not have time to get feedback on this very soon.
|
@samuelgarcia You would prefer to split this in two functions with no parameter each? |
I removed the need for numba (there is now a default if it's not installed). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments added.
Two big questions:
How does this interacts plays with the caching option in to_spike_vector
and get_unit_spike_train
.
Isn't it duplicating behavior?
How would this play out with multiprocessing, all this pre-computing is lost isn't it?
It overrides it (forces a recompute).
No, because it is a fast computation of all spike trains at once, rather than one unit by one unit (which is, as I showed above, much faster).
I'm not sure, but I believe if you get the spike train it should find it in the cache in all cases? |
So are the functions introduced here doing a better job at calculating the |
@@ -449,6 +451,33 @@ def get_all_spike_trains(self, outputs="unit_id"): | |||
spikes.append((spike_times, spike_labels)) | |||
return spikes | |||
|
|||
def precompute_spike_trains(self, from_spike_vector=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not like too much this method name.
Maybe cache_all_spike_trains()
would be more explicit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, I would expect a function cache_all_spike_trains
to not perform any computation, but just move the result of computation to the cache.
But I agree that the name could maybe be improved
There are probably a few places in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleanup of docstrings
…ce into fast_vector_to_dict
Extract waveforms is based on the spike vector. |
Latest benchmark: even faster ~200x times faster on a very big sorting:
~35 times faster for a sorting I typically use:
|
Thanks @samuelgarcia . I introduced some machinery to test this: I want to decouple the performance from the Mearec memory model layout limitations. I will give it a test as soon as I can. |
I found no difference (as I expected) import time
import numpy as np
import spikeinterface.core as si
num_units = 1000
durations = [10 * 60 * 60.0]
seed = 25
sorting1 = SortingGenerator(num_units=num_units, durations=durations, seed=seed)
sorting2 = SortingGenerator(num_units=num_units, durations=durations, seed=seed)
t1 = time.perf_counter()
for unit_id in sorting1.unit_ids:
sorting1.get_unit_spike_train(unit_id)
t2 = time.perf_counter()
sorting2.precompute_spike_trains()
for unit_id in sorting2.unit_ids:
sorting2.get_unit_spike_train(unit_id)
t3 = time.perf_counter()
print(f"Good old : {t2-t1:.1f} s")
print(f"Pre-computing: {t3-t2:.1f} s")
for unit_id in sorting1.unit_ids:
assert np.all(sorting1.get_unit_spike_train(unit_id) == sorting2.get_unit_spike_train(unit_id))
Good old : 1.6 s
Pre-computing: 1.6 s This might be a phenomenon exclusive to mearec, do you guys suspect that it happens for any other format? Can you c-profile your code to se where is the time spent? It is very strange to me that we calculate the I think that a function to transform The design principle is to corall complexity away from the core. Implementation details of specific formats should not leak up as a general methods. All the current functionality can be coralled to NumpySorting and currently baserecording has a Now, if most formats are like this I think it would make sense but I really want to keep the core simple. I think the bar should be high. |
Yes pre-computing only makes things faster if you come from a spike vector (in other scenarios it's going to be the same speed). However the
I believe there is a typo here ^^ Thanks for the input! |
@DradeAW You are correct. My scenario of performance above is wrong. Here is the corrected scenario: import time
import numpy as np
import spikeinterface.core as si
from spikeinterface.core.generate import SortingGenerator
num_units = 1000
durations = [10 * 60 * 60.0]
seed = 25
sorting1 = SortingGenerator(num_units=num_units, durations=durations, seed=seed)
sorting2 = SortingGenerator(num_units=num_units, durations=durations, seed=seed)
sorting2.to_spike_vector()
t1 = time.perf_counter()
for unit_id in sorting1.unit_ids:
sorting1.get_unit_spike_train(unit_id)
t2 = time.perf_counter()
print(f"Good old : {t2-t1:.1f} s")
t3 = time.perf_counter()
sorting2.precompute_spike_trains()
for unit_id in sorting2.unit_ids:
sorting2.get_unit_spike_train(unit_id)
t4 = time.perf_counter()
print(f"Pre-computing: {t4-t3:.1f} s")
for unit_id in sorting1.unit_ids:
assert np.all(sorting1.get_unit_spike_train(unit_id) == sorting2.get_unit_spike_train(unit_id))
Good old : 1.6 s
Using numba to compute spike trains
Pre-computing: 1.9 s
After caching:
Good old : 1.6 s
Using numba to compute spike trains
Pre-computing: 1.4 s So, it is a bit slower on the first run but then after you cache I it becomes faster. I would label this as no-gain but also no loss even in the case with really fast-io (that is, the case of my generator). So I think this supports your case for adding this as even in this extreme case it performs relleatively well (or not bad). So, yes, empiricism works, I am now in your side, this should work OK in most cases, I still don't like the cache option and would like to have a more readable version of the numba implementation but that should be easy. Thanks for bearing with me. There is an important point though that should not hold this PR. Without numba installed this is terrible slow:
In the light of this, my current position is the following:
As I was writting this I think the numba option is better fit for another issue. |
Then there is a bug, because it should not happen since it should be doing the same operation (for loop and getting the unit spike train). This should be fixed before merging!
I'll add this on my todo list :) |
Don't you believe that numba is just way faster compared to looping? Can you elaborate on why this has to be a bug on the numpy version? |
Numba is faster, but if numba isn't installed, then it should just do it the good old way. So if numba isn't installed, the behaviour before and after this PR shouldn't change! |
# the trick here is to have a function getter | ||
vector_to_list_of_spiketrain = get_numba_vector_to_list_of_spiketrain() | ||
else: | ||
vector_to_list_of_spiketrain = vector_to_list_of_spiketrain_numpy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, if you don't have numba you use the numpy version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is eaxctly the "old" way of course.
return spike_trains | ||
|
||
|
||
def vector_to_list_of_spiketrain_numpy(sample_indices, unit_indices, num_units): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the "old way" you mean using the numpy version then that is what is super-slow and I think should be avoided.
If by the "old way" you mean just calling get_unit_spike_train for every unit
(which is what I meant) then I think is indeed a bug because the implementation is calling this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my version, it was calling get_unit_spike_train(unit_id)
.
But the function Sam wrote should be doing the same thing the NumpySorting
does when you call get_unit_spike_train
... So it should be the same performance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the sorting is a NumpySorting that would be correct, and it makes sense: NumpySorting is slow at extracting spike trains because, well, the representation is just bad for that. That's what it is.
But if you sorting is not as bad as NumpySorting
for calling get_unit_spike_train
then the current implementation is terrible without Numba.
That is, If you have a spike_vector
cached in something that is not as bad NumpySorting
for extracting spikes and you don't have Numba then you will be silently calling a terribly slow function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see the problem!
You are speaking about a (corner) case where the get_unit_spike_train
function is very fast, yet for some reason the spike trains are not cached but the spike vector is.
Indeed you are correct, in this case it would be terribly slow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I take issue with you calling that corner_case
. Why are you so confident than sorting to build the spike_vector and then masking within resulting structured array is a faster operation than just calling get_unit_spike_train
?
My original appeal was to restrict this to methods that already have the spike_vector
pre-calculated like NumpySorting
and.a as you were saying, maybe Phy. There, you can be confident than calling this unlikely to be bad. But then I tested this method with numba and it works OK even for those cases with kind of fast-io so that seems OK.
But for functions that that don't have spike_vector
pre-calculated this seems like a corner case:
- You already called
get_unit_spike_train
to get yourspike_vector
(so your data should be cached already, not need for this) - Then, you already called
get_unit_spike_train
but somehow it is better to mask a very long vector (generating a spike in memory) and extract it from there.
So again, I think this works great for cases that already have spike_vector
pre-calculated and it works OK even for other cases (if you have numba), but, if you don't have it? or if you don't have numba? Are you sure that get_unit_spike_train
is that bad in general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I read again the function, and realized it does not check whether the spike trains cache is already full (this is in another PR). This should be added.
The corner case I was talking about is when get_unit_spike_train
is very fast, yet the spike vector is cached but not the spike trains (since the spike vector is computed form the spike trains, they should be cached).
@samuelgarcia I think you are thinking of another PR, This PR is different! |
Allows for fast pre-computation of spike trains