Skip to content

Commit

Permalink
refactoring/bugfixing
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Sep 28, 2023
1 parent b091cbc commit 12aec80
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
TypeVar,
Union,
cast,
Sequence,
)

import numpy as np
Expand Down Expand Up @@ -109,7 +108,7 @@ def __init__(
workers: int = 1,
index_column: str = "event_no",
icetray_verbose: int = 0,
I3_Filters: Union[I3Filter, List[Callable]] = [NullSplitI3Filter],
I3_Filters: List[I3Filter] = [],
):
"""Construct DataConverter.
Expand Down Expand Up @@ -169,9 +168,10 @@ def __init__(
self._sequential_batch_pattern = sequential_batch_pattern
self._input_file_batch_pattern = input_file_batch_pattern
self._workers = workers
if isinstance(I3_Filters, I3Filter):
I3_Filters = [I3_Filters]
self._I3Filters = I3_Filters

# I3Filters (NullSplitI3Filter is always included)
self._I3Filters = [NullSplitI3Filter()] + I3_Filters

for filter in self._I3Filters:
assert isinstance(
filter, I3Filter
Expand Down
8 changes: 4 additions & 4 deletions src/graphnet/data/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ class I3Filter(Logger):
"""A generic filter for I3-frames."""

@abstractmethod
def _pass_frame(self, frame: "icetray.I3Frame") -> bool:
"""Return True if the frame passes the filter, False otherwise.
def _keep_frame(self, frame: "icetray.I3Frame") -> bool:
"""Return True if the frame is kept, False otherwise.
Args:
frame: I3-frame
The I3-frame to check.
Returns:
bool: True if the frame passes the filter, False otherwise.
bool: True if the frame is kept, False otherwise.
"""
raise NotImplementedError

Expand All @@ -35,7 +35,7 @@ def __call__(self, frame: "icetray.I3Frame") -> bool:
Returns:
bool: True if the frame passes the filter, False otherwise.
"""
pass_flag = self._pass_frame(frame)
pass_flag = self._keep_frame(frame)
try:
assert isinstance(pass_flag, bool)
except AssertionError:
Expand Down

0 comments on commit 12aec80

Please sign in to comment.