-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nodepipeline : skip chunks when no peaks inside and skip_after_n_peaks #3356
Changes from 5 commits
f3fbd5d
b1d726a
6590e0f
9111c13
1f52715
56975a2
28476cc
c4eb8a5
5d84f6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,6 @@ | ||
""" | ||
Pipeline on spikes/peaks/detected peaks | ||
|
||
Functions that can be chained: | ||
* after peak detection | ||
* already detected peaks | ||
* spikes (labeled peaks) | ||
to compute some additional features on-the-fly: | ||
* peak localization | ||
* peak-to-peak | ||
* pca | ||
* amplitude | ||
* amplitude scaling | ||
* ... | ||
|
||
There are two ways for using theses "plugin nodes": | ||
* during `peak_detect()` | ||
* when peaks are already detected and reduced with `select_peaks()` | ||
* on a sorting object | ||
|
||
|
||
""" | ||
|
||
from __future__ import annotations | ||
|
@@ -103,6 +87,9 @@ def get_trace_margin(self): | |
def get_dtype(self): | ||
return base_peak_dtype | ||
|
||
def get_peak_slice(self, segment_index, start_frame, end_frame, ): | ||
# not needed for PeakDetector | ||
raise NotImplementedError | ||
|
||
# this is used in sorting components | ||
class PeakDetector(PeakSource): | ||
|
@@ -127,11 +114,18 @@ def get_trace_margin(self): | |
def get_dtype(self): | ||
return base_peak_dtype | ||
|
||
def compute(self, traces, start_frame, end_frame, segment_index, max_margin): | ||
# get local peaks | ||
def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): | ||
sl = self.segment_slices[segment_index] | ||
peaks_in_segment = self.peaks[sl] | ||
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) | ||
return i0, i1 | ||
|
||
def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): | ||
# get local peaks | ||
sl = self.segment_slices[segment_index] | ||
peaks_in_segment = self.peaks[sl] | ||
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) | ||
i0, i1 = peak_slice | ||
local_peaks = peaks_in_segment[i0:i1] | ||
|
||
# make sample index local to traces | ||
|
@@ -212,8 +206,7 @@ def get_trace_margin(self): | |
def get_dtype(self): | ||
return self._dtype | ||
|
||
def compute(self, traces, start_frame, end_frame, segment_index, max_margin): | ||
# get local peaks | ||
def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): | ||
sl = self.segment_slices[segment_index] | ||
peaks_in_segment = self.peaks[sl] | ||
if self.include_spikes_in_margin: | ||
|
@@ -222,6 +215,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): | |
) | ||
else: | ||
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) | ||
return i0, i1 | ||
|
||
def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): | ||
# get local peaks | ||
sl = self.segment_slices[segment_index] | ||
peaks_in_segment = self.peaks[sl] | ||
# if self.include_spikes_in_margin: | ||
# i0, i1 = np.searchsorted( | ||
# peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin] | ||
# ) | ||
# else: | ||
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) | ||
i0, i1 = peak_slice | ||
|
||
local_peaks = peaks_in_segment[i0:i1] | ||
|
||
# make sample index local to traces | ||
|
@@ -467,31 +474,91 @@ def run_node_pipeline( | |
nodes, | ||
job_kwargs, | ||
job_name="pipeline", | ||
mp_context=None, | ||
#mp_context=None, | ||
gather_mode="memory", | ||
gather_kwargs={}, | ||
squeeze_output=True, | ||
folder=None, | ||
names=None, | ||
verbose=False, | ||
skip_after_n_peaks=None, | ||
): | ||
""" | ||
Common function to run pipeline with peak detector or already detected peak. | ||
Machinery to compute in parallel operations on peaks and traces. | ||
|
||
This useful in several use cases: | ||
* in sortingcomponents : detect peaks and make some computation on then (localize, pca, ...) | ||
* in sortingcomponents : replay some peaks and make some computation on then (localize, pca, ...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replay? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is like a replay of spikes by buffer. Maybe another term would be better I found this image more easy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unless this is extremely technical I don't think we could quite say this. It sounds more like you are pulling some sound engineering terminology into the spike world. Like you're remixing your data analysis. I think I get the vibe, but I'm not sure of a better word. We can leave it for now and if I think of something I can put in a PR patch later. |
||
* postprocessing : replay some spikes and make some computation on then (localize, pca, ...) | ||
|
||
Here a "peak" is a spike without any labels just a "detected". | ||
Here a "spike" is a spike with any a label so already sorted. | ||
Comment on lines
+500
to
+501
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is super helpful. Nice comment! |
||
|
||
The main idea is to have a graph of nodes. | ||
Every node is doing a computaion of some peaks and related traces. | ||
The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever) | ||
|
||
Every node can have one or several output that can be directed to other nodes (aka nodes have parents). | ||
|
||
Every node can optionally have a global output that will be gathered by the main process. | ||
This is controlled by return_output = True. | ||
|
||
The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector. | ||
These vectors can be in "memory" or in files ("npy") | ||
|
||
|
||
Parameters | ||
---------- | ||
|
||
recording: Recording | ||
|
||
nodes: a list of PipelineNode | ||
|
||
job_kwargs: dict | ||
The classical job_kwargs | ||
job_name : str | ||
The name of the pipeline used for the progress_bar | ||
gather_mode : "memory" | "npz" | ||
|
||
gather_kwargs : dict | ||
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. | ||
squeeze_output : bool, default True | ||
If only one output node then squeeze the tuple | ||
folder : str | Path | None | ||
Used for gather_mode="npz" | ||
names : list of str | ||
Names of outputs. | ||
verbose : bool, default False | ||
Verbosity. | ||
skip_after_n_peaks : None | int | ||
Skip the computation after n_peaks. | ||
This is not an exact because internally this skip is done per worker in average. | ||
|
||
Returns | ||
------- | ||
outputs: tuple of np.array | np.array | ||
a tuple of vector for the output of nodes having return_output=True. | ||
If squeeze_output=True and only one output then directly np.array. | ||
""" | ||
|
||
check_graph(nodes) | ||
|
||
job_kwargs = fix_job_kwargs(job_kwargs) | ||
assert all(isinstance(node, PipelineNode) for node in nodes) | ||
|
||
if skip_after_n_peaks is not None: | ||
skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"] | ||
else: | ||
skip_after_n_peaks_per_worker = None | ||
|
||
if gather_mode == "memory": | ||
gather_func = GatherToMemory() | ||
elif gather_mode == "npy": | ||
gather_func = GatherToNpy(folder, names, **gather_kwargs) | ||
else: | ||
raise ValueError(f"wrong gather_mode : {gather_mode}") | ||
|
||
init_args = (recording, nodes) | ||
init_args = (recording, nodes, skip_after_n_peaks_per_worker) | ||
|
||
processor = ChunkRecordingExecutor( | ||
recording, | ||
|
@@ -510,79 +577,104 @@ def run_node_pipeline( | |
return outs | ||
|
||
|
||
def _init_peak_pipeline(recording, nodes): | ||
def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker): | ||
# create a local dict per worker | ||
worker_ctx = {} | ||
worker_ctx["recording"] = recording | ||
worker_ctx["nodes"] = nodes | ||
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) | ||
worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker | ||
worker_ctx["num_peaks"] = 0 | ||
return worker_ctx | ||
|
||
|
||
def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx): | ||
recording = worker_ctx["recording"] | ||
max_margin = worker_ctx["max_margin"] | ||
nodes = worker_ctx["nodes"] | ||
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] | ||
|
||
recording_segment = recording._recording_segments[segment_index] | ||
traces_chunk, left_margin, right_margin = get_chunk_with_margin( | ||
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True | ||
) | ||
node0 = nodes[0] | ||
|
||
if isinstance(node0, (SpikeRetriever, PeakRetriever)): | ||
# in this case PeakSource could have no peaks and so no need to load traces just skip | ||
peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin) | ||
load_trace_and_compute = i0 < i1 | ||
else: | ||
# PeakDetector always need traces | ||
load_trace_and_compute = True | ||
|
||
# compute the graph | ||
pipeline_outputs = {} | ||
for node in nodes: | ||
node_parents = node.parents if node.parents else list() | ||
node_input_args = tuple() | ||
for parent in node_parents: | ||
parent_output = pipeline_outputs[parent] | ||
parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) | ||
node_input_args += parent_outputs_tuple | ||
if isinstance(node, PeakDetector): | ||
# to handle compatibility peak detector is a special case | ||
# with specific margin | ||
# TODO later when in master: change this later | ||
extra_margin = max_margin - node.get_trace_margin() | ||
if extra_margin: | ||
trace_detection = traces_chunk[extra_margin:-extra_margin] | ||
if skip_after_n_peaks_per_worker is not None: | ||
if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker: | ||
load_trace_and_compute = False | ||
|
||
if load_trace_and_compute: | ||
traces_chunk, left_margin, right_margin = get_chunk_with_margin( | ||
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True | ||
) | ||
# compute the graph | ||
pipeline_outputs = {} | ||
for node in nodes: | ||
node_parents = node.parents if node.parents else list() | ||
node_input_args = tuple() | ||
for parent in node_parents: | ||
parent_output = pipeline_outputs[parent] | ||
parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) | ||
node_input_args += parent_outputs_tuple | ||
if isinstance(node, PeakDetector): | ||
# to handle compatibility peak detector is a special case | ||
# with specific margin | ||
# TODO later when in master: change this later | ||
extra_margin = max_margin - node.get_trace_margin() | ||
if extra_margin: | ||
trace_detection = traces_chunk[extra_margin:-extra_margin] | ||
else: | ||
trace_detection = traces_chunk | ||
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) | ||
# set sample index to local | ||
node_output[0]["sample_index"] += extra_margin | ||
elif isinstance(node, PeakSource): | ||
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice) | ||
else: | ||
trace_detection = traces_chunk | ||
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) | ||
# set sample index to local | ||
node_output[0]["sample_index"] += extra_margin | ||
elif isinstance(node, PeakSource): | ||
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) | ||
else: | ||
# TODO later when in master: change the signature of all nodes (or maybe not!) | ||
node_output = node.compute(traces_chunk, *node_input_args) | ||
pipeline_outputs[node] = node_output | ||
|
||
# propagate the output | ||
pipeline_outputs_tuple = tuple() | ||
for node in nodes: | ||
# handle which buffer are given to the output | ||
# this is controlled by node.return_output being a bool or tuple of bool | ||
out = pipeline_outputs[node] | ||
if isinstance(out, tuple): | ||
if isinstance(node.return_output, bool) and node.return_output: | ||
pipeline_outputs_tuple += out | ||
elif isinstance(node.return_output, tuple): | ||
for flag, e in zip(node.return_output, out): | ||
if flag: | ||
pipeline_outputs_tuple += (e,) | ||
else: | ||
if isinstance(node.return_output, bool) and node.return_output: | ||
pipeline_outputs_tuple += (out,) | ||
elif isinstance(node.return_output, tuple): | ||
# this should not apppend : maybe a checker somewhere before ? | ||
pass | ||
# TODO later when in master: change the signature of all nodes (or maybe not!) | ||
node_output = node.compute(traces_chunk, *node_input_args) | ||
pipeline_outputs[node] = node_output | ||
|
||
if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource): | ||
worker_ctx["num_peaks"] += node_output[0].size | ||
|
||
# propagate the output | ||
pipeline_outputs_tuple = tuple() | ||
for node in nodes: | ||
# handle which buffer are given to the output | ||
# this is controlled by node.return_output being a bool or tuple of bool | ||
out = pipeline_outputs[node] | ||
if isinstance(out, tuple): | ||
if isinstance(node.return_output, bool) and node.return_output: | ||
pipeline_outputs_tuple += out | ||
elif isinstance(node.return_output, tuple): | ||
for flag, e in zip(node.return_output, out): | ||
if flag: | ||
pipeline_outputs_tuple += (e,) | ||
else: | ||
if isinstance(node.return_output, bool) and node.return_output: | ||
pipeline_outputs_tuple += (out,) | ||
elif isinstance(node.return_output, tuple): | ||
# this should not apppend : maybe a checker somewhere before ? | ||
pass | ||
|
||
if isinstance(nodes[0], PeakDetector): | ||
# the first out element is the peak vector | ||
# we need to go back to absolut sample index | ||
pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin | ||
if isinstance(nodes[0], PeakDetector): | ||
# the first out element is the peak vector | ||
# we need to go back to absolut sample index | ||
pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin | ||
|
||
return pipeline_outputs_tuple | ||
|
||
else: | ||
# the gather will skip this output and not concatenate it | ||
return | ||
|
||
return pipeline_outputs_tuple | ||
|
||
|
||
class GatherToMemory: | ||
|
@@ -595,6 +687,9 @@ def __init__(self): | |
self.tuple_mode = None | ||
|
||
def __call__(self, res): | ||
if res is None: | ||
return | ||
|
||
if self.tuple_mode is None: | ||
# first loop only | ||
self.tuple_mode = isinstance(res, tuple) | ||
|
@@ -655,6 +750,9 @@ def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): | |
self.final_shapes.append(None) | ||
|
||
def __call__(self, res): | ||
if res is None: | ||
return | ||
|
||
if self.tuple_mode is None: | ||
# first loop only | ||
self.tuple_mode = isinstance(res, tuple) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we leave this for people that use positional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it was a mistake because it should be in job_kwargs no ?
Alessio is it you that add this mp_context directly here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it was you @samuelgarcia