diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index a5279247f5..9d29fa41a0 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -385,11 +385,14 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self): + def run(self, seed=None): """ Runs the defined jobs. """ all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + if seed is not None: + rng = np.random.default_rng(seed) + all_chunks = rng.permutation(all_chunks) if self.handle_returns: returns = [] diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ceff8577d3..ece8d09769 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -124,6 +124,12 @@ def __init__(self, recording, peaks): def get_trace_margin(self): return 0 + def has_peaks(self, start_frame, end_frame, segment_index): + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0 < i1 + def get_dtype(self): return base_peak_dtype @@ -209,6 +215,12 @@ def __init__( def get_trace_margin(self): return 0 + def has_peaks(self, start_frame, end_frame, segment_index): + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0 < i1 + def get_dtype(self): return self._dtype @@ -467,12 +479,14 @@ def run_node_pipeline( nodes, job_kwargs, job_name="pipeline", + max_peaks=None, mp_context=None, gather_mode="memory", gather_kwargs={}, squeeze_output=True, folder=None, names=None, + seed=None, verbose=False, ): """ @@ -491,7 +505,7 @@ def run_node_pipeline( else: raise ValueError(f"wrong gather_mode : {gather_mode}") - init_args = (recording, nodes) + init_args = (recording, nodes, max_peaks) processor = ChunkRecordingExecutor( recording, @@ -504,17 +518,19 @@ def run_node_pipeline( **job_kwargs, ) - processor.run() + processor.run(seed=seed) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs -def _init_peak_pipeline(recording, nodes): +def _init_peak_pipeline(recording, nodes, max_peaks): # create a local dict per worker worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["nodes"] = nodes + worker_ctx["num_peaks"] = 0 + worker_ctx["max_peaks"] = max_peaks worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) return worker_ctx @@ -522,12 +538,31 @@ def _init_peak_pipeline(recording, nodes): def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx): recording = worker_ctx["recording"] max_margin = worker_ctx["max_margin"] + max_peaks = worker_ctx["max_peaks"] nodes = worker_ctx["nodes"] + num_peaks = worker_ctx["num_peaks"] recording_segment = recording._recording_segments[segment_index] - traces_chunk, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True - ) + + if max_peaks is not None: + search_for_peaks = num_peaks < max_peaks + else: + search_for_peaks = True + + if isinstance(nodes[0], PeakRetriever) or isinstance(nodes[0], SpikeRetriever): + chunk_has_peaks = nodes[0].has_peaks(start_frame, end_frame, segment_index) + else: + chunk_has_peaks = True + + if search_for_peaks and chunk_has_peaks: + traces_chunk, left_margin, right_margin = get_chunk_with_margin( + recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True + ) + traces_loaded = True + else: + peak_output = np.zeros(0, base_peak_dtype) + waveform_output, left_margin, right_margin = np.zeros((0, 0), dtype=recording.dtype), 0, 0 + traces_loaded = False # compute the graph pipeline_outputs = {} @@ -538,23 +573,32 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c parent_output = pipeline_outputs[parent] parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) node_input_args += parent_outputs_tuple - if isinstance(node, PeakDetector): - # to handle compatibility peak detector is a special case - # with specific margin - # TODO later when in master: change this later - extra_margin = max_margin - node.get_trace_margin() - if extra_margin: - trace_detection = traces_chunk[extra_margin:-extra_margin] + if traces_loaded: + if isinstance(node, PeakDetector): + # to handle compatibility peak detector is a special case + # with specific margin + # TODO later when in master: change this later + extra_margin = max_margin - node.get_trace_margin() + if extra_margin: + trace_detection = traces_chunk[extra_margin:-extra_margin] + else: + trace_detection = traces_chunk + node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) + # set sample index to local + node_output[0]["sample_index"] += extra_margin + worker_ctx["num_peaks"] += len(node_output[0]) + elif isinstance(node, PeakSource): + node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) + worker_ctx["num_peaks"] += len(node_output[0]) else: - trace_detection = traces_chunk - node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) - # set sample index to local - node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakSource): - node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) + # TODO later when in master: change the signature of all nodes (or maybe not!) + node_output = node.compute(traces_chunk, *node_input_args) else: - # TODO later when in master: change the signature of all nodes (or maybe not!) - node_output = node.compute(traces_chunk, *node_input_args) + if isinstance(node, PeakSource): + node_output = peak_output + else: + node_output = waveform_output + pipeline_outputs[node] = node_output # propagate the output @@ -582,6 +626,9 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c # we need to go back to absolut sample index pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin + else: + pipeline_outputs_tuple = tuple() + return pipeline_outputs_tuple diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 0d5c92ff28..7c09f0ef04 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -50,7 +50,15 @@ def detect_peaks( - recording, method="by_channel", pipeline_nodes=None, gather_mode="memory", folder=None, names=None, **kwargs + recording, + method="by_channel", + pipeline_nodes=None, + gather_mode="memory", + folder=None, + names=None, + seed=None, + max_peaks=None, + **kwargs, ): """Peak detection based on threshold crossing in term of k x MAD. @@ -123,6 +131,8 @@ def detect_peaks( gather_mode=gather_mode, squeeze_output=squeeze_output, folder=folder, + seed=seed, + max_peaks=max_peaks, names=names, ) return outs