Skip to content

Commit

Permalink
Merge pull request #133 from EIT-ALIVE/129b_eq_dbodor
Browse files Browse the repository at this point in the history
mixin class for equality and equivalence
  • Loading branch information
DaniBodor authored Feb 5, 2024
2 parents 0c0dacd + fee1ddf commit 2617164
Show file tree
Hide file tree
Showing 15 changed files with 420 additions and 276 deletions.
43 changes: 11 additions & 32 deletions eitprocessing/continuous_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import contextlib
from dataclasses import dataclass
from dataclasses import field
from dataclasses import dataclass, field
from typing import Any

import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self
from ..helper import NotEquivalent
from ..variants.variant_collection import VariantCollection
from .continuous_data_variant import ContinuousDataVariant

from eitprocessing.continuous_data.continuous_data_variant import ContinuousDataVariant
from eitprocessing.mixins.equality import Equivalence
from eitprocessing.variants.variant_collection import VariantCollection


@dataclass
class ContinuousData:
@dataclass(eq=False)
class ContinuousData(Equivalence):
name: str
unit: str
description: str
category: str
time: NDArray
loaded: bool
calculated_from: Any | list[Any] | None = None
Expand All @@ -27,10 +28,11 @@ def __post_init__(self):
raise DataSourceUnknown(
"Data must be loaded or calculated form another dataset."
)
self._check_equivalence = ["unit", "category"]

@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

calculcated_from = None if a.loaded else [a.calculated_from, b.calculated_from]

Expand All @@ -47,29 +49,6 @@ def concatenate(cls, a: Self, b: Self) -> Self:
)
return obj

@classmethod
def check_equivalence(cls, a: Self, b: Self, raise_: bool = False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if a.name != b.name:
raise NotEquivalent(f"Names do not match: {a.name}, {b.name}")
if a.unit != b.unit:
raise NotEquivalent(f"Units do not match: {a.unit}, {b.unit}")
if a.description != b.description:
raise NotEquivalent(
f"Descriptions do not match: {a.description}, {b.description}"
)
if a.loaded != b.loaded:
raise NotEquivalent(
f"Only one of the datasets is loaded: {a.loaded=}, {b.loaded=}"
)

VariantCollection.check_equivalence(a.variants, b.variants, raise_=True)

return True

return False


class DataSourceUnknown(Exception):
"""Raised when the source of data is unknown."""
33 changes: 10 additions & 23 deletions eitprocessing/continuous_data/continuous_data_collection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import contextlib
from typing import Any

from typing_extensions import Self
from ..helper import NotEquivalent
from . import ContinuousData

from eitprocessing.continuous_data import ContinuousData
from eitprocessing.mixins.equality import Equivalence, EquivalenceError


class ContinuousDataCollection(dict):
class ContinuousDataCollection(dict, Equivalence):
def __setitem__(self, __key: Any, __value: Any) -> None:
self._check_data(__value, key=__key)
return super().__setitem__(__key, __value)
Expand All @@ -31,32 +32,18 @@ def _check_data(
@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
try:
cls.check_equivalence(a, b, raise_=True)
except NotEquivalent as e:
raise ValueError("VariantCollections could not be concatenated") from e
cls.isequivalent(a, b, raise_=True)
except EquivalenceError as e:
raise EquivalenceError(
"ContinuousDataCollections could not be concatenated"
) from e

obj = ContinuousDataCollection()
for key in a.keys() & b.keys():
obj.add(ContinuousData.concatenate(a[key], b[key]))

return obj

@classmethod
def check_equivalence(cls, a: Self, b: Self, raise_=False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if set(a.keys()) != set(b.keys()):
raise NotEquivalent(
f"VariantCollections do not contain the same variants: {a.keys()=}, {b.keys()=}"
)

for key in a.keys():
ContinuousData.check_equivalence(a[key], b[key], raise_=True)

return True

return False


class DuplicateContinuousDataName(Exception):
"""Raised when a variant with the same name already exists in the collection."""
9 changes: 5 additions & 4 deletions eitprocessing/continuous_data/continuous_data_variant.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from dataclasses import field
from dataclasses import dataclass, field

from numpy.typing import NDArray
from typing_extensions import Self
from ..variants import Variant

from eitprocessing.variants import Variant


@dataclass
@dataclass(eq=False)
class ContinuousDataVariant(Variant):
values: NDArray = field(kw_only=True)

Expand Down
46 changes: 13 additions & 33 deletions eitprocessing/eit_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,32 @@
from __future__ import annotations
import contextlib
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import field

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import reduce
from pathlib import Path
from typing import TypeAlias
from typing import TypeVar
from typing import TypeAlias, TypeVar

import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self
from typing_extensions import override
from typing_extensions import Self, override

from eitprocessing.continuous_data.continuous_data_collection import (
ContinuousDataCollection,
)
from eitprocessing.eit_data.eit_data_variant import EITDataVariant
from eitprocessing.eit_data.vendor import Vendor
from eitprocessing.helper import NotEquivalent
from eitprocessing.mixins.equality import Equivalence
from eitprocessing.mixins.slicing import SelectByTime
from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection
from eitprocessing.variants.variant_collection import VariantCollection


PathLike: TypeAlias = str | Path
PathArg: TypeAlias = PathLike | list[PathLike]
T = TypeVar("T", bound="EITData")


@dataclass
class EITData(SelectByTime, ABC):
@dataclass(eq=False)
class EITData(SelectByTime, Equivalence, ABC):
path: Path | list[Path]
nframes: int
time: NDArray
Expand All @@ -45,6 +42,7 @@ class EITData(SelectByTime, ABC):
def __post_init__(self):
if not self.label:
self.label = f"{self.__class__.__name__}_{id(self)}"
self._check_equivalence = ["vendor", "framerate"]

@classmethod
def from_path( # pylint: disable=too-many-arguments,too-many-locals
Expand Down Expand Up @@ -182,7 +180,7 @@ def _ensure_vendor(vendor: Vendor | str) -> Vendor:

@classmethod
def concatenate(cls, a: T, b: T, label: str | None = None) -> T:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

a_path = cls._ensure_path_list(a.path)
b_path = cls._ensure_path_list(b.path)
Expand All @@ -208,24 +206,6 @@ def concatenate(cls, a: T, b: T, label: str | None = None) -> T:
variants=variants,
)

@classmethod
def check_equivalence(cls, a: T, b: T, raise_=False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if a.__class__ != b.__class__:
raise NotEquivalent(f"Classes don't match: {type(a)}, {type(b)}")

if a.framerate != b.framerate:
raise NotEquivalent(
f"Framerates do not match: {a.framerate}, {b.framerate}"
)

VariantCollection.check_equivalence(a.variants, b.variants, raise_=True)

return True

return False

def _sliced_copy(
self,
start_index: int,
Expand Down Expand Up @@ -272,7 +252,7 @@ def _from_path( # pylint: disable=too-many-arguments
...


@dataclass
@dataclass(eq=False)
class EITData_(EITData):
vendor: Vendor = field(init=False)

Expand Down
30 changes: 6 additions & 24 deletions eitprocessing/eit_data/eit_data_variant.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
"""
Copyright 2023 Netherlands eScience Center and Erasmus University Medical Center.
Licensed under the Apache License, version 2.0. See LICENSE for details.
This file contains methods related to when electrical impedance tomographs are read.
"""

from dataclasses import dataclass
from dataclasses import field
import numpy as np
Expand All @@ -14,28 +7,14 @@
from eitprocessing.variants import Variant


@dataclass
@dataclass(eq=False)
class EITDataVariant(Variant, SelectByTime):
_data_field_name: str = "pixel_impedance"
pixel_impedance: NDArray = field(repr=False, kw_only=True)

def __len__(self):
return self.pixel_impedance.shape[0]

def __eq__(self, other):
for attr in ["name", "description", "params"]:
if getattr(self, attr) != getattr(other, attr):
return False

for attr in ["pixel_impedance"]:
# NaN values are not equal. Check whether values are equal or both NaN.
s = getattr(self, attr)
o = getattr(other, attr)
if not np.all((s == o) | (np.isnan(s) & np.isnan(o))):
return False

return True

@property
def global_baseline(self):
return np.nanmin(self.pixel_impedance)
Expand All @@ -58,7 +37,7 @@ def global_impedance(self):

@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

return cls(
label=a.label,
Expand All @@ -70,7 +49,10 @@ def concatenate(cls, a: Self, b: Self) -> Self:
)

def _sliced_copy(
self, start_index: int, end_index: int, label: str | None = None
self,
start_index: int,
end_index: int,
label: str | None = None,
) -> Self:
pixel_impedance = self.pixel_impedance[start_index:end_index, :, :]

Expand Down
19 changes: 9 additions & 10 deletions eitprocessing/eit_data/timpel.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import warnings
from dataclasses import dataclass
from dataclasses import field
from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self

from eitprocessing.continuous_data.continuous_data_collection import (
ContinuousDataCollection,
)
from eitprocessing.eit_data import EITData_
from eitprocessing.eit_data.eit_data_variant import EITDataVariant
from eitprocessing.eit_data.phases import MaxValue, MinValue, QRSMark
from eitprocessing.eit_data.vendor import Vendor
from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection
from eitprocessing.variants.variant_collection import VariantCollection
from ..eit_data.eit_data_variant import EITDataVariant
from ..eit_data.phases import MaxValue
from ..eit_data.phases import MinValue
from ..eit_data.phases import QRSMark
from ..eit_data.vendor import Vendor
from . import EITData_


@dataclass
@dataclass(eq=False)
class TimpelEITData(EITData_):
framerate: float = 50
vendor: Vendor = field(default=Vendor.TIMPEL, init=False)
Expand Down Expand Up @@ -93,7 +92,7 @@ def _from_path( # pylint: disable=too-many-arguments,too-many-locals

# extract waveform data
# TODO: properly export waveform data
waveform_data = { # noqa;
waveform_data = { # noqa
"airway_pressure": data[:, 1024],
"flow": data[:, 1025],
"volume": data[:, 1026],
Expand Down
4 changes: 0 additions & 4 deletions eitprocessing/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
class NotEquivalent(Exception):
"""Raised when two objects are not equivalent."""


class NotConsecutive(Exception):
"""Raised when trying to concatenate non-consecutive datasets."""
Empty file.
Loading

0 comments on commit 2617164

Please sign in to comment.