Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Poisson statistics to generate_sorting and optimize memory profile #2226

Merged
merged 23 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def get_unit_property(self, unit_id, key):

def get_total_num_spikes(self):
warnings.warn(
"Sorting.get_total_num_spikes() is deprecated, se sorting.count_num_spikes_per_unit()",
"Sorting.get_total_num_spikes() is deprecated and will be removed in spikeinterface 0.102, use sorting.count_num_spikes_per_unit()",
DeprecationWarning,
stacklevel=2,
)
Expand Down
125 changes: 113 additions & 12 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from typing import Union, Optional, List, Literal
import warnings
from math import ceil


from .numpyextractors import NumpyRecording, NumpySorting
Expand Down Expand Up @@ -135,7 +136,7 @@ def generate_sorting(
spikes = []
for segment_index in range(num_segments):
num_samples = int(sampling_frequency * durations[segment_index])
times, labels = synthesize_random_firings(
samples, labels = synthesize_poisson_spike_vector(
num_units=num_units,
sampling_frequency=sampling_frequency,
duration=durations[segment_index],
Expand All @@ -146,11 +147,11 @@ def generate_sorting(

if empty_units is not None:
keep = ~np.isin(labels, empty_units)
times = times[keep]
samples = samples[keep]
labels = labels[keep]

spikes_in_seg = np.zeros(times.size, dtype=minimum_spike_dtype)
spikes_in_seg["sample_index"] = times
spikes_in_seg = np.zeros(samples.size, dtype=minimum_spike_dtype)
spikes_in_seg["sample_index"] = samples
spikes_in_seg["unit_index"] = labels
spikes_in_seg["segment_index"] = segment_index
spikes.append(spikes_in_seg)
Expand Down Expand Up @@ -284,6 +285,114 @@ def generate_snippets(


## spiketrain zone ##
def synthesize_poisson_spike_vector(
num_units=20,
sampling_frequency=30000.0,
duration=60.0,
refractory_period_ms=4.0,
firing_rates=3.0,
seed=0,
):
"""
Generate random spike frames for neuronal units using a Poisson process.

This function simulates the spike activity of multiple neuronal units. Each unit's spiking behavior
is modeled as a Poisson process, with spike times discretized according to the specified sampling frequency.
The function accounts for refractory periods in spike generation, and allows specifying either a uniform
firing rate for all units or distinct firing rates for each unit.

Parameters
----------
num_units : int, default: 20
Number of neuronal units to simulate
sampling_frequency : float, default: 30000.0
Sampling frequency in Hz
duration : float, default: 60.0
Duration of the simulation in seconds
refractory_period_ms : float, default: 4.0
Refractory period between spikes in milliseconds
firing_rates : float or array_like, default: 3.0
Firing rate(s) in Hz. Can be a single value for all units or an array of firing rates with
each element being the firing rate for one unit
seed : int, default: 0
Seed for random number generator

Returns
-------
spike_frames : ndarray
1D array of spike frames.
unit_indices : ndarray
1D array of unit indices corresponding to each spike.

Notes
-----
- The inter-spike intervals are simulated using a geometric distribution, representing the discrete
counterpart to the exponential distribution of intervals in a continuous-time Poisson process.
- The refractory period is enforced by adding a fixed number of frames to each neuron's inter-spike interval,
ensuring no two spikes occur within this period for any single neuron.
- The effective firing rate is adjusted upwards to compensate for the refractory period, following the model in [1].
This adjustment ensures the overall firing rate remains consistent with the specified `firing_rates`,
despite the enforced refractory period.


References
----------
[1] Deger, M., Helias, M., Boucsein, C., & Rotter, S. (2012). Statistical properties of superimposed stationary
spike trains. Journal of Computational Neuroscience, 32(3), 443–463.
"""

rng = np.random.default_rng(seed=seed)

if np.isscalar(firing_rates):
firing_rates = np.full(num_units, firing_rates, dtype="float64")

# Calculate the number of frames in the refractory period
refractory_period_seconds = refractory_period_ms / 1000.0
refactory_period_frames = int(refractory_period_seconds * sampling_frequency)

is_refactory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates)
if is_refactory_period_too_long:
raise ValueError(
f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}"
)

# p is the probably of an spike per tick of the sampling frequency
binomial_p = firing_rates / sampling_frequency
# We estimate how many spikes we will have in the duration
max_frames = duration * sampling_frequency
max_binomial_p = float(np.max(binomial_p))
num_spikes_expected = ceil(max_frames * max_binomial_p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_spikes_expected = ceil(max_frames * max_binomial_p)
num_spikes_expected = int(np.ceil(max_frames * max_binomial_p))

Any interest in this instead and then you can remove the ceil import from math? Or did you really only want to use math.ceil?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the advantage of this?

Last time I checked the math module functions are faster for scalars than numpy function as they avoid the overhead. Speed won't matter that much at this scale though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, the only advantage for this scalar is that you import one less function into the code that is only used once. But reducing imports is not necessarily a good reason. So my comment was more question than hard recommendation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last time I checked the math module functions are faster for scalars than numpy function

Last time I checked, even math.PI was faster than np.PI, which I still don't understand ahah
I agree, math for scalars is better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zm711 I see. Yes, importing from the standard library at will is my prior until proven otherwise.

Run the following script:

import pkgutil
import timeit
import sys 

# Get a list of all standard library modules
standard_lib_modules = [module for module in pkgutil.iter_modules() if module.name in sys.stdlib_module_names]

# Dictionary to store import times
import_times = {}

for module in standard_lib_modules:
    # Measure the import time
    time = timeit.timeit(f"import {module.name}", number=1)
    import_times[module.name] = time

# Print or process the import times
for module, time in import_times.items():
    print(f"{module}: {time} seconds")

You will see that importing from the standard library is at the scale of main memory reference:
https://brenocon.com/dean_perf.html

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @h-mayorquin! Makes sense.

num_spikes_std = int(np.sqrt(num_spikes_expected * (1 - max_binomial_p)))
num_spikes_max = num_spikes_expected + 4 * num_spikes_std

# Increase the firing rate to take into account the refractory period
modified_firing_rate = firing_rates / (1 - firing_rates * refractory_period_seconds)
binomial_p_modified = modified_firing_rate / sampling_frequency
binomial_p_modified = np.minimum(binomial_p_modified, 1.0)

# Generate inter spike frames, add the refactory samples and accumulate for sorted spike frames
inter_spike_frames = rng.geometric(p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max))
inter_spike_frames[:, 1:] += refactory_period_frames
spike_frames = np.cumsum(inter_spike_frames, axis=1, out=inter_spike_frames)
spike_frames = spike_frames.ravel()

# We map the corresponding unit indices
unit_indices = np.repeat(np.arange(num_units, dtype="uint16"), num_spikes_max)

# Eliminate spikes that are beyond the duration
mask = spike_frames <= max_frames
num_correct_frames = np.sum(mask)
spike_frames[:num_correct_frames] = spike_frames[mask] # Avoids a malloc
unit_indices = unit_indices[mask]

# Sort globaly
spike_frames = spike_frames[:num_correct_frames]
sort_indices = np.argsort(spike_frames, kind="stable") # I profiled the different kinds, this is the fastest.

unit_indices = unit_indices[sort_indices]
spike_frames = spike_frames[sort_indices]

return spike_frames, unit_indices


def synthesize_random_firings(
Expand Down Expand Up @@ -327,14 +436,6 @@ def synthesize_random_firings(

rng = np.random.default_rng(seed=seed)

# unit_seeds = [rng.integers(0, 2 ** 63) for i in range(num_units)]

# if seed is not None:
# np.random.seed(seed)
# seeds = np.random.RandomState(seed=seed).randint(0, 2147483647, num_units)
# else:
# seeds = np.random.randint(0, 2147483647, num_units)

if np.isscalar(firing_rates):
firing_rates = np.full(num_units, firing_rates, dtype="float64")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def test_calculate_sd_ratio(waveform_extractor_simple):
)

assert np.all(list(sd_ratio.keys()) == waveform_extractor_simple.unit_ids)
# assert np.allclose(list(sd_ratio.values()), 1, atol=0.2, rtol=0)
assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0)
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
Expand Down