diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 93462ac5d8..507c099b02 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -27,7 +27,8 @@ class FilterRecording(BasePreprocessor): Generic filter class based on: * scipy.signal.iirfilter - * scipy.signal.filtfilt or scipy.signal.sosfilt + * scipy.signal.filtfilt or scipy.signal.sosfiltfilt when direction = "forward-backward" + * scipy.signal.lfilt or scipy.signal.sosfilt BandpassFilterRecording is built on top of it. @@ -56,6 +57,10 @@ class FilterRecording(BasePreprocessor): - numerator/denominator : ("ba") ftype : str, default: "butter" Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + direction : "forward" | "backward" | "forward-backward", default: "forward-backward" + Direction of filtering: + - forward and backward filter in just one direction, creating phase shifts in the signal. + - forward-backward filters in both directions, a zero-phase filtering. Returns ------- @@ -77,6 +82,7 @@ def __init__( add_reflect_padding=False, coeff=None, dtype=None, + direction="forward-backward", ): import scipy.signal @@ -108,7 +114,13 @@ def __init__( for parent_segment in recording._recording_segments: self.add_recording_segment( FilterRecordingSegment( - parent_segment, filter_coeff, filter_mode, margin, dtype, add_reflect_padding=add_reflect_padding + parent_segment, + filter_coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=add_reflect_padding, + direction=direction, ) ) @@ -123,14 +135,25 @@ def __init__( margin_ms=margin_ms, add_reflect_padding=add_reflect_padding, dtype=dtype.str, + direction=direction, ) class FilterRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, add_reflect_padding=False): + def __init__( + self, + parent_recording_segment, + coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=False, + direction="forward-backward", + ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff self.filter_mode = filter_mode + self.direction = direction self.margin = margin self.add_reflect_padding = add_reflect_padding self.dtype = dtype @@ -152,11 +175,24 @@ def get_traces(self, start_frame, end_frame, channel_indices): import scipy.signal - if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) - elif self.filter_mode == "ba": - b, a = self.coeff - filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + if self.direction == "forward-backward": + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + else: + if self.direction == "backward": + traces_chunk = np.flip(traces_chunk, axis=0) + + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0) + + if self.direction == "backward": + filtered_traces = np.flip(filtered_traces, axis=0) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] @@ -297,8 +333,74 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") +def causal_filter( + recording, + direction="forward-backward", + band=[300.0, 6000.0], + btype="bandpass", + filter_order=5, + ftype="butter", + filter_mode="sos", + margin_ms=5.0, + add_reflect_padding=False, + coeff=None, + dtype=None, +): + """ + Generic causal filter built on top of the filter function. + + Parameters + ---------- + recording : Recording + The recording extractor to be re-referenced + direction : "forward" | "backward", default: "forward" + Direction of causal filter. The "backward" option flips the traces in time before applying the filter + and then flips them back. + band : float or list, default: [300.0, 6000.0] + If float, cutoff frequency in Hz for "highpass" filter type + If list. band (low, high) in Hz for "bandpass" filter type + btype : "bandpass" | "highpass", default: "bandpass" + Type of the filter + margin_ms : float, default: 5.0 + Margin in ms on border to avoid border effect + coeff : array | None, default: None + Filter coefficients in the filter_mode form. + dtype : dtype or None, default: None + The dtype of the returned traces. If None, the dtype of the parent recording is used + add_reflect_padding : Bool, default False + If True, uses a left and right margin during calculation. + filter_order : order + The order of the filter for `scipy.signal.iirfilter` + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients for `scipy.signal.iirfilter`: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + + Returns + ------- + filter_recording : FilterRecording + The causal-filtered recording extractor object + """ + assert direction in ["forward", "backward"], "Direction can be either 'forward' or 'backward'" + return filter( + recording=recording, + direction=direction, + band=band + btype=btype, + filter_order=filter_order, + ftype=ftype, + filter_mode=filter_mode, + margin_ms=margin_ms, + add_reflect_padding=add_reflect_padding, + coeff=coeff, + dtype=dtype, + ) + bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs) +causal_filter.__doc__ = causal_filter.__doc__.format(_common_filter_docs) def fix_dtype(recording, dtype):