Skip to content

Commit

Permalink
Merge pull request #2226 from catalystneuro/improve_generate_sorting
Browse files Browse the repository at this point in the history
Add Poisson statistics to `generate_sorting` and optimize memory profile
  • Loading branch information
alejoe91 authored Jan 22, 2024
2 parents 51a9a7e + db1568b commit 581d8d1
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def get_unit_property(self, unit_id, key):

def get_total_num_spikes(self):
warnings.warn(
"Sorting.get_total_num_spikes() is deprecated, se sorting.count_num_spikes_per_unit()",
"Sorting.get_total_num_spikes() is deprecated and will be removed in spikeinterface 0.102, use sorting.count_num_spikes_per_unit()",
DeprecationWarning,
stacklevel=2,
)
Expand Down
125 changes: 113 additions & 12 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from typing import Union, Optional, List, Literal
import warnings
from math import ceil

from .basesorting import SpikeVectorSortingSegment
from .numpyextractors import NumpySorting
Expand Down Expand Up @@ -135,7 +136,7 @@ def generate_sorting(
spikes = []
for segment_index in range(num_segments):
num_samples = int(sampling_frequency * durations[segment_index])
times, labels = synthesize_random_firings(
samples, labels = synthesize_poisson_spike_vector(
num_units=num_units,
sampling_frequency=sampling_frequency,
duration=durations[segment_index],
Expand All @@ -146,11 +147,11 @@ def generate_sorting(

if empty_units is not None:
keep = ~np.isin(labels, empty_units)
times = times[keep]
samples = samples[keep]
labels = labels[keep]

spikes_in_seg = np.zeros(times.size, dtype=minimum_spike_dtype)
spikes_in_seg["sample_index"] = times
spikes_in_seg = np.zeros(samples.size, dtype=minimum_spike_dtype)
spikes_in_seg["sample_index"] = samples
spikes_in_seg["unit_index"] = labels
spikes_in_seg["segment_index"] = segment_index
spikes.append(spikes_in_seg)
Expand Down Expand Up @@ -606,6 +607,114 @@ def generate_snippets(


## spiketrain zone ##
def synthesize_poisson_spike_vector(
num_units=20,
sampling_frequency=30000.0,
duration=60.0,
refractory_period_ms=4.0,
firing_rates=3.0,
seed=0,
):
"""
Generate random spike frames for neuronal units using a Poisson process.
This function simulates the spike activity of multiple neuronal units. Each unit's spiking behavior
is modeled as a Poisson process, with spike times discretized according to the specified sampling frequency.
The function accounts for refractory periods in spike generation, and allows specifying either a uniform
firing rate for all units or distinct firing rates for each unit.
Parameters
----------
num_units : int, default: 20
Number of neuronal units to simulate
sampling_frequency : float, default: 30000.0
Sampling frequency in Hz
duration : float, default: 60.0
Duration of the simulation in seconds
refractory_period_ms : float, default: 4.0
Refractory period between spikes in milliseconds
firing_rates : float or array_like, default: 3.0
Firing rate(s) in Hz. Can be a single value for all units or an array of firing rates with
each element being the firing rate for one unit
seed : int, default: 0
Seed for random number generator
Returns
-------
spike_frames : ndarray
1D array of spike frames.
unit_indices : ndarray
1D array of unit indices corresponding to each spike.
Notes
-----
- The inter-spike intervals are simulated using a geometric distribution, representing the discrete
counterpart to the exponential distribution of intervals in a continuous-time Poisson process.
- The refractory period is enforced by adding a fixed number of frames to each neuron's inter-spike interval,
ensuring no two spikes occur within this period for any single neuron.
- The effective firing rate is adjusted upwards to compensate for the refractory period, following the model in [1].
This adjustment ensures the overall firing rate remains consistent with the specified `firing_rates`,
despite the enforced refractory period.
References
----------
[1] Deger, M., Helias, M., Boucsein, C., & Rotter, S. (2012). Statistical properties of superimposed stationary
spike trains. Journal of Computational Neuroscience, 32(3), 443–463.
"""

rng = np.random.default_rng(seed=seed)

if np.isscalar(firing_rates):
firing_rates = np.full(num_units, firing_rates, dtype="float64")

# Calculate the number of frames in the refractory period
refractory_period_seconds = refractory_period_ms / 1000.0
refactory_period_frames = int(refractory_period_seconds * sampling_frequency)

is_refactory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates)
if is_refactory_period_too_long:
raise ValueError(
f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}"
)

# p is the probably of an spike per tick of the sampling frequency
binomial_p = firing_rates / sampling_frequency
# We estimate how many spikes we will have in the duration
max_frames = duration * sampling_frequency
max_binomial_p = float(np.max(binomial_p))
num_spikes_expected = ceil(max_frames * max_binomial_p)
num_spikes_std = int(np.sqrt(num_spikes_expected * (1 - max_binomial_p)))
num_spikes_max = num_spikes_expected + 4 * num_spikes_std

# Increase the firing rate to take into account the refractory period
modified_firing_rate = firing_rates / (1 - firing_rates * refractory_period_seconds)
binomial_p_modified = modified_firing_rate / sampling_frequency
binomial_p_modified = np.minimum(binomial_p_modified, 1.0)

# Generate inter spike frames, add the refactory samples and accumulate for sorted spike frames
inter_spike_frames = rng.geometric(p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max))
inter_spike_frames[:, 1:] += refactory_period_frames
spike_frames = np.cumsum(inter_spike_frames, axis=1, out=inter_spike_frames)
spike_frames = spike_frames.ravel()

# We map the corresponding unit indices
unit_indices = np.repeat(np.arange(num_units, dtype="uint16"), num_spikes_max)

# Eliminate spikes that are beyond the duration
mask = spike_frames <= max_frames
num_correct_frames = np.sum(mask)
spike_frames[:num_correct_frames] = spike_frames[mask] # Avoids a malloc
unit_indices = unit_indices[mask]

# Sort globaly
spike_frames = spike_frames[:num_correct_frames]
sort_indices = np.argsort(spike_frames, kind="stable") # I profiled the different kinds, this is the fastest.

unit_indices = unit_indices[sort_indices]
spike_frames = spike_frames[sort_indices]

return spike_frames, unit_indices


def synthesize_random_firings(
Expand Down Expand Up @@ -649,14 +758,6 @@ def synthesize_random_firings(

rng = np.random.default_rng(seed=seed)

# unit_seeds = [rng.integers(0, 2 ** 63) for i in range(num_units)]

# if seed is not None:
# np.random.seed(seed)
# seeds = np.random.RandomState(seed=seed).randint(0, 2147483647, num_units)
# else:
# seeds = np.random.randint(0, 2147483647, num_units)

if np.isscalar(firing_rates):
firing_rates = np.full(num_units, firing_rates, dtype="float64")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def test_calculate_sd_ratio(waveform_extractor_simple):
)

assert np.all(list(sd_ratio.keys()) == waveform_extractor_simple.unit_ids)
# assert np.allclose(list(sd_ratio.values()), 1, atol=0.2, rtol=0)
assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0)


if __name__ == "__main__":
Expand Down

0 comments on commit 581d8d1

Please sign in to comment.