diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ceff8577d3..a72808e176 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -1,22 +1,6 @@ """ -Pipeline on spikes/peaks/detected peaks - -Functions that can be chained: - * after peak detection - * already detected peaks - * spikes (labeled peaks) -to compute some additional features on-the-fly: - * peak localization - * peak-to-peak - * pca - * amplitude - * amplitude scaling - * ... - -There are two ways for using theses "plugin nodes": - * during `peak_detect()` - * when peaks are already detected and reduced with `select_peaks()` - * on a sorting object + + """ from __future__ import annotations @@ -103,6 +87,15 @@ def get_trace_margin(self): def get_dtype(self): return base_peak_dtype + def get_peak_slice( + self, + segment_index, + start_frame, + end_frame, + ): + # not needed for PeakDetector + raise NotImplementedError + # this is used in sorting components class PeakDetector(PeakSource): @@ -127,11 +120,18 @@ def get_trace_margin(self): def get_dtype(self): return base_peak_dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # get local peaks + def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0, i1 + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + # i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + i0, i1 = peak_slice local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -212,8 +212,7 @@ def get_trace_margin(self): def get_dtype(self): return self._dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # get local peaks + def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] if self.include_spikes_in_margin: @@ -222,6 +221,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): ) else: i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0, i1 + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + # if self.include_spikes_in_margin: + # i0, i1 = np.searchsorted( + # peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin] + # ) + # else: + # i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + i0, i1 = peak_slice + local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -467,16 +480,71 @@ def run_node_pipeline( nodes, job_kwargs, job_name="pipeline", - mp_context=None, + # mp_context=None, gather_mode="memory", gather_kwargs={}, squeeze_output=True, folder=None, names=None, verbose=False, + skip_after_n_peaks=None, ): """ - Common function to run pipeline with peak detector or already detected peak. + Machinery to compute in parallel operations on peaks and traces. + + This useful in several use cases: + * in sortingcomponents : detect peaks and make some computation on then (localize, pca, ...) + * in sortingcomponents : replay some peaks and make some computation on then (localize, pca, ...) + * postprocessing : replay some spikes and make some computation on then (localize, pca, ...) + + Here a "peak" is a spike without any labels just a "detected". + Here a "spike" is a spike with any a label so already sorted. + + The main idea is to have a graph of nodes. + Every node is doing a computaion of some peaks and related traces. + The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever) + + Every node can have one or several output that can be directed to other nodes (aka nodes have parents). + + Every node can optionally have a global output that will be gathered by the main process. + This is controlled by return_output = True. + + The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector. + These vectors can be in "memory" or in files ("npy") + + + Parameters + ---------- + + recording: Recording + + nodes: a list of PipelineNode + + job_kwargs: dict + The classical job_kwargs + job_name : str + The name of the pipeline used for the progress_bar + gather_mode : "memory" | "npz" + + gather_kwargs : dict + OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. + squeeze_output : bool, default True + If only one output node then squeeze the tuple + folder : str | Path | None + Used for gather_mode="npz" + names : list of str + Names of outputs. + verbose : bool, default False + Verbosity. + skip_after_n_peaks : None | int + Skip the computation after n_peaks. + This is not an exact because internally this skip is done per worker in average. + + Returns + ------- + outputs: tuple of np.array | np.array + a tuple of vector for the output of nodes having return_output=True. + If squeeze_output=True and only one output then directly np.array. """ check_graph(nodes) @@ -484,6 +552,11 @@ def run_node_pipeline( job_kwargs = fix_job_kwargs(job_kwargs) assert all(isinstance(node, PipelineNode) for node in nodes) + if skip_after_n_peaks is not None: + skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"] + else: + skip_after_n_peaks_per_worker = None + if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": @@ -491,7 +564,7 @@ def run_node_pipeline( else: raise ValueError(f"wrong gather_mode : {gather_mode}") - init_args = (recording, nodes) + init_args = (recording, nodes, skip_after_n_peaks_per_worker) processor = ChunkRecordingExecutor( recording, @@ -510,12 +583,14 @@ def run_node_pipeline( return outs -def _init_peak_pipeline(recording, nodes): +def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker): # create a local dict per worker worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["nodes"] = nodes worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) + worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker + worker_ctx["num_peaks"] = 0 return worker_ctx @@ -523,66 +598,88 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c recording = worker_ctx["recording"] max_margin = worker_ctx["max_margin"] nodes = worker_ctx["nodes"] + skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] recording_segment = recording._recording_segments[segment_index] - traces_chunk, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True - ) + node0 = nodes[0] - # compute the graph - pipeline_outputs = {} - for node in nodes: - node_parents = node.parents if node.parents else list() - node_input_args = tuple() - for parent in node_parents: - parent_output = pipeline_outputs[parent] - parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) - node_input_args += parent_outputs_tuple - if isinstance(node, PeakDetector): - # to handle compatibility peak detector is a special case - # with specific margin - # TODO later when in master: change this later - extra_margin = max_margin - node.get_trace_margin() - if extra_margin: - trace_detection = traces_chunk[extra_margin:-extra_margin] + if isinstance(node0, (SpikeRetriever, PeakRetriever)): + # in this case PeakSource could have no peaks and so no need to load traces just skip + peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin) + load_trace_and_compute = i0 < i1 + else: + # PeakDetector always need traces + load_trace_and_compute = True + + if skip_after_n_peaks_per_worker is not None: + if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker: + load_trace_and_compute = False + + if load_trace_and_compute: + traces_chunk, left_margin, right_margin = get_chunk_with_margin( + recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True + ) + # compute the graph + pipeline_outputs = {} + for node in nodes: + node_parents = node.parents if node.parents else list() + node_input_args = tuple() + for parent in node_parents: + parent_output = pipeline_outputs[parent] + parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) + node_input_args += parent_outputs_tuple + if isinstance(node, PeakDetector): + # to handle compatibility peak detector is a special case + # with specific margin + # TODO later when in master: change this later + extra_margin = max_margin - node.get_trace_margin() + if extra_margin: + trace_detection = traces_chunk[extra_margin:-extra_margin] + else: + trace_detection = traces_chunk + node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) + # set sample index to local + node_output[0]["sample_index"] += extra_margin + elif isinstance(node, PeakSource): + node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice) else: - trace_detection = traces_chunk - node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) - # set sample index to local - node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakSource): - node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) - else: - # TODO later when in master: change the signature of all nodes (or maybe not!) - node_output = node.compute(traces_chunk, *node_input_args) - pipeline_outputs[node] = node_output - - # propagate the output - pipeline_outputs_tuple = tuple() - for node in nodes: - # handle which buffer are given to the output - # this is controlled by node.return_output being a bool or tuple of bool - out = pipeline_outputs[node] - if isinstance(out, tuple): - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += out - elif isinstance(node.return_output, tuple): - for flag, e in zip(node.return_output, out): - if flag: - pipeline_outputs_tuple += (e,) - else: - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += (out,) - elif isinstance(node.return_output, tuple): - # this should not apppend : maybe a checker somewhere before ? - pass + # TODO later when in master: change the signature of all nodes (or maybe not!) + node_output = node.compute(traces_chunk, *node_input_args) + pipeline_outputs[node] = node_output + + if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource): + worker_ctx["num_peaks"] += node_output[0].size + + # propagate the output + pipeline_outputs_tuple = tuple() + for node in nodes: + # handle which buffer are given to the output + # this is controlled by node.return_output being a bool or tuple of bool + out = pipeline_outputs[node] + if isinstance(out, tuple): + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += out + elif isinstance(node.return_output, tuple): + for flag, e in zip(node.return_output, out): + if flag: + pipeline_outputs_tuple += (e,) + else: + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += (out,) + elif isinstance(node.return_output, tuple): + # this should not apppend : maybe a checker somewhere before ? + pass + + if isinstance(nodes[0], PeakDetector): + # the first out element is the peak vector + # we need to go back to absolut sample index + pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin - if isinstance(nodes[0], PeakDetector): - # the first out element is the peak vector - # we need to go back to absolut sample index - pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin + return pipeline_outputs_tuple - return pipeline_outputs_tuple + else: + # the gather will skip this output and not concatenate it + return class GatherToMemory: @@ -595,6 +692,9 @@ def __init__(self): self.tuple_mode = None def __call__(self, res): + if res is None: + return + if self.tuple_mode is None: # first loop only self.tuple_mode = isinstance(res, tuple) @@ -655,6 +755,9 @@ def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): self.final_shapes.append(None) def __call__(self, res): + if res is None: + return + if self.tuple_mode is None: # first loop only self.tuple_mode = isinstance(res, tuple) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 8d788acbad..deef2291c6 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -83,8 +83,12 @@ def test_run_node_pipeline(cache_folder_creation): extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) + # print(peaks.size) peak_retriever = PeakRetriever(recording, peaks) + # this test when no spikes in last chunks + peak_retriever_few = PeakRetriever(recording, peaks[: peaks.size // 2]) + # channel index is from template spike_retriever_T = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds @@ -100,7 +104,7 @@ def test_run_node_pipeline(cache_folder_creation): ) # test with 3 differents first nodes - for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): + for loop, peak_source in enumerate((peak_retriever, peak_retriever_few, spike_retriever_T, spike_retriever_S)): # one step only : squeeze output nodes = [ peak_source, @@ -139,10 +143,12 @@ def test_run_node_pipeline(cache_folder_creation): num_peaks = peaks.shape[0] num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks + if peak_source != peak_retriever_few: + assert waveforms_rms.shape[0] == num_peaks assert waveforms_rms.shape[1] == num_channels - assert waveforms_rms.shape[0] == num_peaks + if peak_source != peak_retriever_few: + assert waveforms_rms.shape[0] == num_peaks assert waveforms_rms.shape[1] == num_channels # gather npy mode @@ -185,5 +191,38 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) +def test_skip_after_n_peaks(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) + + # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) + + spikes = sorting.to_spike_vector() + + # create peaks from spikes + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + + peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) + # print(peaks.size) + + node0 = PeakRetriever(recording, peaks) + node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) + nodes = [node0, node1] + + skip_after_n_peaks = 30 + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks + ) + + assert some_amplitudes.size >= skip_after_n_peaks + assert some_amplitudes.size < spikes.size + + +# the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": - test_run_node_pipeline() + # folder = Path("./cache_folder/core") + # test_run_node_pipeline(folder) + + test_skip_after_n_peaks()