-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Poisson statistics to generate_sorting
and optimize memory profile
#2226
Add Poisson statistics to generate_sorting
and optimize memory profile
#2226
Conversation
generate_sorting
and optimize memory profile
@h-mayorquin Curious what speeds you observe with import numpy
seed = 0
random_number_generator = numpy.random.default_rng(seed=seed)
number_of_units = 1_000
firing_rates = 10.0 # Hz, scalar usage
# firing_rates = [10.0 for unit_index in range(number_of_units)] # Hz, vector usage
duration = 60.0 # seconds
# sampling_frequency = None
sampling_frequency = 30_000 # Hz, if specified
# refractory_period = None
refractory_period = 4.0 # milliseconds
def _clean_refractory_period(original_spike_times: numpy.ndarray, refractory_period_seconds: float) -> numpy.ndarray:
inter_spike_intervals = numpy.diff(original_spike_times, prepend=refractory_period_seconds)
violations = inter_spike_intervals < refractory_period_seconds # scale ms to s
if numpy.any(violations):
return original_spike_times
spike_time_shifts = refractory_period_seconds - inter_spike_intervals[violations]
return original_spike_times[violations] + spike_time_shifts
if numpy.isscalar(firing_rates):
number_of_spikes_per_unit = random_number_generator.poisson(lam=firing_rates * duration, size=number_of_units)
else:
number_of_spikes_per_unit = numpy.empty(shape=number_of_units, dtype="uint16")
for unit_index in range(number_of_units):
number_of_spikes_per_unit[unit_index] = int(
random_number_generator.poisson(lam=firing_rates[unit_index] * duration, size=1)
)
spike_times = list()
if sampling_frequency is None:
for number_of_spikes in number_of_spikes_per_unit:
spikes = numpy.sort(random_number_generator.uniform(low=0, high=duration, size=number_of_spikes))
if refractory_period is not None:
spikes = _clean_refractory_period(
original_spike_times=spikes, refractory_period_seconds=refractory_period / 1e3
)
spike_times.append(spikes)
else:
for number_of_spikes in number_of_spikes_per_unit:
spikes = numpy.sort(
random_number_generator.integers(
low=0, high=int(duration * sampling_frequency), size=number_of_spikes, dtype="uint64"
)
) / sampling_frequency
if refractory_period is not None:
spikes = _clean_refractory_period(
original_spike_times=spikes, refractory_period_seconds=refractory_period / 1e3
)
spike_times.append(spikes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring cleanup :)
src/spikeinterface/core/generate.py
Outdated
- The function uses a geometric distribution to simulate the discrete inter-spike intervals, | ||
based that would be an exponential process for continuous time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't quite clear. Maybe it is missing a word. I'm not quite sure how to fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct I think. this is not clear. I will expand on this. Thanks. Also for all the other comments. They all make sense and are very helpful as usual.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified the docstring here, let me know what you think and if it is still unclear to you. Also, any other advice you might have is useful.
Co-authored-by: Zach McKenzie <[email protected]>
Co-authored-by: Zach McKenzie <[email protected]>
Co-authored-by: Zach McKenzie <[email protected]>
Co-authored-by: Zach McKenzie <[email protected]>
@CodyCBakerPhD That is the bottleneck. Generating sorted concatenated spikes and their correspoding units. This is how the output should look like (more or less): spike_frames
[ 452 15511 37989 42417 62234 71248 74939]
unit indexes
[1 1 0 0 1 1 0] The concatenated frames have to be sorted but the frames of different units can happen at the same time. So you can not sort until you concatenate but the cumulative sum has to happen with the compact version. So it is hard to avoid the memory allocation or sort smaller vectors. Reading your code it seems that you just concatenated sorted spikes and I can't see the vector of the units. Adding the concatenation of all the spikes and then sorting (without the units part) the code is already slower than the old implementation with a %%timeit decorator in a notebook. |
To me, the difference between 10,000 ; 10,200 and 10,400 is negligible.
I meant that the resulting ISI might be erroneous. |
Right, it was me who did not understood above. You are correct. Let's switch to 4 stds. I think it makes sense.
I still don't understand this. Lambda is the firing rate, right? Also, this is something that the current implementation is doing, generating the frames in integers. Is this being avoided right now? Maybe this will be helpful to me for understanding. |
Let's take an example where the generation would give But if you first convert to integers, then you have |
@DradeAW
Yeah, maybe try to run some experiments? Maybe you are right but "noticeable" is not very actionable to me, do you expect it to fail a fitness of test? The mean would be far from the true mean? It would be more or less skwed than the exponential? I think that this would be the easiest think to do. Try the current implementation, try this in different scenarios, let's see if in any it does worse. That would be very useful. |
I sorry I misunderstood! |
I did not mean this, I meant your expectations about how the exponential distribution should look like. But now that I understandg you are privileging the count divided by long-term time that makes sense. |
I really don't understand what you are doing, We just want a unit with a given mean firing rate and refractory period, and you method just outputs a wrong firing rate. |
Yes! It is the definition of the mean firing rate!! |
I am implement the Poisson statistics as described in books like: The first references in google indicate a similar implementation. The thing that you are proposing I have never seen. Maybe it is correct and it makes a lot of sense to you but this is not how it is done in the books I have read. Can you show me some reference or other library implementation to see that you proposed convention is common place:
|
I've not read this anywhere, this is what I do by simply manipulating statistics :) I may not be the correct answer in some cases, but in the case of firing_rate = 99.9 Hz and refractory_period = 10 ms, this version simply fails |
x D I will ask Alessio and Sam what they prefer for this case next time we meet as I don't think they will dare to read this super long-thread that you and I did. |
I have been passively enjoyed the passionate exchange ;) |
Interesting! Could I confirm that the only outstanding point of discussion is whether to rescale On this, I am not familiar with the conventions in the field, some seem to do this rescaling 1, 2 whereas in other sources 3, [Dayan & Abbot] it is not mentioned. |
@JoeZiminski this is a good summary. Thanks for the references. So Bartoz does indeed modify the rate, no restrains. Andrew does modify the rate but excludes the case of the refatory period being larger than the inter spike period. Plus, Andrew does provide a reference to Stefan paper where they called this Poisson Process with dead time to get away from the fact that this is not a poisson process anymore: https://link.springer.com/article/10.1007/s10827-011-0362-8 This answers my concerns. If we decide to go for the version in the aferomentioned paper I am fine with it. Thanks, @DradeAW, I learned something today. I never thought carefully before on when do the assumptions of the Poisson renewal process that I learned from Abbot's treatment break down. I still think that the bursting model they have over there makes more sense but this seems to just be another equally valid convention in the field. It also has the advantage that is easier to implement now that I think about it. |
Maybe this whole conversation is just my bad for not understanding this 😅 Sorry if I seemed a bit harsh, but criticizing my math when I spent many hours / days on the question tends to trigger me ahah |
I think this is a very strong point. In fact, I feel convinced by it. The red-hering was that we started discussing about the statistics. You came to a PR that is called "add poisson statistics" and then told me that the ISI interval should not follow the theoretical exepected distribution for a Poisson. Turns out that you have a good reason for that to be not be poisson but I instead doubled down on "this is a poisson distribution, why do you want me to make my statistics not poissonian, you must be getting something wrong!". At the end I think we came to the right crux and we should take pride on that: I asked you:
And I think that you have a great answer for my question there which is:
I do believe that user-centered decisions trumps mathematical soudness so that argment convinces me. And looking in retrospective the discussion went to the distribution because I failed to read the following. You said:
I actually did not understand that you were asking me to measure the firing rate there (probably the use of the proposition "at" and the omission of the word firing rate, writing is hard). And then I started discussing about the pdf of the distribution. But overall the discussion was good to me. I learned something new and I think making this a Poisson distribution with dead times (as Stefan and co call it) is better for maintenance and usability. So thanks for bearing with me, I think the library is better thanks to your efforts. |
Alright, everything turned out well ^^ Sorry I know I'm not always super great at expressing stuff (even in my native language 😅), but we both learned something! Although I'm still curious about something: your implementation is still Poissonian? |
Not strictly, the support is changed (by the shifting) but the ISI is still an exponential and the mean and the std ratio should remain constant in the new support. That said, I am not sure that is not the case in the poisson process with dead time. I tried to fit to the exponential with a shift and that did not work but maybe the paraemters need to be changed in some other way. I pushed a version of your suggestion now but the tests are failing. It seems that those are the new metrics. I will take a look later. |
And this are some related plots that I did after our discussion. DiscretizationThis should answer @DradeAW question about how the discretization behaves in the extreme case: The bins accumulate in How much does the method actually modifies the firing ratesFor myself I was curios on how much is the instantaneous firing rate modified depending on the refactory period: So, yes, the lesson is quite a lot, for high firing rates a small refactory period can easily double the firing rate. |
# We estimate how many spikes we will have in the duration | ||
max_frames = duration * sampling_frequency | ||
max_binomial_p = float(np.max(binomial_p)) | ||
num_spikes_expected = ceil(max_frames * max_binomial_p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_spikes_expected = ceil(max_frames * max_binomial_p) | |
num_spikes_expected = int(np.ceil(max_frames * max_binomial_p)) |
Any interest in this instead and then you can remove the ceil import from math? Or did you really only want to use math.ceil?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the advantage of this?
Last time I checked the math module functions are faster for scalars than numpy function as they avoid the overhead. Speed won't matter that much at this scale though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, the only advantage for this scalar is that you import one less function into the code that is only used once. But reducing imports is not necessarily a good reason. So my comment was more question than hard recommendation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last time I checked the math module functions are faster for scalars than numpy function
Last time I checked, even math.PI
was faster than np.PI
, which I still don't understand ahah
I agree, math
for scalars is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zm711 I see. Yes, importing from the standard library at will is my prior until proven otherwise.
Run the following script:
import pkgutil
import timeit
import sys
# Get a list of all standard library modules
standard_lib_modules = [module for module in pkgutil.iter_modules() if module.name in sys.stdlib_module_names]
# Dictionary to store import times
import_times = {}
for module in standard_lib_modules:
# Measure the import time
time = timeit.timeit(f"import {module.name}", number=1)
import_times[module.name] = time
# Print or process the import times
for module, time in import_times.items():
print(f"{module}: {time} seconds")
You will see that importing from the standard library is at the scale of main memory reference:
https://brenocon.com/dean_perf.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @h-mayorquin! Makes sense.
src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Outdated
Show resolved
Hide resolved
Hi Ramon and Aurelien. I have to admit (with lot of shame) that I did had time to read it before today. |
I added a function that generates spikes as if they are a poisson process which, in my understanding, is the most common statistics of firing rate.
This is is also 3 times faster. Generates a ten hour sorting with 1000 units in 7 seconds instead of 20 as the current one. A ~ 3 speedup.
Plus, it is more memory efficient as it has around 70 % memory requirements of the current function (I did try to make it less memory hungry but it is HARD). It also seems that the current implementation has some kind of leak, see the temporal profile of memory utilizaiton in the profile above.
The first three are the PR implementation, the last three is the current implementation.
[EDIT]:
Something important that I forgot to say is that this sometimes randomly (when your durations are too short and your firing rates too low) will produce empty spike_trains but I think that's fine. It seems like a relevant think to discuss anyway: