Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: slicing mixin class #132

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions eitprocessing/eit_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
ContinuousDataCollection,
)
from eitprocessing.eit_data.eit_data_variant import EITDataVariant
from eitprocessing.mixins import SelectByTime
from eitprocessing.eit_data.vendor import Vendor
from eitprocessing.helper import NotEquivalent
from eitprocessing.mixins.slicing import SelectByTime
from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection
from eitprocessing.variants.variant_collection import VariantCollection
from ..helper import NotEquivalent
from .vendor import Vendor


PathLike: TypeAlias = str | Path
Expand Down Expand Up @@ -227,7 +227,10 @@ def check_equivalence(cls, a: T, b: T, raise_=False) -> bool:
return False

def _sliced_copy(
self: Self, start_index: int, end_index: int, label: str | None = None
self,
start_index: int,
end_index: int,
label: str,
) -> Self:
cls = self._get_vendor_class(self.vendor)
time = self.time[start_index:end_index]
Expand Down
5 changes: 2 additions & 3 deletions eitprocessing/eit_data/eit_data_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self
from eitprocessing.mixins import SelectByTime
from ..variants import Variant
from eitprocessing.mixins.slicing import SelectByTime
from eitprocessing.variants import Variant


@dataclass
Expand Down Expand Up @@ -72,7 +72,6 @@ def concatenate(cls, a: Self, b: Self) -> Self:
def _sliced_copy(
self, start_index: int, end_index: int, label: str | None = None
) -> Self:
label = label or f"Slice ({start_index}-{end_index}) of <{self.label}>"
pixel_impedance = self.pixel_impedance[start_index:end_index, :, :]

return self.__class__(
Expand Down
129 changes: 0 additions & 129 deletions eitprocessing/mixins.py

This file was deleted.

180 changes: 180 additions & 0 deletions eitprocessing/mixins/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations
import bisect
import warnings
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self


class SelectByIndex(ABC):
"""Adds slicing functionality to subclass by implementing `__getitem__`.

Subclasses must implement a `_sliced_copy` function that defines what should happen
when the object is sliced. This class ensures that when calling a slice between
square brackets (as e.g. done for lists) then return the expected sliced object.
"""

def __getitem__(self, key: slice | int):
if isinstance(key, slice):
if key.step and key.step != 1:
raise ValueError(
f"Can't slice {self.__class__} object with steps other than 1."
)
start_index = key.start
end_index = key.stop
return self.select_by_index(start_index, end_index)

if isinstance(key, int):
return self.select_by_index(start=key, end=key + 1)

raise TypeError(
f"Invalid slicing input. Should be `slice` or `int`, not {type(key)}."
)

def select_by_index(
self,
start: int | None = None,
end: int | None = None,
label: str | None = None,
) -> Self:
"""De facto implementation of the `__getitem__ function."""

if start is None and end is None:
warnings.warn("No starting or end timepoint was selected.")
return self

start = start or 0
if end is None:
end = len(self)

if label is None:
if start > end:
label = f"No frames selected from <{self.label}>"
elif start < end - 1:
label = f"Frames ({start}-{end-1}) of <{self.label}>"
else:
label = f"Frame ({start}) of <{self.label}>"

return self._sliced_copy(start_index=start, end_index=end, label=label)

@abstractmethod
def _sliced_copy(
DaniBodor marked this conversation as resolved.
Show resolved Hide resolved
self,
start_index: int,
end_index: int,
label: str,
) -> Self:
"""Slicing method that must be implemented by all subclasses.

Must return a copy of self object with all attached data within selected
indices.
"""
...


class SelectByTime(SelectByIndex):
time: NDArray

def select_by_time( # pylint: disable=too-many-arguments
self,
start_time: float | int | None = None,
end_time: float | int | None = None,
start_inclusive: bool = True,
end_inclusive: bool = False,
label: str | None = None,
DaniBodor marked this conversation as resolved.
Show resolved Hide resolved
) -> Self:
"""Get a slice from start to end time stamps.

Given a start and end time stamp (i.e. its value, not its index),
return a slice of the original object, which must contain a time axis.

Args:
start_time: first time point to include. Defaults to first frame of sequence.
end_time: last time point. Defaults to last frame of sequence.
start_inclusive (default: `True`), end_inclusive (default `False`):
these arguments control the behavior if the given time stamp does not
match exactly with an existing time stamp of the input.
if `True`: the given time stamp will be inside the sliced object.
if `False`: the given time stamp will be outside the sliced object.
label: Description. Defaults to None, which will create a label based on the original
object label and the frames by which it is sliced.

Raises:
TypeError: if `self` does not contain a `time` attribute.
ValueError: if time stamps are not sorted.

Returns:
Self: _description_
"""

if not "time" in vars(self):
raise TypeError(f"Object {self} has no time axis.")

if start_time is None and end_time is None:
warnings.warn("No starting or end timepoint was selected.")
return self

if not np.all(np.sort(self.time) == self.time):
raise ValueError(
f"Time stamps for {self} are not sorted and therefore data"
"cannot be selected by time."
)

if start_time is None or start_time < self.time[0]:
start_index = 0
elif start_inclusive:
start_index = bisect.bisect_right(self.time, start_time) - 1
else:
start_index = bisect.bisect_left(self.time, start_time)

if end_time is None:
end_index = len(self.time)
DaniBodor marked this conversation as resolved.
Show resolved Hide resolved
elif end_inclusive:
end_index = bisect.bisect_left(self.time, end_time) + 1
else:
end_index = bisect.bisect_left(self.time, end_time)

return self.select_by_index(
start=start_index,
end=end_index,
label=label,
)

@property
def t(self) -> TimeIndexer:
return TimeIndexer(self)


@dataclass
class TimeIndexer:
"""Helper class allowing for slicing an object using the time axis rather than indices

Example:
```
>>> data = DraegerEITData.from_path(<path>)
>>> tp_start = data.time[1]
>>> tp_end = data.time[4]
>>> time_slice = data.t[tp_start:tp_end]
>>> index_slice = data[1:4]
>>> time_slice == index_slice
True
```
"""

obj: SelectByTime

def __getitem__(self, key: slice | int | float):
if isinstance(key, slice):
if key.step:
raise ValueError("Can't slice by time using specific step sizes.")
return self.obj.select_by_time(key.start, key.stop)

if isinstance(key, (int, float)):
return self.obj.select_by_time(start=key, end=key, end_inclusive=True)
DaniBodor marked this conversation as resolved.
Show resolved Hide resolved

raise TypeError(
f"Invalid slicing input. Should be `slice` or `int` or `float`, not {type(key)}."
)
Loading