From 6590e0f1845a62300d617d3896969ec303bebc46 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Sep 2024 14:24:16 +0200 Subject: [PATCH] nodepipeline add skip_after_n_peaks option --- src/spikeinterface/core/node_pipeline.py | 22 +++++++++-- .../core/tests/test_node_pipeline.py | 37 +++++++++++++++++-- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index e72f87f794..d04ad59f46 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -497,6 +497,7 @@ def run_node_pipeline( folder=None, names=None, verbose=False, + skip_after_n_peaks=None, ): """ Common function to run pipeline with peak detector or already detected peak. @@ -507,6 +508,11 @@ def run_node_pipeline( job_kwargs = fix_job_kwargs(job_kwargs) assert all(isinstance(node, PipelineNode) for node in nodes) + if skip_after_n_peaks is not None: + skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"] + else: + skip_after_n_peaks_per_worker = None + if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": @@ -514,7 +520,7 @@ def run_node_pipeline( else: raise ValueError(f"wrong gather_mode : {gather_mode}") - init_args = (recording, nodes) + init_args = (recording, nodes, skip_after_n_peaks_per_worker) processor = ChunkRecordingExecutor( recording, @@ -533,12 +539,14 @@ def run_node_pipeline( return outs -def _init_peak_pipeline(recording, nodes): +def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker): # create a local dict per worker worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["nodes"] = nodes worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) + worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker + worker_ctx["num_peaks"] = 0 return worker_ctx @@ -546,6 +554,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c recording = worker_ctx["recording"] max_margin = worker_ctx["max_margin"] nodes = worker_ctx["nodes"] + skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] recording_segment = recording._recording_segments[segment_index] node0 = nodes[0] @@ -557,7 +566,11 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c else: # PeakDetector always need traces load_trace_and_compute = True - + + if skip_after_n_peaks_per_worker is not None: + if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker: + load_trace_and_compute = False + if load_trace_and_compute: traces_chunk, left_margin, right_margin = get_chunk_with_margin( recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True @@ -590,6 +603,9 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c node_output = node.compute(traces_chunk, *node_input_args) pipeline_outputs[node] = node_output + if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource): + worker_ctx["num_peaks"] += node_output[0].size + # propagate the output pipeline_outputs_tuple = tuple() for node in nodes: diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index a2919f5673..f31757d6bc 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -83,7 +83,7 @@ def test_run_node_pipeline(cache_folder_creation): extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) - print(peaks.size) + # print(peaks.size) peak_retriever = PeakRetriever(recording, peaks) # this test when no spikes in last chunks @@ -191,6 +191,37 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) +def test_skip_after_n_peaks(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) + + # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) + + spikes = sorting.to_spike_vector() + + # create peaks from spikes + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + + peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) + # print(peaks.size) + + node0 = PeakRetriever(recording, peaks) + node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) + nodes = [node0, node1] + + skip_after_n_peaks = 30 + some_amplitudes = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks) + + assert some_amplitudes.size >= skip_after_n_peaks + assert some_amplitudes.size < spikes.size + + + + if __name__ == "__main__": - folder = Path("./cache_folder/core") - test_run_node_pipeline(folder) + # folder = Path("./cache_folder/core") + # test_run_node_pipeline(folder) + + test_skip_after_n_peaks()