-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #132 from EIT-ALIVE/129a_slicing_mixin_dbodor
refactor: slicing mixin class
- Loading branch information
Showing
6 changed files
with
534 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
from __future__ import annotations | ||
import bisect | ||
import warnings | ||
from abc import ABC | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
import numpy as np | ||
from numpy.typing import NDArray | ||
from typing_extensions import Self | ||
|
||
|
||
class SelectByIndex(ABC): | ||
"""Adds slicing functionality to subclass by implementing `__getitem__`. | ||
Subclasses must implement a `_sliced_copy` function that defines what should happen | ||
when the object is sliced. This class ensures that when calling a slice between | ||
square brackets (as e.g. done for lists) then return the expected sliced object. | ||
""" | ||
|
||
def __getitem__(self, key: slice | int): | ||
if isinstance(key, slice): | ||
if key.step and key.step != 1: | ||
raise ValueError( | ||
f"Can't slice {self.__class__} object with steps other than 1." | ||
) | ||
start_index = key.start | ||
end_index = key.stop | ||
return self.select_by_index(start_index, end_index) | ||
|
||
if isinstance(key, int): | ||
return self.select_by_index(start=key, end=key + 1) | ||
|
||
raise TypeError( | ||
f"Invalid slicing input. Should be `slice` or `int`, not {type(key)}." | ||
) | ||
|
||
def select_by_index( | ||
self, | ||
start: int | None = None, | ||
end: int | None = None, | ||
label: str | None = None, | ||
) -> Self: | ||
"""De facto implementation of the `__getitem__ function.""" | ||
|
||
if start is None and end is None: | ||
warnings.warn("No starting or end timepoint was selected.") | ||
return self | ||
|
||
start = start or 0 | ||
if end is None: | ||
end = len(self) | ||
|
||
if label is None: | ||
if start > end: | ||
label = f"No frames selected from <{self.label}>" | ||
elif start < end - 1: | ||
label = f"Frames ({start}-{end-1}) of <{self.label}>" | ||
else: | ||
label = f"Frame ({start}) of <{self.label}>" | ||
|
||
return self._sliced_copy(start_index=start, end_index=end, label=label) | ||
|
||
@abstractmethod | ||
def _sliced_copy( | ||
self, | ||
start_index: int, | ||
end_index: int, | ||
label: str, | ||
) -> Self: | ||
"""Slicing method that must be implemented by all subclasses. | ||
Must return a copy of self object with all attached data within selected | ||
indices. | ||
""" | ||
... | ||
|
||
|
||
class SelectByTime(SelectByIndex): | ||
time: NDArray | ||
|
||
def select_by_time( # pylint: disable=too-many-arguments | ||
self, | ||
start_time: float | int | None = None, | ||
end_time: float | int | None = None, | ||
start_inclusive: bool = True, | ||
end_inclusive: bool = False, | ||
label: str | None = None, | ||
) -> Self: | ||
"""Get a slice from start to end time stamps. | ||
Given a start and end time stamp (i.e. its value, not its index), | ||
return a slice of the original object, which must contain a time axis. | ||
Args: | ||
start_time: first time point to include. Defaults to first frame of sequence. | ||
end_time: last time point. Defaults to last frame of sequence. | ||
start_inclusive (default: `True`), end_inclusive (default `False`): | ||
these arguments control the behavior if the given time stamp does not | ||
match exactly with an existing time stamp of the input. | ||
if `True`: the given time stamp will be inside the sliced object. | ||
if `False`: the given time stamp will be outside the sliced object. | ||
label: Description. Defaults to None, which will create a label based on the original | ||
object label and the frames by which it is sliced. | ||
Raises: | ||
TypeError: if `self` does not contain a `time` attribute. | ||
ValueError: if time stamps are not sorted. | ||
Returns: | ||
Self: _description_ | ||
""" | ||
|
||
if not "time" in vars(self): | ||
raise TypeError(f"Object {self} has no time axis.") | ||
|
||
if start_time is None and end_time is None: | ||
warnings.warn("No starting or end timepoint was selected.") | ||
return self | ||
|
||
if not np.all(np.sort(self.time) == self.time): | ||
raise ValueError( | ||
f"Time stamps for {self} are not sorted and therefore data" | ||
"cannot be selected by time." | ||
) | ||
|
||
if start_time is None or start_time < self.time[0]: | ||
start_index = 0 | ||
elif start_inclusive: | ||
start_index = bisect.bisect_right(self.time, start_time) - 1 | ||
else: | ||
start_index = bisect.bisect_left(self.time, start_time) | ||
|
||
if end_time is None: | ||
end_index = len(self.time) | ||
elif end_inclusive: | ||
end_index = bisect.bisect_left(self.time, end_time) + 1 | ||
else: | ||
end_index = bisect.bisect_left(self.time, end_time) | ||
|
||
return self.select_by_index( | ||
start=start_index, | ||
end=end_index, | ||
label=label, | ||
) | ||
|
||
@property | ||
def t(self) -> TimeIndexer: | ||
return TimeIndexer(self) | ||
|
||
|
||
@dataclass | ||
class TimeIndexer: | ||
"""Helper class allowing for slicing an object using the time axis rather than indices | ||
Example: | ||
``` | ||
>>> data = DraegerEITData.from_path(<path>) | ||
>>> tp_start = data.time[1] | ||
>>> tp_end = data.time[4] | ||
>>> time_slice = data.t[tp_start:tp_end] | ||
>>> index_slice = data[1:4] | ||
>>> time_slice == index_slice | ||
True | ||
``` | ||
""" | ||
|
||
obj: SelectByTime | ||
|
||
def __getitem__(self, key: slice | int | float): | ||
if isinstance(key, slice): | ||
if key.step: | ||
raise ValueError("Can't slice by time using specific step sizes.") | ||
return self.obj.select_by_time(key.start, key.stop) | ||
|
||
if isinstance(key, (int, float)): | ||
return self.obj.select_by_time(start=key, end=key, end_inclusive=True) | ||
|
||
raise TypeError( | ||
f"Invalid slicing input. Should be `slice` or `int` or `float`, not {type(key)}." | ||
) |
Oops, something went wrong.