Skip to content

Commit

Permalink
add causal_filter to filter.py
Browse files Browse the repository at this point in the history
1. solved a typo in filter class.
2. added a new causal_filter class and function.
  • Loading branch information
JuanPimientoCaicedo authored Jun 30, 2024
1 parent 4539550 commit 8f3ecef
Showing 1 changed file with 157 additions and 1 deletion.
158 changes: 157 additions & 1 deletion src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FilterRecording(BasePreprocessor):
Generic filter class based on:
* scipy.signal.iirfilter
* scipy.signal.filtfilt or scipy.signal.sosfilt
* scipy.signal.filtfilt or scipy.signal.sosfiltfilt
BandpassFilterRecording is built on top of it.
Expand Down Expand Up @@ -290,12 +290,168 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):

self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)

class Causalfilter(BasePreprocessor):
"""
filter class based on:
* scipy.signal.iirfilter
* scipy.signal.lfilt or scipy.signal.sosfilt
Produces forward or backward filtering with causal filters
Parameters
----------
recording : Recording
The recording extractor to be re-referenced
band : float or list, default: [300.0, 6000.0]
If float, cutoff frequency in Hz for "highpass" filter type
If list. band (low, high) in Hz for "bandpass" filter type
btype : "bandpass" | "highpass", default: "bandpass"
Type of the filter
margin_ms : float, default: 5.0
Margin in ms on border to avoid border effect
coeff : array | None, default: None
Filter coefficients in the filter_mode form.
dtype : dtype or None, default: None
The dtype of the returned traces. If None, the dtype of the parent recording is used
add_reflect_padding : Bool, default False
If True, uses a left and right margin during calculation.
filter_order : order
The order of the filter for `scipy.signal.iirfilter`
filter_mode : "sos" | "ba", default: "sos"
Filter form of the filter coefficients for `scipy.signal.iirfilter`:
- second-order sections ("sos")
- numerator/denominator : ("ba")
ftype : str, default: "butter"
Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1".
direction : "forward" | "backward", default: "forward"
Returns
-------
filter_recording : FilterRecording
The filtered recording extractor object
"""

name = "causal_filter"

def __init__(
self,
recording,
band=[300.0, 6000.0],
btype="bandpass",
filter_order=1,
ftype="butter",
filter_mode="sos",
direction="forward",
margin_ms=5.0,
add_reflect_padding=False,
coeff=None,
dtype=None,

):
import scipy.signal

assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'"
fs = recording.get_sampling_frequency()
if coeff is None:
assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'"
# coefficient
# self.coeff is 'sos' or 'ab' style
filter_coeff = scipy.signal.iirfilter(
filter_order, band, fs=fs, analog=False, btype=btype, ftype=ftype, output=filter_mode
)
else:
filter_coeff = coeff
if not isinstance(coeff, list):
if filter_mode == "ba":
coeff = [c.tolist() for c in coeff]
else:
coeff = coeff.tolist()
dtype = fix_dtype(recording, dtype)

BasePreprocessor.__init__(self, recording, dtype=dtype)
self.annotate(is_filtered=True)

if "offset_to_uV" in self.get_property_keys():
self.set_channel_offsets(0)

margin = int(margin_ms * fs / 1000.0)
for parent_segment in recording._recording_segments:
self.add_recording_segment(
CausalFilterRecordingSegment(
parent_segment, filter_coeff, filter_mode, margin, dtype,
direction, add_reflect_padding=add_reflect_padding
)
)

self._kwargs = dict(
recording=recording,
band=band,
btype=btype,
filter_order=filter_order,
ftype=ftype,
filter_mode=filter_mode,
coeff=coeff,
margin_ms=margin_ms,
direction=direction,
add_reflect_padding=add_reflect_padding,
dtype=dtype.str,
)


class CausalFilterRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, direction, add_reflect_padding=False):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.coeff = coeff
self.filter_mode = filter_mode
self.margin = margin
self.add_reflect_padding = add_reflect_padding
self.dtype = dtype
self.direction = direction

def get_traces(self, start_frame, end_frame, channel_indices):
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
self.parent_recording_segment,
start_frame,
end_frame,
channel_indices,
self.margin,
add_reflect_padding=self.add_reflect_padding,
)

traces_dtype = traces_chunk.dtype
# if uint --> force int
if traces_dtype.kind == "u":
traces_chunk = traces_chunk.astype("float32")

import scipy.signal

if self.filter_mode == "sos":
if self.direction == "forward":
filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0)
elif self.direction == "backward":
filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk, axis = 0), axis = 0), axis = 0)
elif self.filter_mode == "ba":
b, a = self.coeff
if self.direction == "forward":
filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0)
elif self.direction == "backward":
filtered_traces = np.flip(scipy.signal.lfilt(b, a, np.flip(traces_chunk, axis = 0), axis=0), axis = 0)
if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
else:
filtered_traces = filtered_traces[left_margin:, :]

if np.issubdtype(self.dtype, np.integer):
filtered_traces = filtered_traces.round()

return filtered_traces.astype(self.dtype)

# functions for API
filter = define_function_from_class(source_class=FilterRecording, name="filter")
bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter")
notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter")
highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter")
causal_filter = define_function_from_class(source_class=Causalfilter, name="causal_filter")

bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs)
highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs)
Expand Down

0 comments on commit 8f3ecef

Please sign in to comment.