diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 93462ac5d8..b654b965ff 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -27,7 +27,7 @@ class FilterRecording(BasePreprocessor): Generic filter class based on: * scipy.signal.iirfilter - * scipy.signal.filtfilt or scipy.signal.sosfilt + * scipy.signal.filtfilt or scipy.signal.sosfiltfilt BandpassFilterRecording is built on top of it. @@ -290,12 +290,168 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str) +class Causalfilter(BasePreprocessor): + """ + filter class based on: + + * scipy.signal.iirfilter + * scipy.signal.lfilt or scipy.signal.sosfilt + + Produces forward or backward filtering with causal filters + + Parameters + ---------- + recording : Recording + The recording extractor to be re-referenced + band : float or list, default: [300.0, 6000.0] + If float, cutoff frequency in Hz for "highpass" filter type + If list. band (low, high) in Hz for "bandpass" filter type + btype : "bandpass" | "highpass", default: "bandpass" + Type of the filter + margin_ms : float, default: 5.0 + Margin in ms on border to avoid border effect + coeff : array | None, default: None + Filter coefficients in the filter_mode form. + dtype : dtype or None, default: None + The dtype of the returned traces. If None, the dtype of the parent recording is used + add_reflect_padding : Bool, default False + If True, uses a left and right margin during calculation. + filter_order : order + The order of the filter for `scipy.signal.iirfilter` + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients for `scipy.signal.iirfilter`: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + direction : "forward" | "backward", default: "forward" + Returns + ------- + filter_recording : FilterRecording + The filtered recording extractor object + """ + + name = "causal_filter" + + def __init__( + self, + recording, + band=[300.0, 6000.0], + btype="bandpass", + filter_order=1, + ftype="butter", + filter_mode="sos", + direction="forward", + margin_ms=5.0, + add_reflect_padding=False, + coeff=None, + dtype=None, + + ): + import scipy.signal + + assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'" + fs = recording.get_sampling_frequency() + if coeff is None: + assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'" + # coefficient + # self.coeff is 'sos' or 'ab' style + filter_coeff = scipy.signal.iirfilter( + filter_order, band, fs=fs, analog=False, btype=btype, ftype=ftype, output=filter_mode + ) + else: + filter_coeff = coeff + if not isinstance(coeff, list): + if filter_mode == "ba": + coeff = [c.tolist() for c in coeff] + else: + coeff = coeff.tolist() + dtype = fix_dtype(recording, dtype) + + BasePreprocessor.__init__(self, recording, dtype=dtype) + self.annotate(is_filtered=True) + + if "offset_to_uV" in self.get_property_keys(): + self.set_channel_offsets(0) + + margin = int(margin_ms * fs / 1000.0) + for parent_segment in recording._recording_segments: + self.add_recording_segment( + CausalFilterRecordingSegment( + parent_segment, filter_coeff, filter_mode, margin, dtype, + direction, add_reflect_padding=add_reflect_padding + ) + ) + + self._kwargs = dict( + recording=recording, + band=band, + btype=btype, + filter_order=filter_order, + ftype=ftype, + filter_mode=filter_mode, + coeff=coeff, + margin_ms=margin_ms, + direction=direction, + add_reflect_padding=add_reflect_padding, + dtype=dtype.str, + ) + + +class CausalFilterRecordingSegment(BasePreprocessorSegment): + def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, direction, add_reflect_padding=False): + BasePreprocessorSegment.__init__(self, parent_recording_segment) + self.coeff = coeff + self.filter_mode = filter_mode + self.margin = margin + self.add_reflect_padding = add_reflect_padding + self.dtype = dtype + self.direction = direction + + def get_traces(self, start_frame, end_frame, channel_indices): + traces_chunk, left_margin, right_margin = get_chunk_with_margin( + self.parent_recording_segment, + start_frame, + end_frame, + channel_indices, + self.margin, + add_reflect_padding=self.add_reflect_padding, + ) + + traces_dtype = traces_chunk.dtype + # if uint --> force int + if traces_dtype.kind == "u": + traces_chunk = traces_chunk.astype("float32") + + import scipy.signal + + if self.filter_mode == "sos": + if self.direction == "forward": + filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) + elif self.direction == "backward": + filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk, axis = 0), axis = 0), axis = 0) + elif self.filter_mode == "ba": + b, a = self.coeff + if self.direction == "forward": + filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0) + elif self.direction == "backward": + filtered_traces = np.flip(scipy.signal.lfilt(b, a, np.flip(traces_chunk, axis = 0), axis=0), axis = 0) + if right_margin > 0: + filtered_traces = filtered_traces[left_margin:-right_margin, :] + else: + filtered_traces = filtered_traces[left_margin:, :] + + if np.issubdtype(self.dtype, np.integer): + filtered_traces = filtered_traces.round() + + return filtered_traces.astype(self.dtype) # functions for API filter = define_function_from_class(source_class=FilterRecording, name="filter") bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter") notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") +causal_filter = define_function_from_class(source_class=Causalfilter, name="causal_filter") bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs)