Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nodepipeline : skip chunks when no peaks inside and skip_after_n_peaks #3356

Merged
merged 9 commits into from
Oct 7, 2024
258 changes: 178 additions & 80 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,6 @@
"""
Pipeline on spikes/peaks/detected peaks

Functions that can be chained:
* after peak detection
* already detected peaks
* spikes (labeled peaks)
to compute some additional features on-the-fly:
* peak localization
* peak-to-peak
* pca
* amplitude
* amplitude scaling
* ...

There are two ways for using theses "plugin nodes":
* during `peak_detect()`
* when peaks are already detected and reduced with `select_peaks()`
* on a sorting object


"""

from __future__ import annotations
Expand Down Expand Up @@ -103,6 +87,9 @@ def get_trace_margin(self):
def get_dtype(self):
return base_peak_dtype

def get_peak_slice(self, segment_index, start_frame, end_frame, ):
# not needed for PeakDetector
raise NotImplementedError

# this is used in sorting components
class PeakDetector(PeakSource):
Expand All @@ -127,11 +114,18 @@ def get_trace_margin(self):
def get_dtype(self):
return base_peak_dtype

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin):
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
return i0, i1

def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
i0, i1 = peak_slice
local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down Expand Up @@ -212,8 +206,7 @@ def get_trace_margin(self):
def get_dtype(self):
return self._dtype

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# get local peaks
def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin):
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
if self.include_spikes_in_margin:
Expand All @@ -222,6 +215,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
)
else:
i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
return i0, i1

def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice):
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
# if self.include_spikes_in_margin:
# i0, i1 = np.searchsorted(
# peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin]
# )
# else:
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
i0, i1 = peak_slice

local_peaks = peaks_in_segment[i0:i1]

# make sample index local to traces
Expand Down Expand Up @@ -467,31 +474,91 @@ def run_node_pipeline(
nodes,
job_kwargs,
job_name="pipeline",
mp_context=None,
#mp_context=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we leave this for people that use positional?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was a mistake because it should be in job_kwargs no ?
Alessio is it you that add this mp_context directly here ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was you @samuelgarcia

gather_mode="memory",
gather_kwargs={},
squeeze_output=True,
folder=None,
names=None,
verbose=False,
skip_after_n_peaks=None,
):
"""
Common function to run pipeline with peak detector or already detected peak.
Machinery to compute in parallel operations on peaks and traces.

This useful in several use cases:
* in sortingcomponents : detect peaks and make some computation on then (localize, pca, ...)
* in sortingcomponents : replay some peaks and make some computation on then (localize, pca, ...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replay?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is like a replay of spikes by buffer. Maybe another term would be better I found this image more easy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless this is extremely technical I don't think we could quite say this. It sounds more like you are pulling some sound engineering terminology into the spike world. Like you're remixing your data analysis. I think I get the vibe, but I'm not sure of a better word. We can leave it for now and if I think of something I can put in a PR patch later.

* postprocessing : replay some spikes and make some computation on then (localize, pca, ...)

Here a "peak" is a spike without any labels just a "detected".
Here a "spike" is a spike with any a label so already sorted.
Comment on lines +500 to +501
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is super helpful. Nice comment!


The main idea is to have a graph of nodes.
Every node is doing a computaion of some peaks and related traces.
The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever)

Every node can have one or several output that can be directed to other nodes (aka nodes have parents).

Every node can optionally have a global output that will be gathered by the main process.
This is controlled by return_output = True.

The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector.
These vectors can be in "memory" or in files ("npy")


Parameters
----------

recording: Recording

nodes: a list of PipelineNode

job_kwargs: dict
The classical job_kwargs
job_name : str
The name of the pipeline used for the progress_bar
gather_mode : "memory" | "npz"

gather_kwargs : dict
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy.
squeeze_output : bool, default True
If only one output node then squeeze the tuple
folder : str | Path | None
Used for gather_mode="npz"
names : list of str
Names of outputs.
verbose : bool, default False
Verbosity.
skip_after_n_peaks : None | int
Skip the computation after n_peaks.
This is not an exact because internally this skip is done per worker in average.

Returns
-------
outputs: tuple of np.array | np.array
a tuple of vector for the output of nodes having return_output=True.
If squeeze_output=True and only one output then directly np.array.
"""

check_graph(nodes)

job_kwargs = fix_job_kwargs(job_kwargs)
assert all(isinstance(node, PipelineNode) for node in nodes)

if skip_after_n_peaks is not None:
skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"]
else:
skip_after_n_peaks_per_worker = None

if gather_mode == "memory":
gather_func = GatherToMemory()
elif gather_mode == "npy":
gather_func = GatherToNpy(folder, names, **gather_kwargs)
else:
raise ValueError(f"wrong gather_mode : {gather_mode}")

init_args = (recording, nodes)
init_args = (recording, nodes, skip_after_n_peaks_per_worker)

processor = ChunkRecordingExecutor(
recording,
Expand All @@ -510,79 +577,104 @@ def run_node_pipeline(
return outs


def _init_peak_pipeline(recording, nodes):
def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker):
# create a local dict per worker
worker_ctx = {}
worker_ctx["recording"] = recording
worker_ctx["nodes"] = nodes
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes)
worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker
worker_ctx["num_peaks"] = 0
return worker_ctx


def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx):
recording = worker_ctx["recording"]
max_margin = worker_ctx["max_margin"]
nodes = worker_ctx["nodes"]
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"]

recording_segment = recording._recording_segments[segment_index]
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True
)
node0 = nodes[0]

if isinstance(node0, (SpikeRetriever, PeakRetriever)):
# in this case PeakSource could have no peaks and so no need to load traces just skip
peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin)
load_trace_and_compute = i0 < i1
else:
# PeakDetector always need traces
load_trace_and_compute = True

# compute the graph
pipeline_outputs = {}
for node in nodes:
node_parents = node.parents if node.parents else list()
node_input_args = tuple()
for parent in node_parents:
parent_output = pipeline_outputs[parent]
parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,)
node_input_args += parent_outputs_tuple
if isinstance(node, PeakDetector):
# to handle compatibility peak detector is a special case
# with specific margin
# TODO later when in master: change this later
extra_margin = max_margin - node.get_trace_margin()
if extra_margin:
trace_detection = traces_chunk[extra_margin:-extra_margin]
if skip_after_n_peaks_per_worker is not None:
if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker:
load_trace_and_compute = False

if load_trace_and_compute:
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True
)
# compute the graph
pipeline_outputs = {}
for node in nodes:
node_parents = node.parents if node.parents else list()
node_input_args = tuple()
for parent in node_parents:
parent_output = pipeline_outputs[parent]
parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,)
node_input_args += parent_outputs_tuple
if isinstance(node, PeakDetector):
# to handle compatibility peak detector is a special case
# with specific margin
# TODO later when in master: change this later
extra_margin = max_margin - node.get_trace_margin()
if extra_margin:
trace_detection = traces_chunk[extra_margin:-extra_margin]
else:
trace_detection = traces_chunk
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakSource):
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice)
else:
trace_detection = traces_chunk
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakSource):
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
node_output = node.compute(traces_chunk, *node_input_args)
pipeline_outputs[node] = node_output

# propagate the output
pipeline_outputs_tuple = tuple()
for node in nodes:
# handle which buffer are given to the output
# this is controlled by node.return_output being a bool or tuple of bool
out = pipeline_outputs[node]
if isinstance(out, tuple):
if isinstance(node.return_output, bool) and node.return_output:
pipeline_outputs_tuple += out
elif isinstance(node.return_output, tuple):
for flag, e in zip(node.return_output, out):
if flag:
pipeline_outputs_tuple += (e,)
else:
if isinstance(node.return_output, bool) and node.return_output:
pipeline_outputs_tuple += (out,)
elif isinstance(node.return_output, tuple):
# this should not apppend : maybe a checker somewhere before ?
pass
# TODO later when in master: change the signature of all nodes (or maybe not!)
node_output = node.compute(traces_chunk, *node_input_args)
pipeline_outputs[node] = node_output

if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource):
worker_ctx["num_peaks"] += node_output[0].size

# propagate the output
pipeline_outputs_tuple = tuple()
for node in nodes:
# handle which buffer are given to the output
# this is controlled by node.return_output being a bool or tuple of bool
out = pipeline_outputs[node]
if isinstance(out, tuple):
if isinstance(node.return_output, bool) and node.return_output:
pipeline_outputs_tuple += out
elif isinstance(node.return_output, tuple):
for flag, e in zip(node.return_output, out):
if flag:
pipeline_outputs_tuple += (e,)
else:
if isinstance(node.return_output, bool) and node.return_output:
pipeline_outputs_tuple += (out,)
elif isinstance(node.return_output, tuple):
# this should not apppend : maybe a checker somewhere before ?
pass

if isinstance(nodes[0], PeakDetector):
# the first out element is the peak vector
# we need to go back to absolut sample index
pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin
if isinstance(nodes[0], PeakDetector):
# the first out element is the peak vector
# we need to go back to absolut sample index
pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin

return pipeline_outputs_tuple

else:
# the gather will skip this output and not concatenate it
return

return pipeline_outputs_tuple


class GatherToMemory:
Expand All @@ -595,6 +687,9 @@ def __init__(self):
self.tuple_mode = None

def __call__(self, res):
if res is None:
return

if self.tuple_mode is None:
# first loop only
self.tuple_mode = isinstance(res, tuple)
Expand Down Expand Up @@ -655,6 +750,9 @@ def __init__(self, folder, names, npy_header_size=1024, exist_ok=False):
self.final_shapes.append(None)

def __call__(self, res):
if res is None:
return

if self.tuple_mode is None:
# first loop only
self.tuple_mode = isinstance(res, tuple)
Expand Down
Loading
Loading