Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative implementation of causal filter in filter.py #3133

Closed
Closed
Changes from 9 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 112 additions & 8 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class FilterRecording(BasePreprocessor):
Generic filter class based on:

* scipy.signal.iirfilter
* scipy.signal.filtfilt or scipy.signal.sosfilt
* scipy.signal.filtfilt or scipy.signal.sosfiltfilt
* scipy.signal.lfilt or scipy.signal.sosfilt when causal_mode = True

BandpassFilterRecording is built on top of it.

Expand Down Expand Up @@ -56,6 +57,10 @@ class FilterRecording(BasePreprocessor):
- numerator/denominator : ("ba")
ftype : str, default: "butter"
Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1".
causal_mode : bool, default: False
If true, filtering is applied in just one direction.
direction : "forward" | "backward", default: "forward"
when causal_mode = True, defines the direction of the filtering
JuanPimientoCaicedo marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand All @@ -77,6 +82,8 @@ def __init__(
add_reflect_padding=False,
coeff=None,
dtype=None,
causal_mode=False,
direction="forward",
):
import scipy.signal

Expand Down Expand Up @@ -108,7 +115,14 @@ def __init__(
for parent_segment in recording._recording_segments:
self.add_recording_segment(
FilterRecordingSegment(
parent_segment, filter_coeff, filter_mode, margin, dtype, add_reflect_padding=add_reflect_padding
parent_segment,
filter_coeff,
filter_mode,
causal_mode,
direction,
margin,
dtype,
add_reflect_padding=add_reflect_padding,
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
)
)

Expand All @@ -123,14 +137,28 @@ def __init__(
margin_ms=margin_ms,
add_reflect_padding=add_reflect_padding,
dtype=dtype.str,
causal_mode=causal_mode,
direction=direction,
)


class FilterRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, add_reflect_padding=False):
def __init__(
self,
parent_recording_segment,
coeff,
filter_mode,
causal_mode,
direction,
margin,
dtype,
add_reflect_padding=False,
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.coeff = coeff
self.filter_mode = filter_mode
self.causal_mode = causal_mode
self.direction = direction
self.margin = margin
self.add_reflect_padding = add_reflect_padding
self.dtype = dtype
Expand All @@ -152,11 +180,25 @@ def get_traces(self, start_frame, end_frame, channel_indices):

import scipy.signal

if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)
if self.causal_mode:
if self.direction == "backward":
traces_chunk = np.flip(traces_chunk, axis=0)

if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0)

if self.direction == "backward":
filtered_traces = np.flip(filtered_traces, axis=0)

else:
if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)

if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
Expand Down Expand Up @@ -291,11 +333,73 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):
self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)


class CausalFilter(FilterRecording):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why having this class if the FilterRecording already handle it with specific options ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it is for the same reason highpass_filter and bandpass_filter exist. Although it is true that in causal filtering, the users might need to provide some extra parameters compared with these other subclasses...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @samuelgarcia we don't need the class in this case, since the CasualFilter class is really just instantiating the FilterRecording class (the fix_dtype and kwargs will be handled by the FilterRecording class).

My suggestion is to:

  • remove this class
  • explicittly define the causal_filter function (see my comment)

"""
Performs causal filtering using:
* scipy.signal.lfilt or scipy.signal.sosfilt

Parameters
----------
recording : Recording
The recording extractor to be re-referenced
band : float or list, default: [300.0, 6000.0]
If float, cutoff frequency in Hz for "highpass" filter type
If list. band (low, high) in Hz for "bandpass" filter type
margin_ms : float
Margin in ms on border to avoid border effect
dtype : dtype or None
The dtype of the returned traces. If None, the dtype of the parent recording is used
direction : "forward" | "backward", default: "forward"
when causal_mode = True, defines the direction of the filtering

Returns
-------
filter_recording : CausalFilterRecording
The causal-filtered recording extractor object

{}

"""

name = "causal_filter"

def __init__(
self,
recording,
band=[300.0, 6000.0],
margin_ms=5.0,
dtype=None,
direction="forward",
**filter_kwargs,
):
FilterRecording.__init__(
self,
recording,
band=band,
margin_ms=margin_ms,
dtype=dtype,
causal_mode=True,
direction=direction,
**filter_kwargs,
)
dtype = fix_dtype(recording, dtype)
self._kwargs = dict(
recording=recording,
band=band,
margin_ms=margin_ms,
dtype=dtype.str,
causal_mode=True,
direction=direction,
)
self._kwargs.update(filter_kwargs)


# functions for API
filter = define_function_from_class(source_class=FilterRecording, name="filter")
bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter")
notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter")
highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter")
causal_filter = define_function_from_class(source_class=CausalFilter, name="causal_filter")
JuanPimientoCaicedo marked this conversation as resolved.
Show resolved Hide resolved

bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs)
highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading