Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Poisson statistics to generate_sorting and optimize memory profile #2226

Merged
merged 23 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def get_unit_property(self, unit_id, key):

def get_total_num_spikes(self):
warnings.warn(
"Sorting.get_total_num_spikes() is deprecated, se sorting.count_num_spikes_per_unit()",
"Sorting.get_total_num_spikes() is deprecated and will be removed in spikeinterface 0.102, se sorting.count_num_spikes_per_unit()",
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
DeprecationWarning,
stacklevel=2,
)
Expand Down
102 changes: 90 additions & 12 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from typing import Union, Optional, List, Literal
import warnings
from math import ceil


from .numpyextractors import NumpyRecording, NumpySorting
Expand Down Expand Up @@ -168,7 +169,7 @@ def generate_sorting(
spikes = []
for segment_index in range(num_segments):
num_samples = int(sampling_frequency * durations[segment_index])
times, labels = synthesize_random_firings(
samples, labels = synthesize_poisson_spike_vector(
num_units=num_units,
sampling_frequency=sampling_frequency,
duration=durations[segment_index],
Expand All @@ -179,11 +180,11 @@ def generate_sorting(

if empty_units is not None:
keep = ~np.isin(labels, empty_units)
times = times[keep]
samples = samples[keep]
labels = labels[keep]

spikes_in_seg = np.zeros(times.size, dtype=minimum_spike_dtype)
spikes_in_seg["sample_index"] = times
spikes_in_seg = np.zeros(samples.size, dtype=minimum_spike_dtype)
spikes_in_seg["sample_index"] = samples
spikes_in_seg["unit_index"] = labels
spikes_in_seg["segment_index"] = segment_index
spikes.append(spikes_in_seg)
Expand Down Expand Up @@ -317,6 +318,91 @@ def generate_snippets(


## spiketrain zone ##
def synthesize_poisson_spike_vector(
num_units=20,
sampling_frequency=30000.0,
duration=60,
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
refractory_period_ms=4.0,
firing_rates=3.0,
seed=0,
):
"""
Generate random spike frames for neuronal units based on a Poisson process.

This function simulates the spike frames of a number of neuronal units, each firing according to
a Poisson process. The spike times are discretized to align with a given sampling frequency.

Parameters
----------
num_units : int, optional
Number of neuronal units to simulate (default is 20).
sampling_frequency : float, optional
Sampling frequency in Hz (default is 30000.0).
duration : float, optional
Duration of the simulation in seconds (default is 60).
refractory_period_ms : float, optional
Refractory period between spikes in milliseconds (default is 4.0).
firing_rates : float or array_like, optional
Firing rate(s) in Hz. Can be a single value or an array of firing rates for each unit
(default is 3.0).
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
seed : int, optional
Seed for random number generator (default is 0).

h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
spike_frames : ndarray
1D array of spike frames.
unit_indices : ndarray
1D array of unit indices corresponding to each spike.

Notes
-----
- The function uses a geometric distribution to simulate the discrete inter-spike intervals,
based that would be an exponential process for continuous time.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite clear. Maybe it is missing a word. I'm not quite sure how to fix.

Copy link
Collaborator Author

@h-mayorquin h-mayorquin Nov 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct I think. this is not clear. I will expand on this. Thanks. Also for all the other comments. They all make sense and are very helpful as usual.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the docstring here, let me know what you think and if it is still unclear to you. Also, any other advice you might have is useful.

"""

rng = np.random.default_rng(seed=seed)

if np.isscalar(firing_rates):
firing_rates = np.full(num_units, firing_rates, dtype="float64")

refractory_period_seconds = refractory_period_ms / 1000.0
refactory_period_frames = int(refractory_period_seconds * sampling_frequency)

# Equivalence between exponential and geometric distributions scaled to the discrete mean
geometric_p = 1 - np.exp(-firing_rates / sampling_frequency)

# We generate as many spikes as the mean plus two std to be sure we have enough
max_frames = duration * sampling_frequency
max_p = geometric_p.max()
num_spikes_expected = ceil(max_frames * max_p)
num_spikes_std = int(np.sqrt(num_spikes_expected * (1 - max_p)))
num_spikes_max = num_spikes_expected + 2 * num_spikes_std

# Generate inter spike frames, add the refactory samples and accumulate for sorted spike frames
inter_spike_frames = rng.geometric(p=geometric_p[:, np.newaxis], size=(num_units, num_spikes_max))
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
inter_spike_frames[:, 1:] += refactory_period_frames

spike_frames = np.cumsum(inter_spike_frames, axis=1, out=inter_spike_frames)
spike_frames = spike_frames.ravel()

# We map un its to corresponding spike times and flatten the array
unit_indices = np.repeat(np.arange(num_units, dtype="uint16"), num_spikes_max)

# Eliminate spikes that are beyond the duration
mask = spike_frames <= max_frames
num_correct_frames = np.sum(mask)
spike_frames[:num_correct_frames] = spike_frames[mask] # Avoids a malloc
unit_indices = unit_indices[mask]

spike_frames = spike_frames[:num_correct_frames]
# Stable should use tim or radix sort which is good for integers and presorted data. I profiled. re-profile in doubt.
sort_indices = np.argsort(spike_frames, kind="stable")

unit_indices = unit_indices[sort_indices]
spike_frames = spike_frames[sort_indices]

return spike_frames, unit_indices


def synthesize_random_firings(
Expand Down Expand Up @@ -360,14 +446,6 @@ def synthesize_random_firings(

rng = np.random.default_rng(seed=seed)

# unit_seeds = [rng.integers(0, 2 ** 63) for i in range(num_units)]

# if seed is not None:
# np.random.seed(seed)
# seeds = np.random.RandomState(seed=seed).randint(0, 2147483647, num_units)
# else:
# seeds = np.random.randint(0, 2147483647, num_units)

if np.isscalar(firing_rates):
firing_rates = np.full(num_units, firing_rates, dtype="float64")

Expand Down