From 6c076de15d41f6ea93d5f0ed9c5e2ffe8493732d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 18 Nov 2023 15:35:28 +0100 Subject: [PATCH 1/4] add sorting generator --- src/spikeinterface/core/basesorting.py | 6 +- src/spikeinterface/core/generate.py | 110 ++++++++++++++++++ .../core/tests/test_generate.py | 28 +++++ 3 files changed, 141 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 3c976c3de3..54a633644a 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -31,10 +31,10 @@ def __init__(self, sampling_frequency: float, unit_ids: List): def __repr__(self): clsname = self.__class__.__name__ - nseg = self.get_num_segments() - nunits = self.get_num_units() + num_segments = self.get_num_segments() + num_units = self.get_num_units() sf_khz = self.get_sampling_frequency() / 1000.0 - txt = f"{clsname}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz" + txt = f"{clsname}: {num_units} units - {num_segments} segments - {sf_khz:0.1f}kHz" if "file_path" in self._kwargs: txt += "\n file_path: {}".format(self._kwargs["file_path"]) return txt diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 1c8661d12d..adb9a8d53f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -566,6 +566,116 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol return spike_train +from spikeinterface.core.basesorting import BaseSortingSegment, BaseSorting + + +class SortingGenerator(BaseSorting): + def __init__( + self, + num_units: int = 5, + sampling_frequency: float = 30000.0, # in Hz + durations: List[float] = [10.325, 3.5], #  in s for 2 segments + firing_rates: float | np.ndarray = 3.0, + refractory_period_ms: float | np.ndarray = 4.0, # in ms + seed: Optional[int] = None, + ): + unit_ids = np.arange(num_units) + super().__init__(sampling_frequency, unit_ids) + + self.num_units = num_units + self.num_segments = len(durations) + self.firing_rates = firing_rates + self.durations = durations + self.refactory_period_ms = refractory_period_ms + + seed = _ensure_seed(seed) + self.seed = seed + + for segment_index in range(self.num_segments): + segment_seed = self.seed + segment_index + segment = SortingGeneratorSegment( + num_units=num_units, + sampling_frequency=sampling_frequency, + duration=durations[segment_index], + firing_rates=firing_rates, + refractory_period_ms=refractory_period_ms, + seed=segment_seed, + t_start=None, + ) + self.add_sorting_segment(segment) + + self._kwargs = { + "num_units": num_units, + "sampling_frequency": sampling_frequency, + "durations": durations, + "firing_rates": firing_rates, + "refactory_period_ms": refractory_period_ms, + "seed": seed, + } + + +class SortingGeneratorSegment(BaseSortingSegment): + def __init__( + self, + num_units: int, + sampling_frequency: float, + duration: float, + firing_rates: float | np.ndarray, + refractory_period_ms: float | np.ndarray, + seed: int, + t_start: Optional[float] = None, + ): + self.num_units = num_units + self.duration = duration + self.sampling_frequency = sampling_frequency + + if np.isscalar(firing_rates): + firing_rates = np.full(num_units, firing_rates, dtype="float64") + + self.firing_rates = firing_rates + + if np.isscalar(refractory_period_ms): + refractory_period_ms = np.full(num_units, refractory_period_ms, dtype="float64") + + self.refractory_period_seconds = refractory_period_ms / 1000.0 + self.segment_seed = seed + self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)} + self.num_samples = math.ceil(sampling_frequency * duration) + super().__init__(t_start) + + def get_unit_spike_train(self, unit_id, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray: + unit_seed = self.units_seed[unit_id] + unit_index = self.parent_extractor.id_to_index(unit_id) + + rng = np.random.default_rng(seed=unit_seed) + + # Poisson process statistics + num_spikes_expected = math.ceil(self.firing_rates[unit_id] * self.duration) + num_spikes_std = math.ceil(np.sqrt(num_spikes_expected)) + num_spikes_max = num_spikes_expected + 2 * num_spikes_std + + p_geometric = 1.0 - np.exp(-self.firing_rates[unit_index] / self.sampling_frequency) + + inter_spike_frames = rng.geometric(p=p_geometric, size=num_spikes_max) + spike_frames = np.cumsum(inter_spike_frames, out=inter_spike_frames) + + refactory_period_frames = int(self.refractory_period_seconds[unit_index] * self.sampling_frequency) + spike_frames[1:] += refactory_period_frames + + if start_frame is not None: + start_index = np.searchsorted(spike_frames, start_frame, side="left") + else: + start_index = 0 + + if end_frame is not None: + end_index = np.searchsorted(spike_frames[start_index:], end_frame, side="right") + else: + end_index = int(self.duration * self.sampling_frequency) + + spike_frames = spike_frames[start_index:end_index] + return spike_frames + + ## Noise generator zone ## diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 7b51abcccb..ea2edae6e2 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -10,6 +10,7 @@ generate_recording, generate_sorting, NoiseGeneratorRecording, + SortingGenerator, generate_recording_by_size, InjectTemplatesRecording, generate_single_fake_waveform, @@ -90,6 +91,33 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float: return memory +def test_memory_sorting_generator(): + # Test that get_traces does not consume more memory than allocated. + + bytes_to_MiB_factor = 1024**2 + relative_tolerance = 0.05 # relative tolerance of 5 per cent + + sampling_frequency = 30000 # Hz + durations = [60.0] + num_units = 1000 + seed = 0 + + before_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor + sorting = SortingGenerator( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=durations, + seed=seed, + ) + after_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor + memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB + ratio = memory_usage_MiB / before_instanciation_MiB + expected_allocation_MiB = 0 + assert ( + ratio <= 1.0 + relative_tolerance + ), f"SortingGenerator wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" + + def test_noise_generator_memory(): # Test that get_traces does not consume more memory than allocated. From b1e0b9abbc342fcf2a62e97b11e11e8eeb769188 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 18 Nov 2023 15:42:56 +0100 Subject: [PATCH 2/4] add more tests --- .../core/tests/test_generate.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index ea2edae6e2..2a720170d9 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -118,6 +118,46 @@ def test_memory_sorting_generator(): ), f"SortingGenerator wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" +def test_sorting_generator_consisency_across_calls(): + sampling_frequency = 30000 # Hz + durations = [1.0] + num_units = 3 + seed = 0 + + sorting = SortingGenerator( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=durations, + seed=seed, + ) + + for unit_id in sorting.get_unit_ids(): + spike_train = sorting.get_unit_spike_train(unit_id=unit_id) + spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id) + + assert np.allclose(spike_train, spike_train_again) + + +def test_sorting_generator_consisency_within_trains(): + sampling_frequency = 30000 # Hz + durations = [1.0] + num_units = 3 + seed = 0 + + sorting = SortingGenerator( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=durations, + seed=seed, + ) + + for unit_id in sorting.get_unit_ids(): + spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000) + spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000) + + assert np.allclose(spike_train, spike_train_again) + + def test_noise_generator_memory(): # Test that get_traces does not consume more memory than allocated. From 3e5eedceb03d6d98a2a2519588d239755f3a6eea Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 10 May 2024 08:49:20 -0600 Subject: [PATCH 3/4] added docstring --- src/spikeinterface/core/generate.py | 108 +++++++++++++++++++++------- 1 file changed, 84 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 6995906802..caf525c19f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -687,10 +687,10 @@ def synthesize_poisson_spike_vector( # Calculate the number of frames in the refractory period refractory_period_seconds = refractory_period_ms / 1000.0 - refactory_period_frames = int(refractory_period_seconds * sampling_frequency) + refractory_period_frames = int(refractory_period_seconds * sampling_frequency) - is_refactory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates) - if is_refactory_period_too_long: + is_refractory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates) + if is_refractory_period_too_long: raise ValueError( f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}" ) @@ -709,9 +709,9 @@ def synthesize_poisson_spike_vector( binomial_p_modified = modified_firing_rate / sampling_frequency binomial_p_modified = np.minimum(binomial_p_modified, 1.0) - # Generate inter spike frames, add the refactory samples and accumulate for sorted spike frames + # Generate inter spike frames, add the refractory samples and accumulate for sorted spike frames inter_spike_frames = rng.geometric(p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max)) - inter_spike_frames[:, 1:] += refactory_period_frames + inter_spike_frames[:, 1:] += refractory_period_frames spike_frames = np.cumsum(inter_spike_frames, axis=1, out=inter_spike_frames) spike_frames = spike_frames.ravel() @@ -982,13 +982,57 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol class SortingGenerator(BaseSorting): def __init__( self, - num_units: int = 5, - sampling_frequency: float = 30000.0, # in Hz + num_units: int = 20, + sampling_frequency: float = 30_000.0, # in Hz durations: List[float] = [10.325, 3.5], #  in s for 2 segments firing_rates: float | np.ndarray = 3.0, refractory_period_ms: float | np.ndarray = 4.0, # in ms - seed: Optional[int] = None, + seed: int = 0, ): + """ + A class for lazily generate synthetic sorting objects with Poisson spike trains. + + We have two ways of representing spike trains in SpikeInterface: + + - Spike vector (sample_index, unit_index) + - Dictionary of unit_id to spike times + + This class simulates a sorting object that uses a representation based on unit IDs to lists of spike times, + rather than pre-computed spike vectors. It is intended for testing performance differences and functionalities + in data handling and analysis frameworks. For the normal use case of sorting objects with spike_vectors use the + `generate_sorting` function. + + Parameters + ---------- + num_units : int, optional + The number of distinct units (neurons) to simulate. Default is 20. + sampling_frequency : float, optional + The sampling frequency of the spike data in Hz. Default is 30_000.0. + durations : list of float, optional + A list containing the duration in seconds for each segment of the sorting data. Default is [10.325, 3.5], + corresponding to 2 segments. + firing_rates : float or np.ndarray, optional + The firing rate(s) in Hz, which can be specified as a single value applicable to all units or as an array + with individual firing rates for each unit. Default is 3.0. + refractory_period_ms : float or np.ndarray, optional + The refractory period in milliseconds. Can be specified either as a single value for all units or as an + array with different values for each unit. Default is 4.0. + seed : int, default: 0 + The seed for the random number generator to ensure reproducibility. + + Raises + ------ + ValueError + If the refractory period is too long for the given firing rates, which could result in unrealistic + physiological conditions. + + Notes + ----- + This generator simulates the spike trains using a Poisson process. It takes into account the refractory periods + by adjusting the firing rates accordingly. See the notes on `synthesize_poisson_spike_vector` for more details. + + """ + unit_ids = np.arange(num_units) super().__init__(sampling_frequency, unit_ids) @@ -996,7 +1040,13 @@ def __init__( self.num_segments = len(durations) self.firing_rates = firing_rates self.durations = durations - self.refactory_period_ms = refractory_period_ms + self.refractory_period_seconds = refractory_period_ms / 1000.0 + + is_refractory_period_too_long = np.any(self.refractory_period_seconds >= 1.0 / firing_rates) + if is_refractory_period_too_long: + raise ValueError( + f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}" + ) seed = _ensure_seed(seed) self.seed = seed @@ -1008,7 +1058,7 @@ def __init__( sampling_frequency=sampling_frequency, duration=durations[segment_index], firing_rates=firing_rates, - refractory_period_ms=refractory_period_ms, + refractory_period_seconds=self.refractory_period_seconds, seed=segment_seed, t_start=None, ) @@ -1019,7 +1069,7 @@ def __init__( "sampling_frequency": sampling_frequency, "durations": durations, "firing_rates": firing_rates, - "refactory_period_ms": refractory_period_ms, + "refractory_period_ms": refractory_period_ms, "seed": seed, } @@ -1031,23 +1081,23 @@ def __init__( sampling_frequency: float, duration: float, firing_rates: float | np.ndarray, - refractory_period_ms: float | np.ndarray, + refractory_period_seconds: float | np.ndarray, seed: int, t_start: Optional[float] = None, ): self.num_units = num_units self.duration = duration self.sampling_frequency = sampling_frequency + self.refractory_period_seconds = refractory_period_seconds if np.isscalar(firing_rates): firing_rates = np.full(num_units, firing_rates, dtype="float64") self.firing_rates = firing_rates - if np.isscalar(refractory_period_ms): - refractory_period_ms = np.full(num_units, refractory_period_ms, dtype="float64") + if np.isscalar(self.refractory_period_seconds): + self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64") - self.refractory_period_seconds = refractory_period_ms / 1000.0 self.segment_seed = seed self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)} self.num_samples = math.ceil(sampling_frequency * duration) @@ -1059,18 +1109,28 @@ def get_unit_spike_train(self, unit_id, start_frame: int | None = None, end_fram rng = np.random.default_rng(seed=unit_seed) - # Poisson process statistics - num_spikes_expected = math.ceil(self.firing_rates[unit_id] * self.duration) - num_spikes_std = math.ceil(np.sqrt(num_spikes_expected)) - num_spikes_max = num_spikes_expected + 2 * num_spikes_std + firing_rate = self.firing_rates[unit_index] + refractory_period = self.refractory_period_seconds[unit_index] + + # p is the probably of an spike per tick of the sampling frequency + binomial_p = firing_rate / self.sampling_frequency + # We estimate how many spikes we will have in the duration + max_frames = int(self.duration * self.sampling_frequency) - 1 + max_binomial_p = float(np.max(binomial_p)) + num_spikes_expected = ceil(max_frames * max_binomial_p) + num_spikes_std = int(np.sqrt(num_spikes_expected * (1 - max_binomial_p))) + num_spikes_max = num_spikes_expected + 4 * num_spikes_std - p_geometric = 1.0 - np.exp(-self.firing_rates[unit_index] / self.sampling_frequency) + # Increase the firing rate to take into account the refractory period + modified_firing_rate = firing_rate / (1 - firing_rate * refractory_period) + binomial_p_modified = modified_firing_rate / self.sampling_frequency + binomial_p_modified = np.minimum(binomial_p_modified, 1.0) - inter_spike_frames = rng.geometric(p=p_geometric, size=num_spikes_max) - spike_frames = np.cumsum(inter_spike_frames, out=inter_spike_frames) + inter_spike_frames = rng.geometric(p=binomial_p_modified, size=num_spikes_max) + spike_frames = np.cumsum(inter_spike_frames) - refactory_period_frames = int(self.refractory_period_seconds[unit_index] * self.sampling_frequency) - spike_frames[1:] += refactory_period_frames + refractory_period_frames = int(refractory_period * self.sampling_frequency) + spike_frames[1:] += refractory_period_frames if start_frame is not None: start_index = np.searchsorted(spike_frames, start_frame, side="left") From 8ae10044f9077201647e8a70c242641c57f5ac05 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Sep 2024 10:14:37 +0200 Subject: [PATCH 4/4] Update src/spikeinterface/core/generate.py Co-authored-by: Garcia Samuel --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index d8d0a9f999..1cc9cfa760 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1216,7 +1216,7 @@ def get_unit_spike_train(self, unit_id, start_frame: int | None = None, end_fram start_index = 0 if end_frame is not None: - end_index = np.searchsorted(spike_frames[start_index:], end_frame, side="right") + end_index = np.searchsorted(spike_frames[start_index:], end_frame, side="left") else: end_index = int(self.duration * self.sampling_frequency)