Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Poisson statistics to generate_sorting and optimize memory profile #2226

Merged
merged 23 commits into from
Jan 22, 2024

Conversation

h-mayorquin
Copy link
Collaborator

@h-mayorquin h-mayorquin commented Nov 18, 2023

I added a function that generates spikes as if they are a poisson process which, in my understanding, is the most common statistics of firing rate.

This is is also 3 times faster. Generates a ten hour sorting with 1000 units in 7 seconds instead of 20 as the current one. A ~ 3 speedup.

--------------------------------------------------
`synthesize_poisson_spike_vector`
Mean time over 3 iterations: 6.94 seconds
Std over 3 iterations: 0.01 seconds
times=['6.92', '6.96', '6.95']
--------------------------------------------------
`synthesize_random_firings`
Mean time over 3 iterations: 21.12 seconds
Std over 3 iterations: 0.09 seconds
times=['21.05', '21.05', '21.25']

Speedup: 3.04

Plus, it is more memory efficient as it has around 70 % memory requirements of the current function (I did try to make it less memory hungry but it is HARD). It also seems that the current implementation has some kind of leak, see the temporal profile of memory utilizaiton in the profile above.

image

The first three are the PR implementation, the last three is the current implementation.

[EDIT]:
Something important that I forgot to say is that this sometimes randomly (when your durations are too short and your firing rates too low) will produce empty spike_trains but I think that's fine. It seems like a relevant think to discuss anyway:

from spikeinterface.core.generate import generate_sorting


seed = 4
sorting = generate_sorting(num_units=2, durations=[1.0], sampling_frequency=30000.0, seed=seed)
sorting.get_unit_spike_train(0, return_times=True)
array([], dtype=float64)

@h-mayorquin h-mayorquin added the core Changes to core module label Nov 18, 2023
@h-mayorquin h-mayorquin marked this pull request as ready for review November 18, 2023 14:47
@h-mayorquin h-mayorquin changed the title Add posion statistics to generate_sorting Add Poisson statistics to generate_sorting and optimize memory profile Nov 18, 2023
@h-mayorquin h-mayorquin self-assigned this Nov 18, 2023
@CodyCBakerPhD
Copy link
Collaborator

@h-mayorquin Curious what speeds you observe with

import numpy

seed = 0
random_number_generator = numpy.random.default_rng(seed=seed)

number_of_units = 1_000

firing_rates = 10.0  # Hz, scalar usage
# firing_rates = [10.0 for unit_index in range(number_of_units)]  # Hz, vector usage

duration = 60.0  # seconds

# sampling_frequency = None
sampling_frequency = 30_000  # Hz, if specified

# refractory_period = None
refractory_period = 4.0  # milliseconds


def _clean_refractory_period(original_spike_times: numpy.ndarray, refractory_period_seconds: float) -> numpy.ndarray:
    inter_spike_intervals = numpy.diff(original_spike_times, prepend=refractory_period_seconds)
    violations = inter_spike_intervals < refractory_period_seconds  # scale ms to s
    if numpy.any(violations):
        return original_spike_times
    spike_time_shifts = refractory_period_seconds - inter_spike_intervals[violations]
    return original_spike_times[violations] + spike_time_shifts


if numpy.isscalar(firing_rates):
    number_of_spikes_per_unit = random_number_generator.poisson(lam=firing_rates * duration, size=number_of_units)
else:
    number_of_spikes_per_unit = numpy.empty(shape=number_of_units, dtype="uint16")
    for unit_index in range(number_of_units):
        number_of_spikes_per_unit[unit_index] = int(
            random_number_generator.poisson(lam=firing_rates[unit_index] * duration, size=1)
        )

spike_times = list()
if sampling_frequency is None:
    for number_of_spikes in number_of_spikes_per_unit:
        spikes = numpy.sort(random_number_generator.uniform(low=0, high=duration, size=number_of_spikes))
        if refractory_period is not None:
            spikes = _clean_refractory_period(
                original_spike_times=spikes, refractory_period_seconds=refractory_period / 1e3
            )
        spike_times.append(spikes)
else:
    for number_of_spikes in number_of_spikes_per_unit:
        spikes = numpy.sort(
            random_number_generator.integers(
                low=0, high=int(duration * sampling_frequency), size=number_of_spikes, dtype="uint64"
            )
        ) / sampling_frequency
        if refractory_period is not None:
            spikes = _clean_refractory_period(
                original_spike_times=spikes, refractory_period_seconds=refractory_period / 1e3
            )
        spike_times.append(spikes)

Copy link
Collaborator

@zm711 zm711 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring cleanup :)

src/spikeinterface/core/basesorting.py Outdated Show resolved Hide resolved
src/spikeinterface/core/generate.py Outdated Show resolved Hide resolved
src/spikeinterface/core/generate.py Outdated Show resolved Hide resolved
src/spikeinterface/core/generate.py Outdated Show resolved Hide resolved
Comment on lines 360 to 361
- The function uses a geometric distribution to simulate the discrete inter-spike intervals,
based that would be an exponential process for continuous time.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite clear. Maybe it is missing a word. I'm not quite sure how to fix.

Copy link
Collaborator Author

@h-mayorquin h-mayorquin Nov 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct I think. this is not clear. I will expand on this. Thanks. Also for all the other comments. They all make sense and are very helpful as usual.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the docstring here, let me know what you think and if it is still unclear to you. Also, any other advice you might have is useful.

@h-mayorquin
Copy link
Collaborator Author

h-mayorquin commented Nov 20, 2023

@CodyCBakerPhD
Most of the computation in the function happens here:
https://github.com/catalystneuro/spikeinterface/blob/ac41b95eb31fc76d21fd2e8f0917092d086b481e/src/spikeinterface/core/generate.py#L400-L405

That is the bottleneck. Generating sorted concatenated spikes and their correspoding units. This is how the output should look like (more or less):

spike_frames
[  452 15511 37989 42417 62234 71248 74939]
unit indexes
[1 1 0 0 1 1 0]

The concatenated frames have to be sorted but the frames of different units can happen at the same time. So you can not sort until you concatenate but the cumulative sum has to happen with the compact version. So it is hard to avoid the memory allocation or sort smaller vectors.

Reading your code it seems that you just concatenated sorted spikes and I can't see the vector of the units. Adding the concatenation of all the spikes and then sorting (without the units part) the code is already slower than the old implementation with a %%timeit decorator in a notebook.

@DradeAW
Copy link
Contributor

DradeAW commented Nov 20, 2023

Maybe I am not understanding what you mean by negligible?

To me, the difference between 10,000 ; 10,200 and 10,400 is negligible.

I am not sure what you mean by bias, it is surely not in a formal sense

I meant that the resulting ISI might be erroneous.
Let's take the (extreme) example where lambda is significantly smaller than the delta time between two samples.
Then you're only generating zeros. But if you are casting after, you will still add values smaller than 0 and after multiple spikes, it will go to the next sample.

@h-mayorquin
Copy link
Collaborator Author

To me, the difference between 10,000 ; 10,200 and 10,400 is negligible.

Right, it was me who did not understood above. You are correct. Let's switch to 4 stds. I think it makes sense.

I meant that the resulting ISI might be erroneous.
Let's take the (extreme) example where lambda is significantly smaller than the delta time between two samples.
Then you're only generating zeros. But if you are casting after, you will still add values smaller than 0 and after multiple spikes, it will go to the next sample.

I still don't understand this. Lambda is the firing rate, right?
What would be an erroneous ISI? Maybe can test this, are you saying that the frames generated by this function transformed to times will not have an exponential distribution in some sense?

Also, this is something that the current implementation is doing, generating the frames in integers. Is this being avoided right now? Maybe this will be helpful to me for understanding.

@DradeAW
Copy link
Contributor

DradeAW commented Nov 20, 2023

Let's take an example where the generation would give [4.4, 3.3, 6.3]. The cumulative sum is [4.4, 7.7, 14.0], which after casting is [4, 7/8, 14] (depending on the rule)

But if you first convert to integers, then you have [4, 7, 13] (which is clearly different, and what I was referring to when I said bias).
Of course if the firing rate is low, you won't notice any difference. But with a high firing rate and high refractory period, I think it can be noticeable.

@h-mayorquin
Copy link
Collaborator Author

h-mayorquin commented Nov 20, 2023

@DradeAW
But I am not sure what I am doing is equivalent. I am generating the integers already, where a spike can be (or not) on each of the ticks of the sampling rate. I am not converting to integers. Like, the problem would be the other way around, how good a binomial distribution approximates a poisson but the assumptions are very well held here.

But with a high firing rate and high refractory period, I think it can be noticeable.

Yeah, maybe try to run some experiments? Maybe you are right but "noticeable" is not very actionable to me, do you expect it to fail a fitness of test? The mean would be far from the true mean? It would be more or less skwed than the exponential? I think that this would be the easiest think to do. Try the current implementation, try this in different scenarios, let's see if in any it does worse. That would be very useful.

@DradeAW
Copy link
Contributor

DradeAW commented Nov 20, 2023

I am generating the integers already, where a spike can be (or not) on each of the ticks of the sampling rate

I sorry I misunderstood!
Then yes, it is fine :)

@h-mayorquin
Copy link
Collaborator Author

Ah, OK, still it would be useful to know if this breaks in some limit. Check it out:

image

Blue is empirical.

@h-mayorquin
Copy link
Collaborator Author

Your math seems to be failing you

Here is the mathematical proof then:

The mean of a exponential distribution is E(X) = 1 / lambda = beta So E(X + refractory_period) = E(X) + refractory_period

Thus, E(X) + refractory_period = 1 / firing_rate <=> E(X) = 1 / firing_rate - refractory_period QED

I did not mean this, I meant your expectations about how the exponential distribution should look like. But now that I understandg you are privileging the count divided by long-term time that makes sense.

@DradeAW
Copy link
Contributor

DradeAW commented Nov 20, 2023

I really don't understand what you are doing,

We just want a unit with a given mean firing rate and refractory period, and you method just outputs a wrong firing rate.

@DradeAW
Copy link
Contributor

DradeAW commented Nov 20, 2023

you are privileging the count divided by long-term time that makes sense

Yes! It is the definition of the mean firing rate!!

@h-mayorquin
Copy link
Collaborator Author

I am implement the Poisson statistics as described in books like:
https://mitpress.mit.edu/9780262041997/theoretical-neuroscience/#:~:text=Larry%20Abbott%20is%20Professor%20of,Unit%20at%20University%20College%20London.

Check out page 31:
image

The first references in google indicate a similar implementation.

The thing that you are proposing I have never seen. Maybe it is correct and it makes a lot of sense to you but this is not how it is done in the books I have read. Can you show me some reference or other library implementation to see that you proposed convention is common place:

  • Increase the firing rate outisde of the refactory perdiod so as to keep the number of spikes constant over long intervals.

@DradeAW
Copy link
Contributor

DradeAW commented Nov 20, 2023

I've not read this anywhere, this is what I do by simply manipulating statistics :)

I may not be the correct answer in some cases, but in the case of firing_rate = 99.9 Hz and refractory_period = 10 ms, this version simply fails

@h-mayorquin
Copy link
Collaborator Author

I've not read this anywhere, this is what I do by simply manipulating statistics :)

x D
Lol

I will ask Alessio and Sam what they prefer for this case next time we meet as I don't think they will dare to read this super long-thread that you and I did.

@alejoe91
Copy link
Member

I have been passively enjoyed the passionate exchange ;)

@JoeZiminski
Copy link
Collaborator

JoeZiminski commented Nov 20, 2023

Interesting! Could I confirm that the only outstanding point of discussion is whether to rescale $\lambda$ such that the passed firing rate to the function is the firing rate of the final spike train? Alternatively, the passed firing rate is the firing rate $\lambda$ used for the Poisson process prior to adding refractory effects. Are there any other differences between the implementations? (let me know if I have misunderstood)

On this, I am not familiar with the conventions in the field, some seem to do this rescaling 1, 2 whereas in other sources 3, [Dayan & Abbot] it is not mentioned.

@h-mayorquin
Copy link
Collaborator Author

h-mayorquin commented Nov 20, 2023

Could I confirm that the only outstanding point of discussion is whether to rescale
such that the passed firing rate to the function is the firing rate of the final spike train? Alternatively, the passed firing rate is the firing rate used for the Poisson process prior to adding refractory effects.

@JoeZiminski this is a good summary.

Thanks for the references. So Bartoz does indeed modify the rate, no restrains. Andrew does modify the rate but excludes the case of the refatory period being larger than the inter spike period. Plus, Andrew does provide a reference to Stefan paper where they called this Poisson Process with dead time to get away from the fact that this is not a poisson process anymore:

https://link.springer.com/article/10.1007/s10827-011-0362-8

This answers my concerns. If we decide to go for the version in the aferomentioned paper I am fine with it. Thanks, @DradeAW, I learned something today. I never thought carefully before on when do the assumptions of the Poisson renewal process that I learned from Abbot's treatment break down. I still think that the bursting model they have over there makes more sense but this seems to just be another equally valid convention in the field. It also has the advantage that is easier to implement now that I think about it.

@DradeAW
Copy link
Contributor

DradeAW commented Nov 20, 2023

only outstanding point of discussion is whether to rescale such that the passed firing rate to the function is the firing rate of the final spike train?

Maybe this whole conversation is just my bad for not understanding this 😅
But as a user, if I ask for a spike train with a given firing rate, I expect the output to be that firing rate?

Sorry if I seemed a bit harsh, but criticizing my math when I spent many hours / days on the question tends to trigger me ahah

@h-mayorquin
Copy link
Collaborator Author

@DradeAW

But as a user, if I ask for a spike train with a given firing rate, I expect the output to be that firing rate?

I think this is a very strong point. In fact, I feel convinced by it.

The red-hering was that we started discussing about the statistics. You came to a PR that is called "add poisson statistics" and then told me that the ISI interval should not follow the theoretical exepected distribution for a Poisson. Turns out that you have a good reason for that to be not be poisson but I instead doubled down on "this is a poisson distribution, why do you want me to make my statistics not poissonian, you must be getting something wrong!". At the end I think we came to the right crux and we should take pride on that:

I asked you:

Your proposal keeps the total number of spikes in time divided by time equal to the firing rate by sacrificing the instantaneous firing rate. Mine does the opposite and preserves the instantaneous firing rate as given at the expense of the count. I care about the statistics being changed, why should I change my statistics to keep the counting over longer averages like that

And I think that you have a great answer for my question there which is:

But as a user, if I ask for a spike train with a given firing rate, I expect the output to be that firing rate?

I do believe that user-centered decisions trumps mathematical soudness so that argment convinces me.

And looking in retrospective the discussion went to the distribution because I failed to read the following. You said:

This looks wrong, can you check that what you get is at 99.99 Hz?

I actually did not understand that you were asking me to measure the firing rate there (probably the use of the proposition "at" and the omission of the word firing rate, writing is hard). And then I started discussing about the pdf of the distribution.

But overall the discussion was good to me. I learned something new and I think making this a Poisson distribution with dead times (as Stefan and co call it) is better for maintenance and usability. So thanks for bearing with me, I think the library is better thanks to your efforts.

@DradeAW
Copy link
Contributor

DradeAW commented Nov 22, 2023

Alright, everything turned out well ^^

Sorry I know I'm not always super great at expressing stuff (even in my native language 😅), but we both learned something!

Although I'm still curious about something: your implementation is still Poissonian?
I thought (but I might be wrong about this) that a Poisson distribution with a refractory period was not a true Poissonian distribution?
Genuinely asking!

@h-mayorquin
Copy link
Collaborator Author

Not strictly, the support is changed (by the shifting) but the ISI is still an exponential and the mean and the std ratio should remain constant in the new support.

That said, I am not sure that is not the case in the poisson process with dead time. I tried to fit to the exponential with a shift and that did not work but maybe the paraemters need to be changed in some other way.

I pushed a version of your suggestion now but the tests are failing. It seems that those are the new metrics. I will take a look later.

@h-mayorquin
Copy link
Collaborator Author

And this are some related plots that I did after our discussion.

Discretization

This should answer @DradeAW question about how the discretization behaves in the extreme case:
image

The bins accumulate in 1/sampling_rate because that the smallest time represented by the sorting.

How much does the method actually modifies the firing rates

For myself I was curios on how much is the instantaneous firing rate modified depending on the refactory period:
image
(note the color map is cut at 10 but it goes very large)

image

image

So, yes, the lesson is quite a lot, for high firing rates a small refactory period can easily double the firing rate.

# We estimate how many spikes we will have in the duration
max_frames = duration * sampling_frequency
max_binomial_p = float(np.max(binomial_p))
num_spikes_expected = ceil(max_frames * max_binomial_p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_spikes_expected = ceil(max_frames * max_binomial_p)
num_spikes_expected = int(np.ceil(max_frames * max_binomial_p))

Any interest in this instead and then you can remove the ceil import from math? Or did you really only want to use math.ceil?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the advantage of this?

Last time I checked the math module functions are faster for scalars than numpy function as they avoid the overhead. Speed won't matter that much at this scale though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, the only advantage for this scalar is that you import one less function into the code that is only used once. But reducing imports is not necessarily a good reason. So my comment was more question than hard recommendation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last time I checked the math module functions are faster for scalars than numpy function

Last time I checked, even math.PI was faster than np.PI, which I still don't understand ahah
I agree, math for scalars is better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zm711 I see. Yes, importing from the standard library at will is my prior until proven otherwise.

Run the following script:

import pkgutil
import timeit
import sys 

# Get a list of all standard library modules
standard_lib_modules = [module for module in pkgutil.iter_modules() if module.name in sys.stdlib_module_names]

# Dictionary to store import times
import_times = {}

for module in standard_lib_modules:
    # Measure the import time
    time = timeit.timeit(f"import {module.name}", number=1)
    import_times[module.name] = time

# Print or process the import times
for module, time in import_times.items():
    print(f"{module}: {time} seconds")

You will see that importing from the standard library is at the scale of main memory reference:
https://brenocon.com/dean_perf.html

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @h-mayorquin! Makes sense.

@alejoe91 alejoe91 added this to the 0.100.0 milestone Jan 9, 2024
@alejoe91 alejoe91 added the hybrid Related to Hybrid testing label Jan 19, 2024
@samuelgarcia
Copy link
Member

Hi Ramon and Aurelien.
Thanks a lot for this PR and this discusssion.
This is really an excelent discussion and piece of work.

I have to admit (with lot of shame) that I did had time to read it before today.
So lets merge this now.

@alejoe91 alejoe91 merged commit 581d8d1 into SpikeInterface:main Jan 22, 2024
11 checks passed
@alejoe91 alejoe91 deleted the improve_generate_sorting branch January 22, 2024 14:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Changes to core module hybrid Related to Hybrid testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants