Skip to content

Commit

Permalink
Merge pull request #2227 from h-mayorquin/add_sorting_generator
Browse files Browse the repository at this point in the history
Add spike-train based lazy SortingGenerator
  • Loading branch information
alejoe91 authored Sep 25, 2024
2 parents d904363 + 8ae1004 commit b2ea8c5
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 5 deletions.
180 changes: 175 additions & 5 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,10 +742,10 @@ def synthesize_poisson_spike_vector(

# Calculate the number of frames in the refractory period
refractory_period_seconds = refractory_period_ms / 1000.0
refactory_period_frames = int(refractory_period_seconds * sampling_frequency)
refractory_period_frames = int(refractory_period_seconds * sampling_frequency)

is_refactory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates)
if is_refactory_period_too_long:
is_refractory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates)
if is_refractory_period_too_long:
raise ValueError(
f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}"
)
Expand All @@ -764,9 +764,9 @@ def synthesize_poisson_spike_vector(
binomial_p_modified = modified_firing_rate / sampling_frequency
binomial_p_modified = np.minimum(binomial_p_modified, 1.0)

# Generate inter spike frames, add the refactory samples and accumulate for sorted spike frames
# Generate inter spike frames, add the refractory samples and accumulate for sorted spike frames
inter_spike_frames = rng.geometric(p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max))
inter_spike_frames[:, 1:] += refactory_period_frames
inter_spike_frames[:, 1:] += refractory_period_frames
spike_frames = np.cumsum(inter_spike_frames, axis=1, out=inter_spike_frames)
spike_frames = spike_frames.ravel()

Expand Down Expand Up @@ -1054,6 +1054,176 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol
return spike_train


from spikeinterface.core.basesorting import BaseSortingSegment, BaseSorting


class SortingGenerator(BaseSorting):
def __init__(
self,
num_units: int = 20,
sampling_frequency: float = 30_000.0, # in Hz
durations: List[float] = [10.325, 3.5], #  in s for 2 segments
firing_rates: float | np.ndarray = 3.0,
refractory_period_ms: float | np.ndarray = 4.0, # in ms
seed: int = 0,
):
"""
A class for lazily generate synthetic sorting objects with Poisson spike trains.
We have two ways of representing spike trains in SpikeInterface:
- Spike vector (sample_index, unit_index)
- Dictionary of unit_id to spike times
This class simulates a sorting object that uses a representation based on unit IDs to lists of spike times,
rather than pre-computed spike vectors. It is intended for testing performance differences and functionalities
in data handling and analysis frameworks. For the normal use case of sorting objects with spike_vectors use the
`generate_sorting` function.
Parameters
----------
num_units : int, optional
The number of distinct units (neurons) to simulate. Default is 20.
sampling_frequency : float, optional
The sampling frequency of the spike data in Hz. Default is 30_000.0.
durations : list of float, optional
A list containing the duration in seconds for each segment of the sorting data. Default is [10.325, 3.5],
corresponding to 2 segments.
firing_rates : float or np.ndarray, optional
The firing rate(s) in Hz, which can be specified as a single value applicable to all units or as an array
with individual firing rates for each unit. Default is 3.0.
refractory_period_ms : float or np.ndarray, optional
The refractory period in milliseconds. Can be specified either as a single value for all units or as an
array with different values for each unit. Default is 4.0.
seed : int, default: 0
The seed for the random number generator to ensure reproducibility.
Raises
------
ValueError
If the refractory period is too long for the given firing rates, which could result in unrealistic
physiological conditions.
Notes
-----
This generator simulates the spike trains using a Poisson process. It takes into account the refractory periods
by adjusting the firing rates accordingly. See the notes on `synthesize_poisson_spike_vector` for more details.
"""

unit_ids = np.arange(num_units)
super().__init__(sampling_frequency, unit_ids)

self.num_units = num_units
self.num_segments = len(durations)
self.firing_rates = firing_rates
self.durations = durations
self.refractory_period_seconds = refractory_period_ms / 1000.0

is_refractory_period_too_long = np.any(self.refractory_period_seconds >= 1.0 / firing_rates)
if is_refractory_period_too_long:
raise ValueError(
f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}"
)

seed = _ensure_seed(seed)
self.seed = seed

for segment_index in range(self.num_segments):
segment_seed = self.seed + segment_index
segment = SortingGeneratorSegment(
num_units=num_units,
sampling_frequency=sampling_frequency,
duration=durations[segment_index],
firing_rates=firing_rates,
refractory_period_seconds=self.refractory_period_seconds,
seed=segment_seed,
t_start=None,
)
self.add_sorting_segment(segment)

self._kwargs = {
"num_units": num_units,
"sampling_frequency": sampling_frequency,
"durations": durations,
"firing_rates": firing_rates,
"refractory_period_ms": refractory_period_ms,
"seed": seed,
}


class SortingGeneratorSegment(BaseSortingSegment):
def __init__(
self,
num_units: int,
sampling_frequency: float,
duration: float,
firing_rates: float | np.ndarray,
refractory_period_seconds: float | np.ndarray,
seed: int,
t_start: Optional[float] = None,
):
self.num_units = num_units
self.duration = duration
self.sampling_frequency = sampling_frequency
self.refractory_period_seconds = refractory_period_seconds

if np.isscalar(firing_rates):
firing_rates = np.full(num_units, firing_rates, dtype="float64")

self.firing_rates = firing_rates

if np.isscalar(self.refractory_period_seconds):
self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64")

self.segment_seed = seed
self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)}
self.num_samples = math.ceil(sampling_frequency * duration)
super().__init__(t_start)

def get_unit_spike_train(self, unit_id, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray:
unit_seed = self.units_seed[unit_id]
unit_index = self.parent_extractor.id_to_index(unit_id)

rng = np.random.default_rng(seed=unit_seed)

firing_rate = self.firing_rates[unit_index]
refractory_period = self.refractory_period_seconds[unit_index]

# p is the probably of an spike per tick of the sampling frequency
binomial_p = firing_rate / self.sampling_frequency
# We estimate how many spikes we will have in the duration
max_frames = int(self.duration * self.sampling_frequency) - 1
max_binomial_p = float(np.max(binomial_p))
num_spikes_expected = ceil(max_frames * max_binomial_p)
num_spikes_std = int(np.sqrt(num_spikes_expected * (1 - max_binomial_p)))
num_spikes_max = num_spikes_expected + 4 * num_spikes_std

# Increase the firing rate to take into account the refractory period
modified_firing_rate = firing_rate / (1 - firing_rate * refractory_period)
binomial_p_modified = modified_firing_rate / self.sampling_frequency
binomial_p_modified = np.minimum(binomial_p_modified, 1.0)

inter_spike_frames = rng.geometric(p=binomial_p_modified, size=num_spikes_max)
spike_frames = np.cumsum(inter_spike_frames)

refractory_period_frames = int(refractory_period * self.sampling_frequency)
spike_frames[1:] += refractory_period_frames

if start_frame is not None:
start_index = np.searchsorted(spike_frames, start_frame, side="left")
else:
start_index = 0

if end_frame is not None:
end_index = np.searchsorted(spike_frames[start_index:], end_frame, side="left")
else:
end_index = int(self.duration * self.sampling_frequency)

spike_frames = spike_frames[start_index:end_index]
return spike_frames


## Noise generator zone ##
class NoiseGeneratorRecording(BaseRecording):
"""
Expand Down
68 changes: 68 additions & 0 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
generate_recording,
generate_sorting,
NoiseGeneratorRecording,
SortingGenerator,
TransformSorting,
generate_recording_by_size,
InjectTemplatesRecording,
Expand Down Expand Up @@ -94,6 +95,73 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float:
return memory


def test_memory_sorting_generator():
# Test that get_traces does not consume more memory than allocated.

bytes_to_MiB_factor = 1024**2
relative_tolerance = 0.05 # relative tolerance of 5 per cent

sampling_frequency = 30000 # Hz
durations = [60.0]
num_units = 1000
seed = 0

before_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor
sorting = SortingGenerator(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=durations,
seed=seed,
)
after_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor
memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB
ratio = memory_usage_MiB / before_instanciation_MiB
expected_allocation_MiB = 0
assert (
ratio <= 1.0 + relative_tolerance
), f"SortingGenerator wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}"


def test_sorting_generator_consisency_across_calls():
sampling_frequency = 30000 # Hz
durations = [1.0]
num_units = 3
seed = 0

sorting = SortingGenerator(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=durations,
seed=seed,
)

for unit_id in sorting.get_unit_ids():
spike_train = sorting.get_unit_spike_train(unit_id=unit_id)
spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id)

assert np.allclose(spike_train, spike_train_again)


def test_sorting_generator_consisency_within_trains():
sampling_frequency = 30000 # Hz
durations = [1.0]
num_units = 3
seed = 0

sorting = SortingGenerator(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=durations,
seed=seed,
)

for unit_id in sorting.get_unit_ids():
spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000)
spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000)

assert np.allclose(spike_train, spike_train_again)


def test_noise_generator_memory():
# Test that get_traces does not consume more memory than allocated.

Expand Down

0 comments on commit b2ea8c5

Please sign in to comment.