Skip to content

Commit

Permalink
Merge pull request #132 from EIT-ALIVE/129a_slicing_mixin_dbodor
Browse files Browse the repository at this point in the history
refactor: slicing mixin class
  • Loading branch information
DaniBodor authored Dec 8, 2023
2 parents c984d58 + 6dc006b commit 5ae7437
Show file tree
Hide file tree
Showing 6 changed files with 534 additions and 136 deletions.
11 changes: 7 additions & 4 deletions eitprocessing/eit_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
ContinuousDataCollection,
)
from eitprocessing.eit_data.eit_data_variant import EITDataVariant
from eitprocessing.mixins import SelectByTime
from eitprocessing.eit_data.vendor import Vendor
from eitprocessing.helper import NotEquivalent
from eitprocessing.mixins.slicing import SelectByTime
from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection
from eitprocessing.variants.variant_collection import VariantCollection
from ..helper import NotEquivalent
from .vendor import Vendor


PathLike: TypeAlias = str | Path
Expand Down Expand Up @@ -227,7 +227,10 @@ def check_equivalence(cls, a: T, b: T, raise_=False) -> bool:
return False

def _sliced_copy(
self: Self, start_index: int, end_index: int, label: str | None = None
self,
start_index: int,
end_index: int,
label: str,
) -> Self:
cls = self._get_vendor_class(self.vendor)
time = self.time[start_index:end_index]
Expand Down
5 changes: 2 additions & 3 deletions eitprocessing/eit_data/eit_data_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self
from eitprocessing.mixins import SelectByTime
from ..variants import Variant
from eitprocessing.mixins.slicing import SelectByTime
from eitprocessing.variants import Variant


@dataclass
Expand Down Expand Up @@ -72,7 +72,6 @@ def concatenate(cls, a: Self, b: Self) -> Self:
def _sliced_copy(
self, start_index: int, end_index: int, label: str | None = None
) -> Self:
label = label or f"Slice ({start_index}-{end_index}) of <{self.label}>"
pixel_impedance = self.pixel_impedance[start_index:end_index, :, :]

return self.__class__(
Expand Down
129 changes: 0 additions & 129 deletions eitprocessing/mixins.py

This file was deleted.

180 changes: 180 additions & 0 deletions eitprocessing/mixins/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations
import bisect
import warnings
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self


class SelectByIndex(ABC):
"""Adds slicing functionality to subclass by implementing `__getitem__`.
Subclasses must implement a `_sliced_copy` function that defines what should happen
when the object is sliced. This class ensures that when calling a slice between
square brackets (as e.g. done for lists) then return the expected sliced object.
"""

def __getitem__(self, key: slice | int):
if isinstance(key, slice):
if key.step and key.step != 1:
raise ValueError(
f"Can't slice {self.__class__} object with steps other than 1."
)
start_index = key.start
end_index = key.stop
return self.select_by_index(start_index, end_index)

if isinstance(key, int):
return self.select_by_index(start=key, end=key + 1)

raise TypeError(
f"Invalid slicing input. Should be `slice` or `int`, not {type(key)}."
)

def select_by_index(
self,
start: int | None = None,
end: int | None = None,
label: str | None = None,
) -> Self:
"""De facto implementation of the `__getitem__ function."""

if start is None and end is None:
warnings.warn("No starting or end timepoint was selected.")
return self

start = start or 0
if end is None:
end = len(self)

if label is None:
if start > end:
label = f"No frames selected from <{self.label}>"
elif start < end - 1:
label = f"Frames ({start}-{end-1}) of <{self.label}>"
else:
label = f"Frame ({start}) of <{self.label}>"

return self._sliced_copy(start_index=start, end_index=end, label=label)

@abstractmethod
def _sliced_copy(
self,
start_index: int,
end_index: int,
label: str,
) -> Self:
"""Slicing method that must be implemented by all subclasses.
Must return a copy of self object with all attached data within selected
indices.
"""
...


class SelectByTime(SelectByIndex):
time: NDArray

def select_by_time( # pylint: disable=too-many-arguments
self,
start_time: float | int | None = None,
end_time: float | int | None = None,
start_inclusive: bool = True,
end_inclusive: bool = False,
label: str | None = None,
) -> Self:
"""Get a slice from start to end time stamps.
Given a start and end time stamp (i.e. its value, not its index),
return a slice of the original object, which must contain a time axis.
Args:
start_time: first time point to include. Defaults to first frame of sequence.
end_time: last time point. Defaults to last frame of sequence.
start_inclusive (default: `True`), end_inclusive (default `False`):
these arguments control the behavior if the given time stamp does not
match exactly with an existing time stamp of the input.
if `True`: the given time stamp will be inside the sliced object.
if `False`: the given time stamp will be outside the sliced object.
label: Description. Defaults to None, which will create a label based on the original
object label and the frames by which it is sliced.
Raises:
TypeError: if `self` does not contain a `time` attribute.
ValueError: if time stamps are not sorted.
Returns:
Self: _description_
"""

if not "time" in vars(self):
raise TypeError(f"Object {self} has no time axis.")

if start_time is None and end_time is None:
warnings.warn("No starting or end timepoint was selected.")
return self

if not np.all(np.sort(self.time) == self.time):
raise ValueError(
f"Time stamps for {self} are not sorted and therefore data"
"cannot be selected by time."
)

if start_time is None or start_time < self.time[0]:
start_index = 0
elif start_inclusive:
start_index = bisect.bisect_right(self.time, start_time) - 1
else:
start_index = bisect.bisect_left(self.time, start_time)

if end_time is None:
end_index = len(self.time)
elif end_inclusive:
end_index = bisect.bisect_left(self.time, end_time) + 1
else:
end_index = bisect.bisect_left(self.time, end_time)

return self.select_by_index(
start=start_index,
end=end_index,
label=label,
)

@property
def t(self) -> TimeIndexer:
return TimeIndexer(self)


@dataclass
class TimeIndexer:
"""Helper class allowing for slicing an object using the time axis rather than indices
Example:
```
>>> data = DraegerEITData.from_path(<path>)
>>> tp_start = data.time[1]
>>> tp_end = data.time[4]
>>> time_slice = data.t[tp_start:tp_end]
>>> index_slice = data[1:4]
>>> time_slice == index_slice
True
```
"""

obj: SelectByTime

def __getitem__(self, key: slice | int | float):
if isinstance(key, slice):
if key.step:
raise ValueError("Can't slice by time using specific step sizes.")
return self.obj.select_by_time(key.start, key.stop)

if isinstance(key, (int, float)):
return self.obj.select_by_time(start=key, end=key, end_inclusive=True)

raise TypeError(
f"Invalid slicing input. Should be `slice` or `int` or `float`, not {type(key)}."
)
Loading

0 comments on commit 5ae7437

Please sign in to comment.