Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 30, 2024
1 parent 8f3ecef commit b25e683
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):

self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)


class Causalfilter(BasePreprocessor):
"""
filter class based on:
Expand Down Expand Up @@ -346,7 +347,6 @@ def __init__(
add_reflect_padding=False,
coeff=None,
dtype=None,

):
import scipy.signal

Expand Down Expand Up @@ -378,9 +378,14 @@ def __init__(
for parent_segment in recording._recording_segments:
self.add_recording_segment(
CausalFilterRecordingSegment(
parent_segment, filter_coeff, filter_mode, margin, dtype,
direction, add_reflect_padding=add_reflect_padding
)
parent_segment,
filter_coeff,
filter_mode,
margin,
dtype,
direction,
add_reflect_padding=add_reflect_padding,
)
)

self._kwargs = dict(
Expand All @@ -399,7 +404,9 @@ def __init__(


class CausalFilterRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, direction, add_reflect_padding=False):
def __init__(
self, parent_recording_segment, coeff, filter_mode, margin, dtype, direction, add_reflect_padding=False
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.coeff = coeff
self.filter_mode = filter_mode
Expand Down Expand Up @@ -429,13 +436,15 @@ def get_traces(self, start_frame, end_frame, channel_indices):
if self.direction == "forward":
filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0)
elif self.direction == "backward":
filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk, axis = 0), axis = 0), axis = 0)
filtered_traces = np.flip(
scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk, axis=0), axis=0), axis=0
)
elif self.filter_mode == "ba":
b, a = self.coeff
if self.direction == "forward":
filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0)
elif self.direction == "backward":
filtered_traces = np.flip(scipy.signal.lfilt(b, a, np.flip(traces_chunk, axis = 0), axis=0), axis = 0)
filtered_traces = np.flip(scipy.signal.lfilt(b, a, np.flip(traces_chunk, axis=0), axis=0), axis=0)
if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
else:
Expand All @@ -446,6 +455,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):

return filtered_traces.astype(self.dtype)


# functions for API
filter = define_function_from_class(source_class=FilterRecording, name="filter")
bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter")
Expand Down

0 comments on commit b25e683

Please sign in to comment.