Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow precomputing spike trains #2175

Merged
merged 42 commits into from
Nov 22, 2023

Conversation

DradeAW
Copy link
Contributor

@DradeAW DradeAW commented Nov 6, 2023

Allows for fast pre-computation of spike trains

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 6, 2023

Wow, to my surprise this is not faster!
I don't understand why, maybe I made a mistake somewhere ...

import time
import spikeinterface.core as si


t1 = time.perf_counter()

sorting = si.load_extractor("/mnt/raid0/data/MEArec/1h_3000cells/analyses/ks2_5_pj7-3/sorting")
for unit_id in sorting.unit_ids:
	sorting.get_unit_spike_train(unit_id)

t2 = time.perf_counter()

sorting = si.load_extractor("/mnt/raid0/data/MEArec/1h_3000cells/analyses/ks2_5_pj7-3/sorting")
sorting.precompute_spike_trains()
for unit_id in sorting.unit_ids:
	sorting.get_unit_spike_train(unit_id)

t3 = time.perf_counter()

print(f"Old way: {t2-t1:.1f} s")
print(f"New way: {t3-t2:.1f} s")
Old way: 16.5 s
New way: 18.3 s

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 6, 2023

Ok this is way better with the latest commit!

import time
import numpy as np
import spikeinterface.core as si



sorting1 = si.load_extractor("/mnt/raid0/data/MEArec/1h_3000cells/analyses/ks2_5_pj7-3/sorting")
sorting2 = si.load_extractor("/mnt/raid0/data/MEArec/1h_3000cells/analyses/ks2_5_pj7-3/sorting")

t1 = time.perf_counter()

for unit_id in sorting1.unit_ids:
	sorting1.get_unit_spike_train(unit_id)

t2 = time.perf_counter()

sorting2.precompute_spike_trains()
for unit_id in sorting2.unit_ids:
	sorting2.get_unit_spike_train(unit_id)

t3 = time.perf_counter()

print(f"Old way: {t2-t1:.1f} s")
print(f"New way: {t3-t2:.1f} s")

for unit_id in sorting1.unit_ids:
	assert np.all(sorting1.get_unit_spike_train(unit_id) == sorting2.get_unit_spike_train(unit_id))
Old way: 16.3 s
New way: 1.4 s

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 6, 2023

For reference, my sorting dataset contains 531 units for a total of 11,298,565 spikes (1h recording)

@alejoe91 alejoe91 added the core Changes to core module label Nov 6, 2023
Copy link
Collaborator

@zm711 zm711 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question and suggestion @DradeAW

src/spikeinterface/core/core_tools.py Outdated Show resolved Hide resolved
@@ -424,6 +427,31 @@ def get_all_spike_trains(self, outputs="unit_id"):
spikes.append((spike_times, spike_labels))
return spikes

def precompute_spike_trains(self, from_spike_vector: bool = True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is in core where numba is not automatically installed wouldn't it be safer to have from_spike_vector default to False, so that a user who has only installed core doesn't instantly hit an assert error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the parameter to be complete, but I actually don't know when it would be interesting to precompute spike trains not from the spike vector.
Can you think of a use case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not off the top of my head. But I'll think about it :)

@samuelgarcia
Copy link
Member

Hi Aurelien.
thanks for this.

I will not have time to get feedback on this very soon.
But globaly:

  • thanks a lot for this speed up, this was on my todo list
  • not sure to like the idea of having numba in core... lets see. It can be used optionally if installed, but then we need
    it for testing.
  • not sure to like the semantic in one function. In short many Sorting are already spiketrains centric (most of then) and
    some of then are spike train first (numpysorting, ...). We need this function for 2 clear distinct case:
    1. when spike train centric to preload in memory
    2. when spike vector centric to make the convertion all at once. The caching used to be spiketrain per spiketrain which
      also a valid scenario I think.
      The semantic have to be very clear for this 2 cases. This is now done with one function with one option.
  • I like the idea to explicitly call this function.

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 7, 2023

not sure to like the semantic in one function [...] We need this function for 2 clear distinct case

@samuelgarcia You would prefer to split this in two functions with no parameter each?

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 7, 2023

not sure to like the idea of having numba in core... lets see. It can be used optionally if installed, but then we need it for testing.

I removed the need for numba (there is now a default if it's not installed).
For the tests, I've added a skip if numba is not installed, so it will not make the test fail.

Copy link
Collaborator

@h-mayorquin h-mayorquin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments added.

Two big questions:
How does this interacts plays with the caching option in to_spike_vector and get_unit_spike_train.

Isn't it duplicating behavior?

How would this play out with multiprocessing, all this pre-computing is lost isn't it?

src/spikeinterface/core/core_tools.py Outdated Show resolved Hide resolved
src/spikeinterface/core/core_tools.py Outdated Show resolved Hide resolved
src/spikeinterface/core/core_tools.py Outdated Show resolved Hide resolved
src/spikeinterface/core/core_tools.py Outdated Show resolved Hide resolved
@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 9, 2023

How does this interacts plays with the caching option in to_spike_vector and get_unit_spike_train.

It overrides it (forces a recompute).

Isn't it duplicating behavior?

No, because it is a fast computation of all spike trains at once, rather than one unit by one unit (which is, as I showed above, much faster).

How would this play out with multiprocessing, all this pre-computing is lost isn't it?

I'm not sure, but I believe if you get the spike train it should find it in the cache in all cases?

@h-mayorquin
Copy link
Collaborator

No, because it is a fast computation of all spike trains at once, rather than one unit by one unit (which is, as I showed above, much faster).

So are the functions introduced here doing a better job at calculating the spike_vector than to_spike_vector?

@@ -449,6 +451,33 @@ def get_all_spike_trains(self, outputs="unit_id"):
spikes.append((spike_times, spike_labels))
return spikes

def precompute_spike_trains(self, from_spike_vector=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like too much this method name.
Maybe cache_all_spike_trains() would be more explicit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, I would expect a function cache_all_spike_trains to not perform any computation, but just move the result of computation to the cache.
But I agree that the name could maybe be improved

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 17, 2023

There are probably a few places in core that could be improved by calling this function, because they need the spike train of all unit ids (maybe when extracting waveforms?)

Copy link
Collaborator

@zm711 zm711 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup of docstrings

src/spikeinterface/core/sorting_tools.py Outdated Show resolved Hide resolved
src/spikeinterface/core/sorting_tools.py Outdated Show resolved Hide resolved
src/spikeinterface/core/sorting_tools.py Show resolved Hide resolved
src/spikeinterface/core/sorting_tools.py Outdated Show resolved Hide resolved
@samuelgarcia
Copy link
Member

There are probably a few places in core that could be improved by calling this function, because they need the spike train of all unit ids (maybe when extracting waveforms?)

Extract waveforms is based on the spike vector.
But maybe elsewhere.

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 17, 2023

Latest benchmark: even faster

~200x times faster on a very big sorting:

Old way: 634.3 s
New way: 2.9 s

~35 times faster for a sorting I typically use:

Old way: 24.9 s
New way: 0.7 s

@h-mayorquin
Copy link
Collaborator

h-mayorquin commented Nov 18, 2023

Thanks @samuelgarcia . I introduced some machinery to test this:

#2227

I want to decouple the performance from the Mearec memory model layout limitations.

I will give it a test as soon as I can.

@h-mayorquin
Copy link
Collaborator

I found no difference (as I expected)

import time
import numpy as np
import spikeinterface.core as si


num_units = 1000
durations = [10 * 60 * 60.0]
seed = 25

sorting1 =  SortingGenerator(num_units=num_units, durations=durations, seed=seed)
sorting2 = SortingGenerator(num_units=num_units, durations=durations, seed=seed)

t1 = time.perf_counter()

for unit_id in sorting1.unit_ids:
	sorting1.get_unit_spike_train(unit_id)

t2 = time.perf_counter()

sorting2.precompute_spike_trains()
for unit_id in sorting2.unit_ids:
	sorting2.get_unit_spike_train(unit_id)

t3 = time.perf_counter()


print(f"Good old : {t2-t1:.1f} s")
print(f"Pre-computing: {t3-t2:.1f} s")

for unit_id in sorting1.unit_ids:
	assert np.all(sorting1.get_unit_spike_train(unit_id) == sorting2.get_unit_spike_train(unit_id))

Good old : 1.6 s
Pre-computing: 1.6 s

This might be a phenomenon exclusive to mearec, do you guys suspect that it happens for any other format? Can you c-profile your code to se where is the time spent?

It is very strange to me that we calculate the spike_trains through the spike_vector which is a terrible representation for that: you mix your spike vectors into one long vector (with a sort as a cost) and then you use a numba function to demix because, well, it was not made for that.

I think that a function to transform spike_vector to spike_trains through numba makes a lot of sense in the NumpySorting. That is where it should be most useful and makes thematic sense. Specially if this is a phenomeon exclusive to Mearec I don't think it is worth adding -more- complexity to the core just for this case.

The design principle is to corall complexity away from the core. Implementation details of specific formats should not leak up as a general methods. All the current functionality can be coralled to NumpySorting and currently baserecording has a .to_numpy_sorting method that enables this easily. That is the place to handle functionality that is useful when you have all the data in memory. The core is for lazy operations.

Now, if most formats are like this I think it would make sense but I really want to keep the core simple. I think the bar should be high.

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 20, 2023

Yes pre-computing only makes things faster if you come from a spike vector (in other scenarios it's going to be the same speed).

However the NumpySorting is not the only sorting using a spike vector (e.g. PhySortingExtractor).
Plus Sam wanted to make a focus on making sorting more spike vector centered.
Plus NumpySorting is the default save method, so it might be used a lot.

Currently baserecording has a .to_numpy_sorting

I believe there is a typo here ^^

Thanks for the input!
This should not be taken lightly indeed. However I believe that retrieving spike trains is a core functionality of a sorting object, and thus should be optimized (I posted the massive gain for huge sorting objects).

@h-mayorquin
Copy link
Collaborator

@DradeAW You are correct. My scenario of performance above is wrong. Here is the corrected scenario:

import time
import numpy as np
import spikeinterface.core as si
from spikeinterface.core.generate import SortingGenerator


num_units = 1000
durations = [10 * 60 * 60.0]
seed = 25

sorting1 =  SortingGenerator(num_units=num_units, durations=durations, seed=seed)
sorting2 = SortingGenerator(num_units=num_units, durations=durations, seed=seed)
sorting2.to_spike_vector()

t1 = time.perf_counter()

for unit_id in sorting1.unit_ids:
	sorting1.get_unit_spike_train(unit_id)

t2 = time.perf_counter()
print(f"Good old : {t2-t1:.1f} s")


t3 = time.perf_counter()
sorting2.precompute_spike_trains()
for unit_id in sorting2.unit_ids:
	sorting2.get_unit_spike_train(unit_id)

t4 = time.perf_counter()


print(f"Pre-computing: {t4-t3:.1f} s")

for unit_id in sorting1.unit_ids:
	assert np.all(sorting1.get_unit_spike_train(unit_id) == sorting2.get_unit_spike_train(unit_id))

Good old : 1.6 s
Using numba to compute spike trains
Pre-computing: 1.9 s
After caching:
Good old : 1.6 s
Using numba to compute spike trains
Pre-computing: 1.4 s

So, it is a bit slower on the first run but then after you cache I it becomes faster. I would label this as no-gain but also no loss even in the case with really fast-io (that is, the case of my generator). So I think this supports your case for adding this as even in this extreme case it performs relleatively well (or not bad). So, yes, empiricism works, I am now in your side, this should work OK in most cases, I still don't like the cache option and would like to have a more readable version of the numba implementation but that should be easy. Thanks for bearing with me.

There is an important point though that should not hold this PR. Without numba installed this is terrible slow:

not using numba
Good old : 1.8 s
Pre-computing: 76.9 s

In the light of this, my current position is the following:

  • If we are going to have this we should have numba in core. Otherwise, we are adding a very slow option that will get triggered without the user knowing about this. Given that the point of adding this is to make performance faster I would not like to run the super-slow function silently. Another option is to throw an assertion letting the user know that this method is not available if they don't have numba. But at this point, I don't really see why not just adding numba at the core. It is not heavy, it is fast to import and it it would also enable other improvements that I have in mind and would avoid all the annoyance about importing numba in a special way.

As I was writting this I think the numba option is better fit for another issue.

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 20, 2023

Without numba installed this is terrible slow:

Then there is a bug, because it should not happen since it should be doing the same operation (for loop and getting the unit spike train). This should be fixed before merging!

I still don't like the cache option and would like to have a more readable version of the numba implementation but that should be easy.

I'll add this on my todo list :)

@h-mayorquin
Copy link
Collaborator

Without numba installed this is terrible slow:

Then there is a bug, because it should not happen since it should be doing the same operation (for loop and getting the unit spike train). This should be fixed before merging!

Don't you believe that numba is just way faster compared to looping? Can you elaborate on why this has to be a bug on the numpy version?

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 20, 2023

Numba is faster, but if numba isn't installed, then it should just do it the good old way.

So if numba isn't installed, the behaviour before and after this PR shouldn't change!

# the trick here is to have a function getter
vector_to_list_of_spiketrain = get_numba_vector_to_list_of_spiketrain()
else:
vector_to_list_of_spiketrain = vector_to_list_of_spiketrain_numpy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, if you don't have numba you use the numpy version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is eaxctly the "old" way of course.

return spike_trains


def vector_to_list_of_spiketrain_numpy(sample_indices, unit_indices, num_units):
Copy link
Collaborator

@h-mayorquin h-mayorquin Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samuelgarcia @DradeAW

If the "old way" you mean using the numpy version then that is what is super-slow and I think should be avoided.

If by the "old way" you mean just calling get_unit_spike_train for every unit (which is what I meant) then I think is indeed a bug because the implementation is calling this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my version, it was calling get_unit_spike_train(unit_id).

But the function Sam wrote should be doing the same thing the NumpySorting does when you call get_unit_spike_train ... So it should be the same performance

Copy link
Collaborator

@h-mayorquin h-mayorquin Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the sorting is a NumpySorting that would be correct, and it makes sense: NumpySorting is slow at extracting spike trains because, well, the representation is just bad for that. That's what it is.

But if you sorting is not as bad as NumpySorting for calling get_unit_spike_train then the current implementation is terrible without Numba.

That is, If you have a spike_vector cached in something that is not as bad NumpySorting for extracting spikes and you don't have Numba then you will be silently calling a terribly slow function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see the problem!

You are speaking about a (corner) case where the get_unit_spike_train function is very fast, yet for some reason the spike trains are not cached but the spike vector is.
Indeed you are correct, in this case it would be terribly slow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take issue with you calling that corner_case. Why are you so confident than sorting to build the spike_vector and then masking within resulting structured array is a faster operation than just calling get_unit_spike_train?

My original appeal was to restrict this to methods that already have the spike_vector pre-calculated like NumpySorting and.a as you were saying, maybe Phy. There, you can be confident than calling this unlikely to be bad. But then I tested this method with numba and it works OK even for those cases with kind of fast-io so that seems OK.

But for functions that that don't have spike_vector pre-calculated this seems like a corner case:

  1. You already called get_unit_spike_train to get your spike_vector (so your data should be cached already, not need for this)
  2. Then, you already called get_unit_spike_train but somehow it is better to mask a very long vector (generating a spike in memory) and extract it from there.

So again, I think this works great for cases that already have spike_vector pre-calculated and it works OK even for other cases (if you have numba), but, if you don't have it? or if you don't have numba? Are you sure that get_unit_spike_train is that bad in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I read again the function, and realized it does not check whether the spike trains cache is already full (this is in another PR). This should be added.

The corner case I was talking about is when get_unit_spike_train is very fast, yet the spike vector is cached but not the spike trains (since the spike vector is computed form the spike trains, they should be cached).

@samuelgarcia
Copy link
Member

This PR should be closed and the discussion should happen in #2209!
#2209 included all this concept but rewritten.

@DradeAW
Copy link
Contributor Author

DradeAW commented Nov 20, 2023

@samuelgarcia I think you are thinking of another PR,

This PR is different!

@alejoe91 alejoe91 merged commit 029c24a into SpikeInterface:main Nov 22, 2023
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Changes to core module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants