diff --git a/eitprocessing/continuous_data/__init__.py b/eitprocessing/continuous_data/__init__.py index d03c63456..3ccbee31a 100644 --- a/eitprocessing/continuous_data/__init__.py +++ b/eitprocessing/continuous_data/__init__.py @@ -1,20 +1,21 @@ -import contextlib -from dataclasses import dataclass -from dataclasses import field +from dataclasses import dataclass, field from typing import Any + import numpy as np from numpy.typing import NDArray from typing_extensions import Self -from ..helper import NotEquivalent -from ..variants.variant_collection import VariantCollection -from .continuous_data_variant import ContinuousDataVariant + +from eitprocessing.continuous_data.continuous_data_variant import ContinuousDataVariant +from eitprocessing.mixins.equality import Equivalence +from eitprocessing.variants.variant_collection import VariantCollection -@dataclass -class ContinuousData: +@dataclass(eq=False) +class ContinuousData(Equivalence): name: str unit: str description: str + category: str time: NDArray loaded: bool calculated_from: Any | list[Any] | None = None @@ -27,10 +28,11 @@ def __post_init__(self): raise DataSourceUnknown( "Data must be loaded or calculated form another dataset." ) + self._check_equivalence = ["unit", "category"] @classmethod def concatenate(cls, a: Self, b: Self) -> Self: - cls.check_equivalence(a, b, raise_=True) + cls.isequivalent(a, b, raise_=True) calculcated_from = None if a.loaded else [a.calculated_from, b.calculated_from] @@ -47,29 +49,6 @@ def concatenate(cls, a: Self, b: Self) -> Self: ) return obj - @classmethod - def check_equivalence(cls, a: Self, b: Self, raise_: bool = False) -> bool: - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if a.name != b.name: - raise NotEquivalent(f"Names do not match: {a.name}, {b.name}") - if a.unit != b.unit: - raise NotEquivalent(f"Units do not match: {a.unit}, {b.unit}") - if a.description != b.description: - raise NotEquivalent( - f"Descriptions do not match: {a.description}, {b.description}" - ) - if a.loaded != b.loaded: - raise NotEquivalent( - f"Only one of the datasets is loaded: {a.loaded=}, {b.loaded=}" - ) - - VariantCollection.check_equivalence(a.variants, b.variants, raise_=True) - - return True - - return False - class DataSourceUnknown(Exception): """Raised when the source of data is unknown.""" diff --git a/eitprocessing/continuous_data/continuous_data_collection.py b/eitprocessing/continuous_data/continuous_data_collection.py index 68b84f590..f0948c7ca 100644 --- a/eitprocessing/continuous_data/continuous_data_collection.py +++ b/eitprocessing/continuous_data/continuous_data_collection.py @@ -1,11 +1,12 @@ -import contextlib from typing import Any + from typing_extensions import Self -from ..helper import NotEquivalent -from . import ContinuousData + +from eitprocessing.continuous_data import ContinuousData +from eitprocessing.mixins.equality import Equivalence, EquivalenceError -class ContinuousDataCollection(dict): +class ContinuousDataCollection(dict, Equivalence): def __setitem__(self, __key: Any, __value: Any) -> None: self._check_data(__value, key=__key) return super().__setitem__(__key, __value) @@ -31,9 +32,11 @@ def _check_data( @classmethod def concatenate(cls, a: Self, b: Self) -> Self: try: - cls.check_equivalence(a, b, raise_=True) - except NotEquivalent as e: - raise ValueError("VariantCollections could not be concatenated") from e + cls.isequivalent(a, b, raise_=True) + except EquivalenceError as e: + raise EquivalenceError( + "ContinuousDataCollections could not be concatenated" + ) from e obj = ContinuousDataCollection() for key in a.keys() & b.keys(): @@ -41,22 +44,6 @@ def concatenate(cls, a: Self, b: Self) -> Self: return obj - @classmethod - def check_equivalence(cls, a: Self, b: Self, raise_=False) -> bool: - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if set(a.keys()) != set(b.keys()): - raise NotEquivalent( - f"VariantCollections do not contain the same variants: {a.keys()=}, {b.keys()=}" - ) - - for key in a.keys(): - ContinuousData.check_equivalence(a[key], b[key], raise_=True) - - return True - - return False - class DuplicateContinuousDataName(Exception): """Raised when a variant with the same name already exists in the collection.""" diff --git a/eitprocessing/continuous_data/continuous_data_variant.py b/eitprocessing/continuous_data/continuous_data_variant.py index 840f84f1f..d78014faf 100644 --- a/eitprocessing/continuous_data/continuous_data_variant.py +++ b/eitprocessing/continuous_data/continuous_data_variant.py @@ -1,11 +1,12 @@ -from dataclasses import dataclass -from dataclasses import field +from dataclasses import dataclass, field + from numpy.typing import NDArray from typing_extensions import Self -from ..variants import Variant + +from eitprocessing.variants import Variant -@dataclass +@dataclass(eq=False) class ContinuousDataVariant(Variant): values: NDArray = field(kw_only=True) diff --git a/eitprocessing/eit_data/__init__.py b/eitprocessing/eit_data/__init__.py index 668246c2a..d98b8b59e 100644 --- a/eitprocessing/eit_data/__init__.py +++ b/eitprocessing/eit_data/__init__.py @@ -1,35 +1,32 @@ from __future__ import annotations -import contextlib -from abc import ABC -from abc import abstractmethod -from dataclasses import dataclass -from dataclasses import field + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field from functools import reduce from pathlib import Path -from typing import TypeAlias -from typing import TypeVar +from typing import TypeAlias, TypeVar + import numpy as np from numpy.typing import NDArray -from typing_extensions import Self -from typing_extensions import override +from typing_extensions import Self, override + from eitprocessing.continuous_data.continuous_data_collection import ( ContinuousDataCollection, ) from eitprocessing.eit_data.eit_data_variant import EITDataVariant from eitprocessing.eit_data.vendor import Vendor -from eitprocessing.helper import NotEquivalent +from eitprocessing.mixins.equality import Equivalence from eitprocessing.mixins.slicing import SelectByTime from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection from eitprocessing.variants.variant_collection import VariantCollection - PathLike: TypeAlias = str | Path PathArg: TypeAlias = PathLike | list[PathLike] T = TypeVar("T", bound="EITData") -@dataclass -class EITData(SelectByTime, ABC): +@dataclass(eq=False) +class EITData(SelectByTime, Equivalence, ABC): path: Path | list[Path] nframes: int time: NDArray @@ -45,6 +42,7 @@ class EITData(SelectByTime, ABC): def __post_init__(self): if not self.label: self.label = f"{self.__class__.__name__}_{id(self)}" + self._check_equivalence = ["vendor", "framerate"] @classmethod def from_path( # pylint: disable=too-many-arguments,too-many-locals @@ -182,7 +180,7 @@ def _ensure_vendor(vendor: Vendor | str) -> Vendor: @classmethod def concatenate(cls, a: T, b: T, label: str | None = None) -> T: - cls.check_equivalence(a, b, raise_=True) + cls.isequivalent(a, b, raise_=True) a_path = cls._ensure_path_list(a.path) b_path = cls._ensure_path_list(b.path) @@ -208,24 +206,6 @@ def concatenate(cls, a: T, b: T, label: str | None = None) -> T: variants=variants, ) - @classmethod - def check_equivalence(cls, a: T, b: T, raise_=False) -> bool: - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if a.__class__ != b.__class__: - raise NotEquivalent(f"Classes don't match: {type(a)}, {type(b)}") - - if a.framerate != b.framerate: - raise NotEquivalent( - f"Framerates do not match: {a.framerate}, {b.framerate}" - ) - - VariantCollection.check_equivalence(a.variants, b.variants, raise_=True) - - return True - - return False - def _sliced_copy( self, start_index: int, @@ -272,7 +252,7 @@ def _from_path( # pylint: disable=too-many-arguments ... -@dataclass +@dataclass(eq=False) class EITData_(EITData): vendor: Vendor = field(init=False) diff --git a/eitprocessing/eit_data/eit_data_variant.py b/eitprocessing/eit_data/eit_data_variant.py index 0a03d66af..a17382f69 100644 --- a/eitprocessing/eit_data/eit_data_variant.py +++ b/eitprocessing/eit_data/eit_data_variant.py @@ -1,10 +1,3 @@ -""" -Copyright 2023 Netherlands eScience Center and Erasmus University Medical Center. -Licensed under the Apache License, version 2.0. See LICENSE for details. - -This file contains methods related to when electrical impedance tomographs are read. -""" - from dataclasses import dataclass from dataclasses import field import numpy as np @@ -14,7 +7,7 @@ from eitprocessing.variants import Variant -@dataclass +@dataclass(eq=False) class EITDataVariant(Variant, SelectByTime): _data_field_name: str = "pixel_impedance" pixel_impedance: NDArray = field(repr=False, kw_only=True) @@ -22,20 +15,6 @@ class EITDataVariant(Variant, SelectByTime): def __len__(self): return self.pixel_impedance.shape[0] - def __eq__(self, other): - for attr in ["name", "description", "params"]: - if getattr(self, attr) != getattr(other, attr): - return False - - for attr in ["pixel_impedance"]: - # NaN values are not equal. Check whether values are equal or both NaN. - s = getattr(self, attr) - o = getattr(other, attr) - if not np.all((s == o) | (np.isnan(s) & np.isnan(o))): - return False - - return True - @property def global_baseline(self): return np.nanmin(self.pixel_impedance) @@ -58,7 +37,7 @@ def global_impedance(self): @classmethod def concatenate(cls, a: Self, b: Self) -> Self: - cls.check_equivalence(a, b, raise_=True) + cls.isequivalent(a, b, raise_=True) return cls( label=a.label, @@ -70,7 +49,10 @@ def concatenate(cls, a: Self, b: Self) -> Self: ) def _sliced_copy( - self, start_index: int, end_index: int, label: str | None = None + self, + start_index: int, + end_index: int, + label: str | None = None, ) -> Self: pixel_impedance = self.pixel_impedance[start_index:end_index, :, :] diff --git a/eitprocessing/eit_data/timpel.py b/eitprocessing/eit_data/timpel.py index 1ad8d0995..933453547 100644 --- a/eitprocessing/eit_data/timpel.py +++ b/eitprocessing/eit_data/timpel.py @@ -1,24 +1,23 @@ import warnings -from dataclasses import dataclass -from dataclasses import field +from dataclasses import dataclass, field from pathlib import Path + import numpy as np from numpy.typing import NDArray from typing_extensions import Self + from eitprocessing.continuous_data.continuous_data_collection import ( ContinuousDataCollection, ) +from eitprocessing.eit_data import EITData_ +from eitprocessing.eit_data.eit_data_variant import EITDataVariant +from eitprocessing.eit_data.phases import MaxValue, MinValue, QRSMark +from eitprocessing.eit_data.vendor import Vendor from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection from eitprocessing.variants.variant_collection import VariantCollection -from ..eit_data.eit_data_variant import EITDataVariant -from ..eit_data.phases import MaxValue -from ..eit_data.phases import MinValue -from ..eit_data.phases import QRSMark -from ..eit_data.vendor import Vendor -from . import EITData_ -@dataclass +@dataclass(eq=False) class TimpelEITData(EITData_): framerate: float = 50 vendor: Vendor = field(default=Vendor.TIMPEL, init=False) @@ -93,7 +92,7 @@ def _from_path( # pylint: disable=too-many-arguments,too-many-locals # extract waveform data # TODO: properly export waveform data - waveform_data = { # noqa; + waveform_data = { # noqa "airway_pressure": data[:, 1024], "flow": data[:, 1025], "volume": data[:, 1026], diff --git a/eitprocessing/helper.py b/eitprocessing/helper.py index 6a25c8072..889df6eb1 100644 --- a/eitprocessing/helper.py +++ b/eitprocessing/helper.py @@ -1,6 +1,2 @@ -class NotEquivalent(Exception): - """Raised when two objects are not equivalent.""" - - class NotConsecutive(Exception): """Raised when trying to concatenate non-consecutive datasets.""" diff --git a/eitprocessing/mixins/__init__.py b/eitprocessing/mixins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/eitprocessing/mixins/equality.py b/eitprocessing/mixins/equality.py new file mode 100644 index 000000000..2ed3714de --- /dev/null +++ b/eitprocessing/mixins/equality.py @@ -0,0 +1,95 @@ +from abc import ABC +from dataclasses import astuple, is_dataclass + +import numpy as np +from typing_extensions import Self + + +class Equivalence(ABC): + # inspired by: https://stackoverflow.com/a/51743960/5170442 + def __eq__(self, other: Self): + if self is other: + return True + if is_dataclass(self): + if self.__class__ is not other.__class__: + return NotImplemented + t1 = astuple(self) + t2 = astuple(other) + return all(Equivalence._array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2)) + return Equivalence._array_safe_eq(self, other) + + @staticmethod + def _array_safe_eq(a, b) -> bool: + """Check if a and b are equal, even if they are numpy arrays containing nans.""" + + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return a.shape == b.shape and np.array_equal(a, b, equal_nan=True) + try: + return object.__eq__(a, b) # `a == b` could trigger an infinite loop + except TypeError: + return NotImplemented + + def isequivalent(self, other: Self, raise_: bool = False) -> bool: + """Test whether the data structure between two objects are equivalent. + + Equivalence, in this case means that objects are compatible e.g. to be + merged. Data content can vary, but e.g. the category of data (e.g. + airway pressure, flow, tidal volume) and unit, etc., must match. + + + Args: + other: object that will be compared to self. + raise_: sets this method's behavior in case of non-equivalence. If + True, an `EquivalenceError` is raised, otherwise `False` is + returned. + + Raises: + EquivalenceError: if `raise_ == True` and the objects are not + equivalent. + + Returns: + bool describing result of equivalence comparison. + """ + + if self == other: + return True + + try: + # check whether types match + if type(self) is not type(other): + raise EquivalenceError( + f"Types don't match: {type(self)}, {type(other)}" + ) + + # check keys in collection + if isinstance(self, dict): + if set(self.keys()) != set(other.keys()): + raise EquivalenceError( + f"Keys don't match:\n\t{self.keys()},\n\t{other.keys()}" + ) + + for key in self: + if not self[key].isequivalent(other[key], False): + raise EquivalenceError( + f"Data in {key} doesn't match: {self[key]}, {other[key]}" + ) + + # check attributes of data + else: + self._check_equivalence: list[str] + for attr in self._check_equivalence: + if (s := getattr(self, attr)) != (o := getattr(other, attr)): + raise f"{attr.capitalize()}s don't match: {s}, {o}" + + # raise or return if a check fails + except EquivalenceError as e: + if raise_: + raise e + return False + + # if all checks pass + return True + + +class EquivalenceError(TypeError, ValueError): + """Raised if objects are not equivalent.""" diff --git a/eitprocessing/mixins/slicing.py b/eitprocessing/mixins/slicing.py index 2b244b1ee..542eeca81 100644 --- a/eitprocessing/mixins/slicing.py +++ b/eitprocessing/mixins/slicing.py @@ -12,11 +12,14 @@ class SelectByIndex(ABC): """Adds slicing functionality to subclass by implementing `__getitem__`. - Subclasses must implement a `_sliced_copy` function that defines what should happen - when the object is sliced. This class ensures that when calling a slice between - square brackets (as e.g. done for lists) then return the expected sliced object. + Subclasses must implement a `_sliced_copy` function that defines what should + happen when the object is sliced. This class ensures that when calling a + slice between square brackets (as e.g. done for lists) then return the + expected sliced object. """ + label: str + def __getitem__(self, key: slice | int): if isinstance(key, slice): if key.step and key.step != 1: @@ -40,7 +43,12 @@ def select_by_index( end: int | None = None, label: str | None = None, ) -> Self: - """De facto implementation of the `__getitem__ function.""" + """De facto implementation of the `__getitem__ function. + + This function can also be called directly to add a label to the sliced + object. Otherwise a default label describing the slice and original + object is attached. + """ if start is None and end is None: warnings.warn("No starting or end timepoint was selected.") @@ -95,22 +103,22 @@ def select_by_time( # pylint: disable=too-many-arguments start_time: first time point to include. Defaults to first frame of sequence. end_time: last time point. Defaults to last frame of sequence. start_inclusive (default: `True`), end_inclusive (default `False`): - these arguments control the behavior if the given time stamp does not - match exactly with an existing time stamp of the input. + these arguments control the behavior if the given time stamp + does not match exactly with an existing time stamp of the input. if `True`: the given time stamp will be inside the sliced object. if `False`: the given time stamp will be outside the sliced object. - label: Description. Defaults to None, which will create a label based on the original - object label and the frames by which it is sliced. + label: Description. Defaults to None, which will create a label based + on the original object label and the frames by which it is sliced. Raises: TypeError: if `self` does not contain a `time` attribute. ValueError: if time stamps are not sorted. Returns: - Self: _description_ + Slice of self. """ - if not "time" in vars(self): + if "time" not in vars(self): raise TypeError(f"Object {self} has no time axis.") if start_time is None and end_time is None: @@ -150,11 +158,11 @@ def t(self) -> TimeIndexer: @dataclass class TimeIndexer: - """Helper class allowing for slicing an object using the time axis rather than indices + """Helper class for slicing an object using the time axis instead of indices. Example: ``` - >>> data = DraegerEITData.from_path() + >>> data = EITData.from_path(, ...) >>> tp_start = data.time[1] >>> tp_end = data.time[4] >>> time_slice = data.t[tp_start:tp_end] diff --git a/eitprocessing/mixins/test_eq.ipynb b/eitprocessing/mixins/test_eq.ipynb new file mode 100644 index 000000000..95c3802da --- /dev/null +++ b/eitprocessing/mixins/test_eq.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from eitprocessing.mixins.slicing import SelectByIndex\n", + "from dataclasses import dataclass, is_dataclass\n", + "from eitprocessing.eit_data.vendor import Vendor\n", + "from eitprocessing.eit_data.draeger import DraegerEITData\n", + "from eitprocessing.eit_data.timpel import TimpelEITData\n", + "from eitprocessing.eit_data.eit_data_variant import EITDataVariant\n", + "from typing_extensions import Self\n", + "from eitprocessing.eit_data import EITData\n", + "from eitprocessing.mixins.equality import EquivalenceError\n", + "\n", + "import os\n", + "import pytest\n", + "from pprint import pprint\n", + "import bisect\n", + "import numpy as np\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([72758.105, 72758.155, 72758.205, ..., 73357.955, 73358.005,\n", + " 73358.055])" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = DraegerEITData.from_path('/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test2.bin')\n", + "data2 = DraegerEITData.from_path('/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test3.bin')\n", + "timpel_data = TimpelEITData.from_path('/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Timpel_Test.txt')\n", + "\n", + "# pprint(data)\n", + "data.time\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "should be True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "\n", + "should be False\n", + "False\n", + "False\n", + "False\n", + "False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dbodor/git/EIT-ALIVE/eitprocessing/eitprocessing/mixins/slicing.py:52: UserWarning: No starting or end timepoint was selected.\n", + " warnings.warn(\"No starting or end timepoint was selected.\")\n" + ] + } + ], + "source": [ + "print ('should be True')\n", + "print(data == data)\n", + "print(data[10] == data[10])\n", + "print(data[:10] == data[0:10])\n", + "print(data[:] == data)\n", + "\n", + "\n", + "print('\\nshould be False')\n", + "print(data == data2)\n", + "print(data[:10] == data[10])\n", + "print(data[:10] == data[2:10])\n", + "print(data[:10] == data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "should all be True\n", + "True\n", + "True\n", + "True\n", + "\n", + "should all be False\n", + "False\n", + "False\n", + "False\n", + "False\n", + "\n", + "error was correctly raised\n" + ] + } + ], + "source": [ + "print('should all be True')\n", + "print(data.isequivalent(data))\n", + "print(data.isequivalent(data2))\n", + "print(DraegerEITData.isequivalent(data, data2))\n", + "\n", + "print('\\nshould all be False')\n", + "print(data.isequivalent(timpel_data, False))\n", + "print(timpel_data.isequivalent(data))\n", + "print(EITData.isequivalent(timpel_data, data))\n", + "print(DraegerEITData.isequivalent(timpel_data, data))\n", + "\n", + "try:\n", + " _ = DraegerEITData.isequivalent(timpel_data, data, True)\n", + " print('\\nno error was raised, but it should have!')\n", + "except EquivalenceError:\n", + " print('\\nerror was correctly raised')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "should print False/True/False and then catch error\n", + "\n", + "False\n", + "True\n", + "False\n", + "\n", + "error was correctly raised\n" + ] + } + ], + "source": [ + "data_new = DraegerEITData.from_path('/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test3.bin')\n", + "\n", + "print('should print False/True/False and then catch error\\n')\n", + "\n", + "print(data_new == data)\n", + "print(data_new.isequivalent(data))\n", + "\n", + "data_new.framerate = 25\n", + "print(data_new.isequivalent(data))\n", + "\n", + "try:\n", + " _ = data_new.isequivalent(data, raise_ = True)\n", + " print('\\nno error was raised, but it should have!')\n", + "except EquivalenceError:\n", + " print('\\nerror was correctly raised')\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "alive", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/eitprocessing/sequence/__init__.py b/eitprocessing/sequence/__init__.py index 3bf839e17..c0193ad88 100644 --- a/eitprocessing/sequence/__init__.py +++ b/eitprocessing/sequence/__init__.py @@ -1,25 +1,17 @@ -""" -Copyright 2023 Netherlands eScience Center and Erasmus University Medical Center. -Licensed under the Apache License, version 2.0. See LICENSE for details. - -This file contains methods related to parts of electrical impedance tomographs -as they are read. -""" from __future__ import annotations import bisect -import contextlib import copy import warnings from dataclasses import dataclass import numpy as np -from ..continuous_data import ContinuousData -from ..eit_data import EITData -from ..sparse_data import SparseData -from ..helper import NotEquivalent +from eitprocessing.continuous_data import ContinuousData +from eitprocessing.eit_data import EITData +from eitprocessing.mixins.equality import Equivalence +from eitprocessing.sparse_data import SparseData @dataclass(eq=False) -class Sequence: +class Sequence(Equivalence): """Sequence of timepoints containing EIT and/or waveform data. A Sequence is a representation of a continuous set of data points, either EIT frames, @@ -48,41 +40,6 @@ def __post_init__(self): if self.label is None: self.label = f"Sequence_{id(self)}" - def __eq__(self, other) -> bool: - if not self.check_equivalence(self, other, raise_=False): - return False - - # TODO: check equality of object and all attached objects - - return True - - @staticmethod - def check_equivalence(a: Sequence, b: Sequence, raise_=False): - """Checks whether content of two Sequence objects is equivalent. - - It the two objects are equivalent, the method returns `True`. - - If the two objects are not equivalent and `raise_` is `True`, a - NotEquivalent exception is raised. If `raise_` is `False`, the - method returns `False` instead. - - Raises: - - NotEquivalent: when the objects are not equivalent and `raise_` is `True` - - """ - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if a.eit_data or b.eit_data: - if not a.eit_data or not b.eit_data: - raise NotEquivalent("Only one of the sequences contains EIT data") - - EITData.check_equivalence(a.eit_data, b.eit_data, raise_=raise_) - # TODO: add other attached objects for equivalence - - return True - - return False - def __add__(self, other: Sequence) -> Sequence: return self.concatenate(self, other) @@ -95,10 +52,6 @@ def concatenate( ) -> Sequence: """Create a merge of two Sequence objects.""" # TODO: rewrite - try: - Sequence.check_equivalence(a, b, raise_=True) - except NotEquivalent as e: - raise type(e)(f"Sequences could not be merged: {e}") from e if a.eit_data and b.eit_data: eit_data = EITData.concatenate(a.eit_data, b.eit_data) diff --git a/eitprocessing/variants/__init__.py b/eitprocessing/variants/__init__.py index ad452fa5e..aa2200e24 100644 --- a/eitprocessing/variants/__init__.py +++ b/eitprocessing/variants/__init__.py @@ -1,18 +1,13 @@ -import contextlib -from abc import ABC -from abc import abstractmethod -from dataclasses import dataclass -from dataclasses import field -from typing import TypeVar -from typing_extensions import Self -from ..helper import NotEquivalent +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing_extensions import Self -T = TypeVar("T", bound="Variant") +from eitprocessing.mixins.equality import Equivalence -@dataclass -class Variant(ABC): +@dataclass(eq=False) +class Variant(Equivalence, ABC): """Contains a single variant of a dataset. A variant of a dataset is defined as either the raw data, or an edited @@ -34,44 +29,6 @@ class Variant(ABC): description: str params: dict = field(default_factory=dict) - @staticmethod - def check_equivalence(a: T, b: T, raise_=False) -> bool: - """Check the equivalence of two variants - - For two variants to be equivalent, they need to have the same class, - the same label, the same description and the same parameters. Only the - actual data can differ between variants. - - Args: - - a (Variant) - - b (Variant) - - Raises: - - NotEquivalent (only if raise_ is `True`) when a and b are not - equivalent on one of the attributes - """ - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if not isinstance(a, b.__class__): - raise NotEquivalent( - f"Variant classes don't match: {a.__class__}, {b.__class__}" - ) - - if (a_ := a.label) != (b_ := b.label): - raise NotEquivalent(f"EITDataVariant names don't match: {a_}, {b_}") - - if (a_ := a.description) != (b_ := b.description): - raise NotEquivalent( - f"EITDataVariant descriptions don't match: {a_}, {b_}" - ) - - if (a_ := a.params) != (b_ := b.params): - raise NotEquivalent(f"EITDataVariant params don't match: {a_}, {b_}") - - return True - - return False - @classmethod @abstractmethod def concatenate(cls, a: Self, b: Self) -> Self: @@ -88,5 +45,5 @@ def concatenate(cls, a: Self, b: Self) -> Self: - b (Variant) Raises: - - NotEquivalent if a and b are not equivalent and can't be merged + - EquivalenceError if a and b are not equivalent and can't be merged """ diff --git a/eitprocessing/variants/variant_collection.py b/eitprocessing/variants/variant_collection.py index e2a478c14..ca60c249d 100644 --- a/eitprocessing/variants/variant_collection.py +++ b/eitprocessing/variants/variant_collection.py @@ -1,15 +1,14 @@ -import contextlib -from typing import Generic -from typing import TypeVar +from typing import Generic, TypeVar + from typing_extensions import Self -from ..helper import NotEquivalent -from . import Variant +from eitprocessing.mixins.equality import Equivalence, EquivalenceError +from eitprocessing.variants import Variant V = TypeVar("V", bound="Variant") -class VariantCollection(dict, Generic[V]): +class VariantCollection(dict, Equivalence, Generic[V]): """A collection of variants of a single type A VariantCollection is a dictionary with some added features. @@ -51,6 +50,9 @@ class VariantCollection(dict, Generic[V]): variant_type: type[V] + def __eq__(self, other: Self): # otherwise uses the dict.__eq__ method + return Equivalence.__eq__(self, other) + def __init__(self, variant_type: type[V], *args, **kwargs): self.variant_type = variant_type super().__init__(*args, **kwargs) @@ -69,11 +71,11 @@ def add(self, *variant: V, overwrite: bool = False) -> None: Args: - variant (Variant): the variant to be added. Multiple variants can be - added at once. + added at once. Raises: - DuplicateVariantLabel if one attempts to add a variant with a label - that already exists as key. + that already exists as key. """ for variant_ in variant: self._check_variant(variant_, overwrite=overwrite) @@ -96,37 +98,16 @@ def _check_variant(self, variant: V, key=None, overwrite: bool = False) -> None: @classmethod def concatenate(cls, a: Self, b: Self) -> Self: try: - cls.check_equivalence(a, b, raise_=True) - except NotEquivalent as e: + cls.isequivalent(a, b, raise_=True) + except EquivalenceError as e: raise ValueError("VariantCollections could not be concatenated") from e obj = VariantCollection(a.variant_type) - for key in a.keys(): + for key in a: obj.add(a.variant_type.concatenate(a[key], b[key])) return obj - @classmethod - def check_equivalence(cls, a: Self, b: Self, raise_=False) -> bool: - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if a.variant_type != b.variant_type: - raise NotEquivalent( - f"Variant types do not match: {a.variant_type}, {b.variant_type}" - ) - - if set(a.keys()) != set(b.keys()): - raise NotEquivalent( - f"VariantCollections do not contain the same variants: {a.keys()=}, {b.keys()=}" - ) - - for key in a.keys(): - Variant.check_equivalence(a[key], b[key], raise_=True) - - return True - - return False - class InvalidVariantType(TypeError): """Raised when a variant that does not match the variant type is added.""" diff --git a/tests/mixins/test_eq.py b/tests/mixins/test_eq.py new file mode 100644 index 000000000..92c005ff8 --- /dev/null +++ b/tests/mixins/test_eq.py @@ -0,0 +1,26 @@ +# OLD FILE. TESTS NOT YET FUNCTIONAL + +import bisect +import os +from dataclasses import dataclass, is_dataclass +from pprint import pprint + +import numpy as np +import pytest +from typing_extensions import Self + +from eitprocessing.eit_data.draeger import DraegerEITData +from eitprocessing.eit_data.eit_data_variant import EITDataVariant +from eitprocessing.eit_data.vendor import Vendor +from eitprocessing.mixins.slicing import SelectByIndex + + +def test_eq(): + data = DraegerEITData.from_path( + "/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test3.bin" + ) + data2 = DraegerEITData.from_path( + "/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test3.bin" + ) + + data.isequivalent(data2)