Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Preprocessing] add causal backward filtering to correct hardware induced phase shift #2942

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 167 additions & 1 deletion src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FilterRecording(BasePreprocessor):
Generic filter class based on:

* scipy.signal.iirfilter
* scipy.signal.filtfilt or scipy.signal.sosfilt
* scipy.signal.filtfilt or scipy.signal.sosfiltfilt

BandpassFilterRecording is built on top of it.

Expand Down Expand Up @@ -291,11 +291,177 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):
self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)


class Causalfilter(BasePreprocessor):
"""
filter class based on:

* scipy.signal.iirfilter
* scipy.signal.lfilt or scipy.signal.sosfilt

Produces forward or backward filtering with causal filters

Parameters
----------
recording : Recording
The recording extractor to be re-referenced
band : float or list, default: [300.0, 6000.0]
If float, cutoff frequency in Hz for "highpass" filter type
If list. band (low, high) in Hz for "bandpass" filter type
btype : "bandpass" | "highpass", default: "bandpass"
Type of the filter
margin_ms : float, default: 5.0
Margin in ms on border to avoid border effect
coeff : array | None, default: None
Filter coefficients in the filter_mode form.
dtype : dtype or None, default: None
The dtype of the returned traces. If None, the dtype of the parent recording is used
add_reflect_padding : Bool, default False
If True, uses a left and right margin during calculation.
filter_order : order
The order of the filter for `scipy.signal.iirfilter`
filter_mode : "sos" | "ba", default: "sos"
Filter form of the filter coefficients for `scipy.signal.iirfilter`:
- second-order sections ("sos")
- numerator/denominator : ("ba")
ftype : str, default: "butter"
Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1".
direction : "forward" | "backward", default: "forward"
Returns
-------
filter_recording : FilterRecording
The filtered recording extractor object
"""

name = "causal_filter"

def __init__(
self,
recording,
band=[300.0, 6000.0],
btype="bandpass",
filter_order=1,
ftype="butter",
filter_mode="sos",
direction="forward",
margin_ms=5.0,
add_reflect_padding=False,
coeff=None,
dtype=None,
):
import scipy.signal

assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'"
fs = recording.get_sampling_frequency()
if coeff is None:
assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'"
# coefficient
# self.coeff is 'sos' or 'ab' style
filter_coeff = scipy.signal.iirfilter(
filter_order, band, fs=fs, analog=False, btype=btype, ftype=ftype, output=filter_mode
)
else:
filter_coeff = coeff
if not isinstance(coeff, list):
if filter_mode == "ba":
coeff = [c.tolist() for c in coeff]
else:
coeff = coeff.tolist()
dtype = fix_dtype(recording, dtype)

BasePreprocessor.__init__(self, recording, dtype=dtype)
self.annotate(is_filtered=True)

if "offset_to_uV" in self.get_property_keys():
self.set_channel_offsets(0)

margin = int(margin_ms * fs / 1000.0)
for parent_segment in recording._recording_segments:
self.add_recording_segment(
CausalFilterRecordingSegment(
parent_segment,
filter_coeff,
filter_mode,
margin,
dtype,
direction,
add_reflect_padding=add_reflect_padding,
)
)

self._kwargs = dict(
recording=recording,
band=band,
btype=btype,
filter_order=filter_order,
ftype=ftype,
filter_mode=filter_mode,
coeff=coeff,
margin_ms=margin_ms,
direction=direction,
add_reflect_padding=add_reflect_padding,
dtype=dtype.str,
)


class CausalFilterRecordingSegment(BasePreprocessorSegment):
def __init__(
self, parent_recording_segment, coeff, filter_mode, margin, dtype, direction, add_reflect_padding=False
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.coeff = coeff
self.filter_mode = filter_mode
self.margin = margin
self.add_reflect_padding = add_reflect_padding
self.dtype = dtype
self.direction = direction

def get_traces(self, start_frame, end_frame, channel_indices):
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
self.parent_recording_segment,
start_frame,
end_frame,
channel_indices,
self.margin,
add_reflect_padding=self.add_reflect_padding,
)

traces_dtype = traces_chunk.dtype
# if uint --> force int
if traces_dtype.kind == "u":
traces_chunk = traces_chunk.astype("float32")

import scipy.signal

if self.filter_mode == "sos":
if self.direction == "forward":
filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0)
elif self.direction == "backward":
filtered_traces = np.flip(
scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk, axis=0), axis=0), axis=0
)
elif self.filter_mode == "ba":
b, a = self.coeff
if self.direction == "forward":
filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0)
elif self.direction == "backward":
filtered_traces = np.flip(scipy.signal.lfilt(b, a, np.flip(traces_chunk, axis=0), axis=0), axis=0)
if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
else:
filtered_traces = filtered_traces[left_margin:, :]

if np.issubdtype(self.dtype, np.integer):
filtered_traces = filtered_traces.round()

return filtered_traces.astype(self.dtype)


# functions for API
filter = define_function_from_class(source_class=FilterRecording, name="filter")
bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter")
notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter")
highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter")
causal_filter = define_function_from_class(source_class=Causalfilter, name="causal_filter")

bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs)
highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs)
Expand Down
Loading