Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spike-train based lazy SortingGenerator #2227

Merged
merged 7 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 175 additions & 5 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,10 +742,10 @@ def synthesize_poisson_spike_vector(

# Calculate the number of frames in the refractory period
refractory_period_seconds = refractory_period_ms / 1000.0
refactory_period_frames = int(refractory_period_seconds * sampling_frequency)
refractory_period_frames = int(refractory_period_seconds * sampling_frequency)

is_refactory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates)
if is_refactory_period_too_long:
is_refractory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates)
if is_refractory_period_too_long:
raise ValueError(
f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}"
)
Expand All @@ -764,9 +764,9 @@ def synthesize_poisson_spike_vector(
binomial_p_modified = modified_firing_rate / sampling_frequency
binomial_p_modified = np.minimum(binomial_p_modified, 1.0)

# Generate inter spike frames, add the refactory samples and accumulate for sorted spike frames
# Generate inter spike frames, add the refractory samples and accumulate for sorted spike frames
inter_spike_frames = rng.geometric(p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max))
inter_spike_frames[:, 1:] += refactory_period_frames
inter_spike_frames[:, 1:] += refractory_period_frames
spike_frames = np.cumsum(inter_spike_frames, axis=1, out=inter_spike_frames)
spike_frames = spike_frames.ravel()

Expand Down Expand Up @@ -1054,6 +1054,176 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol
return spike_train


from spikeinterface.core.basesorting import BaseSortingSegment, BaseSorting


class SortingGenerator(BaseSorting):
def __init__(
self,
num_units: int = 20,
sampling_frequency: float = 30_000.0, # in Hz
durations: List[float] = [10.325, 3.5], #  in s for 2 segments
firing_rates: float | np.ndarray = 3.0,
refractory_period_ms: float | np.ndarray = 4.0, # in ms
seed: int = 0,
):
"""
A class for lazily generate synthetic sorting objects with Poisson spike trains.

We have two ways of representing spike trains in SpikeInterface:

- Spike vector (sample_index, unit_index)
- Dictionary of unit_id to spike times

This class simulates a sorting object that uses a representation based on unit IDs to lists of spike times,
rather than pre-computed spike vectors. It is intended for testing performance differences and functionalities
in data handling and analysis frameworks. For the normal use case of sorting objects with spike_vectors use the
`generate_sorting` function.

Parameters
----------
num_units : int, optional
The number of distinct units (neurons) to simulate. Default is 20.
sampling_frequency : float, optional
The sampling frequency of the spike data in Hz. Default is 30_000.0.
durations : list of float, optional
A list containing the duration in seconds for each segment of the sorting data. Default is [10.325, 3.5],
corresponding to 2 segments.
firing_rates : float or np.ndarray, optional
The firing rate(s) in Hz, which can be specified as a single value applicable to all units or as an array
with individual firing rates for each unit. Default is 3.0.
refractory_period_ms : float or np.ndarray, optional
The refractory period in milliseconds. Can be specified either as a single value for all units or as an
array with different values for each unit. Default is 4.0.
seed : int, default: 0
The seed for the random number generator to ensure reproducibility.

Raises
------
ValueError
If the refractory period is too long for the given firing rates, which could result in unrealistic
physiological conditions.

Notes
-----
This generator simulates the spike trains using a Poisson process. It takes into account the refractory periods
by adjusting the firing rates accordingly. See the notes on `synthesize_poisson_spike_vector` for more details.

"""

unit_ids = np.arange(num_units)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we have a docstring for this? For example you have durations as a list of 2 floats for 2 segments, so for future users it might be beneficial for them to know that the len(durations)==n_segments. And what are the implications of giving one firing rate vs an array of firing rates if I as a user try something like this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we should. The duration thing is a very pervasive in a bunch of functions we should document it. Less pevasive but not exclusive to this one is the firing rates as an array or float.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@h-mayorquin can you add a docstring here? Maybe explaining the difference with the generate_sorting

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also corrected a typo of refactory to refractory

super().__init__(sampling_frequency, unit_ids)

self.num_units = num_units
self.num_segments = len(durations)
self.firing_rates = firing_rates
self.durations = durations
self.refractory_period_seconds = refractory_period_ms / 1000.0

is_refractory_period_too_long = np.any(self.refractory_period_seconds >= 1.0 / firing_rates)
if is_refractory_period_too_long:
raise ValueError(
f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}"
)

seed = _ensure_seed(seed)
self.seed = seed

for segment_index in range(self.num_segments):
segment_seed = self.seed + segment_index
segment = SortingGeneratorSegment(
num_units=num_units,
sampling_frequency=sampling_frequency,
duration=durations[segment_index],
firing_rates=firing_rates,
refractory_period_seconds=self.refractory_period_seconds,
seed=segment_seed,
t_start=None,
)
self.add_sorting_segment(segment)

self._kwargs = {
"num_units": num_units,
"sampling_frequency": sampling_frequency,
"durations": durations,
"firing_rates": firing_rates,
"refractory_period_ms": refractory_period_ms,
"seed": seed,
}


class SortingGeneratorSegment(BaseSortingSegment):
def __init__(
self,
num_units: int,
sampling_frequency: float,
duration: float,
firing_rates: float | np.ndarray,
refractory_period_seconds: float | np.ndarray,
seed: int,
t_start: Optional[float] = None,
):
self.num_units = num_units
self.duration = duration
self.sampling_frequency = sampling_frequency
self.refractory_period_seconds = refractory_period_seconds

if np.isscalar(firing_rates):
firing_rates = np.full(num_units, firing_rates, dtype="float64")

self.firing_rates = firing_rates

if np.isscalar(self.refractory_period_seconds):
self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64")

self.segment_seed = seed
self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)}
self.num_samples = math.ceil(sampling_frequency * duration)
super().__init__(t_start)

def get_unit_spike_train(self, unit_id, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray:
unit_seed = self.units_seed[unit_id]
unit_index = self.parent_extractor.id_to_index(unit_id)

rng = np.random.default_rng(seed=unit_seed)

firing_rate = self.firing_rates[unit_index]
refractory_period = self.refractory_period_seconds[unit_index]

# p is the probably of an spike per tick of the sampling frequency
binomial_p = firing_rate / self.sampling_frequency
# We estimate how many spikes we will have in the duration
max_frames = int(self.duration * self.sampling_frequency) - 1
max_binomial_p = float(np.max(binomial_p))
num_spikes_expected = ceil(max_frames * max_binomial_p)
num_spikes_std = int(np.sqrt(num_spikes_expected * (1 - max_binomial_p)))
num_spikes_max = num_spikes_expected + 4 * num_spikes_std

# Increase the firing rate to take into account the refractory period
modified_firing_rate = firing_rate / (1 - firing_rate * refractory_period)
binomial_p_modified = modified_firing_rate / self.sampling_frequency
binomial_p_modified = np.minimum(binomial_p_modified, 1.0)

inter_spike_frames = rng.geometric(p=binomial_p_modified, size=num_spikes_max)
spike_frames = np.cumsum(inter_spike_frames)

refractory_period_frames = int(refractory_period * self.sampling_frequency)
spike_frames[1:] += refractory_period_frames

if start_frame is not None:
start_index = np.searchsorted(spike_frames, start_frame, side="left")
else:
start_index = 0

if end_frame is not None:
end_index = np.searchsorted(spike_frames[start_index:], end_frame, side="right")
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
else:
end_index = int(self.duration * self.sampling_frequency)

spike_frames = spike_frames[start_index:end_index]
return spike_frames


## Noise generator zone ##
class NoiseGeneratorRecording(BaseRecording):
"""
Expand Down
68 changes: 68 additions & 0 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
generate_recording,
generate_sorting,
NoiseGeneratorRecording,
SortingGenerator,
TransformSorting,
generate_recording_by_size,
InjectTemplatesRecording,
Expand Down Expand Up @@ -94,6 +95,73 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float:
return memory


def test_memory_sorting_generator():
# Test that get_traces does not consume more memory than allocated.

bytes_to_MiB_factor = 1024**2
relative_tolerance = 0.05 # relative tolerance of 5 per cent

sampling_frequency = 30000 # Hz
durations = [60.0]
num_units = 1000
seed = 0

before_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor
sorting = SortingGenerator(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=durations,
seed=seed,
)
after_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor
memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB
ratio = memory_usage_MiB / before_instanciation_MiB
expected_allocation_MiB = 0
assert (
ratio <= 1.0 + relative_tolerance
), f"SortingGenerator wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}"


def test_sorting_generator_consisency_across_calls():
sampling_frequency = 30000 # Hz
durations = [1.0]
num_units = 3
seed = 0

sorting = SortingGenerator(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=durations,
seed=seed,
)

for unit_id in sorting.get_unit_ids():
spike_train = sorting.get_unit_spike_train(unit_id=unit_id)
spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id)

assert np.allclose(spike_train, spike_train_again)


def test_sorting_generator_consisency_within_trains():
sampling_frequency = 30000 # Hz
durations = [1.0]
num_units = 3
seed = 0

sorting = SortingGenerator(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=durations,
seed=seed,
)

for unit_id in sorting.get_unit_ids():
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000)
spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000)

assert np.allclose(spike_train, spike_train_again)


def test_noise_generator_memory():
# Test that get_traces does not consume more memory than allocated.

Expand Down
Loading