-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nodepipeline : skip chunks when no peaks inside and skip_after_n_peaks #3356
Merged
yger
merged 9 commits into
SpikeInterface:main
from
samuelgarcia:node_pipeline_skip_no_peaks
Oct 7, 2024
+226
−84
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
f3fbd5d
nodepipeline : skip chunks when no peaks inside
samuelgarcia b1d726a
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia 6590e0f
nodepipeline add skip_after_n_peaks option
samuelgarcia 9111c13
make Zach happy
samuelgarcia 1f52715
Merci Zach
samuelgarcia 56975a2
Merge branch 'main' into node_pipeline_skip_no_peaks
samuelgarcia 28476cc
Merge branch 'main' into node_pipeline_skip_no_peaks
samuelgarcia c4eb8a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5d84f6c
Merge branch 'main' into node_pipeline_skip_no_peaks
samuelgarcia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,6 @@ | ||
""" | ||
Pipeline on spikes/peaks/detected peaks | ||
|
||
Functions that can be chained: | ||
* after peak detection | ||
* already detected peaks | ||
* spikes (labeled peaks) | ||
to compute some additional features on-the-fly: | ||
* peak localization | ||
* peak-to-peak | ||
* pca | ||
* amplitude | ||
* amplitude scaling | ||
* ... | ||
|
||
There are two ways for using theses "plugin nodes": | ||
* during `peak_detect()` | ||
* when peaks are already detected and reduced with `select_peaks()` | ||
* on a sorting object | ||
|
||
|
||
""" | ||
|
||
from __future__ import annotations | ||
|
@@ -103,6 +87,15 @@ def get_trace_margin(self): | |
def get_dtype(self): | ||
return base_peak_dtype | ||
|
||
def get_peak_slice( | ||
self, | ||
segment_index, | ||
start_frame, | ||
end_frame, | ||
): | ||
# not needed for PeakDetector | ||
raise NotImplementedError | ||
|
||
|
||
# this is used in sorting components | ||
class PeakDetector(PeakSource): | ||
|
@@ -127,11 +120,18 @@ def get_trace_margin(self): | |
def get_dtype(self): | ||
return base_peak_dtype | ||
|
||
def compute(self, traces, start_frame, end_frame, segment_index, max_margin): | ||
# get local peaks | ||
def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): | ||
sl = self.segment_slices[segment_index] | ||
peaks_in_segment = self.peaks[sl] | ||
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) | ||
return i0, i1 | ||
|
||
def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): | ||
# get local peaks | ||
sl = self.segment_slices[segment_index] | ||
peaks_in_segment = self.peaks[sl] | ||
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) | ||
i0, i1 = peak_slice | ||
local_peaks = peaks_in_segment[i0:i1] | ||
|
||
# make sample index local to traces | ||
|
@@ -212,8 +212,7 @@ def get_trace_margin(self): | |
def get_dtype(self): | ||
return self._dtype | ||
|
||
def compute(self, traces, start_frame, end_frame, segment_index, max_margin): | ||
# get local peaks | ||
def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): | ||
sl = self.segment_slices[segment_index] | ||
peaks_in_segment = self.peaks[sl] | ||
if self.include_spikes_in_margin: | ||
|
@@ -222,6 +221,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): | |
) | ||
else: | ||
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) | ||
return i0, i1 | ||
|
||
def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): | ||
# get local peaks | ||
sl = self.segment_slices[segment_index] | ||
peaks_in_segment = self.peaks[sl] | ||
# if self.include_spikes_in_margin: | ||
# i0, i1 = np.searchsorted( | ||
# peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin] | ||
# ) | ||
# else: | ||
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) | ||
i0, i1 = peak_slice | ||
|
||
local_peaks = peaks_in_segment[i0:i1] | ||
|
||
# make sample index local to traces | ||
|
@@ -467,31 +480,91 @@ def run_node_pipeline( | |
nodes, | ||
job_kwargs, | ||
job_name="pipeline", | ||
mp_context=None, | ||
# mp_context=None, | ||
gather_mode="memory", | ||
gather_kwargs={}, | ||
squeeze_output=True, | ||
folder=None, | ||
names=None, | ||
verbose=False, | ||
skip_after_n_peaks=None, | ||
): | ||
""" | ||
Common function to run pipeline with peak detector or already detected peak. | ||
Machinery to compute in parallel operations on peaks and traces. | ||
|
||
This useful in several use cases: | ||
* in sortingcomponents : detect peaks and make some computation on then (localize, pca, ...) | ||
* in sortingcomponents : replay some peaks and make some computation on then (localize, pca, ...) | ||
* postprocessing : replay some spikes and make some computation on then (localize, pca, ...) | ||
|
||
Here a "peak" is a spike without any labels just a "detected". | ||
Here a "spike" is a spike with any a label so already sorted. | ||
Comment on lines
+500
to
+501
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is super helpful. Nice comment! |
||
|
||
The main idea is to have a graph of nodes. | ||
Every node is doing a computaion of some peaks and related traces. | ||
The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever) | ||
|
||
Every node can have one or several output that can be directed to other nodes (aka nodes have parents). | ||
|
||
Every node can optionally have a global output that will be gathered by the main process. | ||
This is controlled by return_output = True. | ||
|
||
The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector. | ||
These vectors can be in "memory" or in files ("npy") | ||
|
||
|
||
Parameters | ||
---------- | ||
|
||
recording: Recording | ||
|
||
nodes: a list of PipelineNode | ||
|
||
job_kwargs: dict | ||
The classical job_kwargs | ||
job_name : str | ||
The name of the pipeline used for the progress_bar | ||
gather_mode : "memory" | "npz" | ||
|
||
gather_kwargs : dict | ||
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. | ||
squeeze_output : bool, default True | ||
If only one output node then squeeze the tuple | ||
folder : str | Path | None | ||
Used for gather_mode="npz" | ||
names : list of str | ||
Names of outputs. | ||
verbose : bool, default False | ||
Verbosity. | ||
skip_after_n_peaks : None | int | ||
Skip the computation after n_peaks. | ||
This is not an exact because internally this skip is done per worker in average. | ||
|
||
Returns | ||
------- | ||
outputs: tuple of np.array | np.array | ||
a tuple of vector for the output of nodes having return_output=True. | ||
If squeeze_output=True and only one output then directly np.array. | ||
""" | ||
|
||
check_graph(nodes) | ||
|
||
job_kwargs = fix_job_kwargs(job_kwargs) | ||
assert all(isinstance(node, PipelineNode) for node in nodes) | ||
|
||
if skip_after_n_peaks is not None: | ||
skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"] | ||
else: | ||
skip_after_n_peaks_per_worker = None | ||
|
||
if gather_mode == "memory": | ||
gather_func = GatherToMemory() | ||
elif gather_mode == "npy": | ||
gather_func = GatherToNpy(folder, names, **gather_kwargs) | ||
else: | ||
raise ValueError(f"wrong gather_mode : {gather_mode}") | ||
|
||
init_args = (recording, nodes) | ||
init_args = (recording, nodes, skip_after_n_peaks_per_worker) | ||
|
||
processor = ChunkRecordingExecutor( | ||
recording, | ||
|
@@ -510,79 +583,103 @@ def run_node_pipeline( | |
return outs | ||
|
||
|
||
def _init_peak_pipeline(recording, nodes): | ||
def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker): | ||
# create a local dict per worker | ||
worker_ctx = {} | ||
worker_ctx["recording"] = recording | ||
worker_ctx["nodes"] = nodes | ||
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) | ||
worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker | ||
worker_ctx["num_peaks"] = 0 | ||
return worker_ctx | ||
|
||
|
||
def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx): | ||
recording = worker_ctx["recording"] | ||
max_margin = worker_ctx["max_margin"] | ||
nodes = worker_ctx["nodes"] | ||
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] | ||
|
||
recording_segment = recording._recording_segments[segment_index] | ||
traces_chunk, left_margin, right_margin = get_chunk_with_margin( | ||
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True | ||
) | ||
node0 = nodes[0] | ||
|
||
# compute the graph | ||
pipeline_outputs = {} | ||
for node in nodes: | ||
node_parents = node.parents if node.parents else list() | ||
node_input_args = tuple() | ||
for parent in node_parents: | ||
parent_output = pipeline_outputs[parent] | ||
parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) | ||
node_input_args += parent_outputs_tuple | ||
if isinstance(node, PeakDetector): | ||
# to handle compatibility peak detector is a special case | ||
# with specific margin | ||
# TODO later when in master: change this later | ||
extra_margin = max_margin - node.get_trace_margin() | ||
if extra_margin: | ||
trace_detection = traces_chunk[extra_margin:-extra_margin] | ||
if isinstance(node0, (SpikeRetriever, PeakRetriever)): | ||
# in this case PeakSource could have no peaks and so no need to load traces just skip | ||
peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin) | ||
load_trace_and_compute = i0 < i1 | ||
else: | ||
# PeakDetector always need traces | ||
load_trace_and_compute = True | ||
|
||
if skip_after_n_peaks_per_worker is not None: | ||
if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker: | ||
load_trace_and_compute = False | ||
|
||
if load_trace_and_compute: | ||
traces_chunk, left_margin, right_margin = get_chunk_with_margin( | ||
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True | ||
) | ||
# compute the graph | ||
pipeline_outputs = {} | ||
for node in nodes: | ||
node_parents = node.parents if node.parents else list() | ||
node_input_args = tuple() | ||
for parent in node_parents: | ||
parent_output = pipeline_outputs[parent] | ||
parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) | ||
node_input_args += parent_outputs_tuple | ||
if isinstance(node, PeakDetector): | ||
# to handle compatibility peak detector is a special case | ||
# with specific margin | ||
# TODO later when in master: change this later | ||
extra_margin = max_margin - node.get_trace_margin() | ||
if extra_margin: | ||
trace_detection = traces_chunk[extra_margin:-extra_margin] | ||
else: | ||
trace_detection = traces_chunk | ||
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) | ||
# set sample index to local | ||
node_output[0]["sample_index"] += extra_margin | ||
elif isinstance(node, PeakSource): | ||
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice) | ||
else: | ||
trace_detection = traces_chunk | ||
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) | ||
# set sample index to local | ||
node_output[0]["sample_index"] += extra_margin | ||
elif isinstance(node, PeakSource): | ||
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) | ||
else: | ||
# TODO later when in master: change the signature of all nodes (or maybe not!) | ||
node_output = node.compute(traces_chunk, *node_input_args) | ||
pipeline_outputs[node] = node_output | ||
|
||
# propagate the output | ||
pipeline_outputs_tuple = tuple() | ||
for node in nodes: | ||
# handle which buffer are given to the output | ||
# this is controlled by node.return_output being a bool or tuple of bool | ||
out = pipeline_outputs[node] | ||
if isinstance(out, tuple): | ||
if isinstance(node.return_output, bool) and node.return_output: | ||
pipeline_outputs_tuple += out | ||
elif isinstance(node.return_output, tuple): | ||
for flag, e in zip(node.return_output, out): | ||
if flag: | ||
pipeline_outputs_tuple += (e,) | ||
else: | ||
if isinstance(node.return_output, bool) and node.return_output: | ||
pipeline_outputs_tuple += (out,) | ||
elif isinstance(node.return_output, tuple): | ||
# this should not apppend : maybe a checker somewhere before ? | ||
pass | ||
# TODO later when in master: change the signature of all nodes (or maybe not!) | ||
node_output = node.compute(traces_chunk, *node_input_args) | ||
pipeline_outputs[node] = node_output | ||
|
||
if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource): | ||
worker_ctx["num_peaks"] += node_output[0].size | ||
|
||
# propagate the output | ||
pipeline_outputs_tuple = tuple() | ||
for node in nodes: | ||
# handle which buffer are given to the output | ||
# this is controlled by node.return_output being a bool or tuple of bool | ||
out = pipeline_outputs[node] | ||
if isinstance(out, tuple): | ||
if isinstance(node.return_output, bool) and node.return_output: | ||
pipeline_outputs_tuple += out | ||
elif isinstance(node.return_output, tuple): | ||
for flag, e in zip(node.return_output, out): | ||
if flag: | ||
pipeline_outputs_tuple += (e,) | ||
else: | ||
if isinstance(node.return_output, bool) and node.return_output: | ||
pipeline_outputs_tuple += (out,) | ||
elif isinstance(node.return_output, tuple): | ||
# this should not apppend : maybe a checker somewhere before ? | ||
pass | ||
|
||
if isinstance(nodes[0], PeakDetector): | ||
# the first out element is the peak vector | ||
# we need to go back to absolut sample index | ||
pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin | ||
|
||
if isinstance(nodes[0], PeakDetector): | ||
# the first out element is the peak vector | ||
# we need to go back to absolut sample index | ||
pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin | ||
return pipeline_outputs_tuple | ||
|
||
return pipeline_outputs_tuple | ||
else: | ||
# the gather will skip this output and not concatenate it | ||
return | ||
|
||
|
||
class GatherToMemory: | ||
|
@@ -595,6 +692,9 @@ def __init__(self): | |
self.tuple_mode = None | ||
|
||
def __call__(self, res): | ||
if res is None: | ||
return | ||
|
||
if self.tuple_mode is None: | ||
# first loop only | ||
self.tuple_mode = isinstance(res, tuple) | ||
|
@@ -655,6 +755,9 @@ def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): | |
self.final_shapes.append(None) | ||
|
||
def __call__(self, res): | ||
if res is None: | ||
return | ||
|
||
if self.tuple_mode is None: | ||
# first loop only | ||
self.tuple_mode = isinstance(res, tuple) | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replay?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is like a replay of spikes by buffer. Maybe another term would be better I found this image more easy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless this is extremely technical I don't think we could quite say this. It sounds more like you are pulling some sound engineering terminology into the spike world. Like you're remixing your data analysis. I think I get the vibe, but I'm not sure of a better word. We can leave it for now and if I think of something I can put in a PR patch later.