diff --git a/eitprocessing/continuous_data/__init__.py b/eitprocessing/continuous_data/__init__.py index d03c63456..a87d14242 100644 --- a/eitprocessing/continuous_data/__init__.py +++ b/eitprocessing/continuous_data/__init__.py @@ -1,17 +1,17 @@ -import contextlib from dataclasses import dataclass from dataclasses import field from typing import Any import numpy as np from numpy.typing import NDArray from typing_extensions import Self -from ..helper import NotEquivalent -from ..variants.variant_collection import VariantCollection -from .continuous_data_variant import ContinuousDataVariant +from eitprocessing.continuous_data.continuous_data_variant import ContinuousDataVariant +from eitprocessing.mixins.equality import Equivalence +from eitprocessing.mixins.equality import EquivalenceError +from eitprocessing.variants.variant_collection import VariantCollection -@dataclass -class ContinuousData: +@dataclass(eq=False) +class ContinuousData(Equivalence): name: str unit: str description: str @@ -30,7 +30,7 @@ def __post_init__(self): @classmethod def concatenate(cls, a: Self, b: Self) -> Self: - cls.check_equivalence(a, b, raise_=True) + cls.isequivalent(a, b, raise_=True) calculcated_from = None if a.loaded else [a.calculated_from, b.calculated_from] @@ -47,28 +47,21 @@ def concatenate(cls, a: Self, b: Self) -> Self: ) return obj - @classmethod - def check_equivalence(cls, a: Self, b: Self, raise_: bool = False) -> bool: - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if a.name != b.name: - raise NotEquivalent(f"Names do not match: {a.name}, {b.name}") - if a.unit != b.unit: - raise NotEquivalent(f"Units do not match: {a.unit}, {b.unit}") - if a.description != b.description: - raise NotEquivalent( - f"Descriptions do not match: {a.description}, {b.description}" - ) - if a.loaded != b.loaded: - raise NotEquivalent( - f"Only one of the datasets is loaded: {a.loaded=}, {b.loaded=}" - ) - - VariantCollection.check_equivalence(a.variants, b.variants, raise_=True) - - return True - - return False + def isequivalent( + self, + other: Self, + raise_: bool = False, + ) -> bool: + # fmt: off + checks = { + f"Names don't match: {self.name}, {other.name}.": self.name == other.name, + f"Units don't match: {self.unit}, {other.unit}.": self.unit == other.unit, + f"Descriptions don't match: {self.description}, {other.description}.": self.description == other.description, + f"Only one of the datasets is loaded: {self.loaded=}, {other.loaded=}.": self.loaded == other.loaded, + f"VariantCollections are not equivalent: {self.variants}, {other.variants}.": VariantCollection.isequivalent(self.variants,other.variants, raise_), + } + # fmt: on + return super()._isequivalent(other, raise_, checks) class DataSourceUnknown(Exception): diff --git a/eitprocessing/continuous_data/continuous_data_collection.py b/eitprocessing/continuous_data/continuous_data_collection.py index 68b84f590..a14964a59 100644 --- a/eitprocessing/continuous_data/continuous_data_collection.py +++ b/eitprocessing/continuous_data/continuous_data_collection.py @@ -1,11 +1,11 @@ -import contextlib from typing import Any from typing_extensions import Self -from ..helper import NotEquivalent -from . import ContinuousData +from eitprocessing.continuous_data import ContinuousData +from eitprocessing.mixins.equality import Equivalence +from eitprocessing.mixins.equality import EquivalenceError -class ContinuousDataCollection(dict): +class ContinuousDataCollection(dict, Equivalence): def __setitem__(self, __key: Any, __value: Any) -> None: self._check_data(__value, key=__key) return super().__setitem__(__key, __value) @@ -31,8 +31,8 @@ def _check_data( @classmethod def concatenate(cls, a: Self, b: Self) -> Self: try: - cls.check_equivalence(a, b, raise_=True) - except NotEquivalent as e: + cls.isequivalent(a, b, raise_=True) + except EquivalenceError as e: raise ValueError("VariantCollections could not be concatenated") from e obj = ContinuousDataCollection() @@ -41,21 +41,19 @@ def concatenate(cls, a: Self, b: Self) -> Self: return obj - @classmethod - def check_equivalence(cls, a: Self, b: Self, raise_=False) -> bool: - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if set(a.keys()) != set(b.keys()): - raise NotEquivalent( - f"VariantCollections do not contain the same variants: {a.keys()=}, {b.keys()=}" - ) - - for key in a.keys(): - ContinuousData.check_equivalence(a[key], b[key], raise_=True) - - return True - - return False + def isequivalent( + self, + other: Self, + raise_: bool = False, + ) -> bool: + # fmt: off + checks = { + f"VariantCollections do not contain the same variants: {self.keys()=}, {other.keys()=}": set(self.keys()) == set(other.keys()), + } + for key in self.keys(): + checks[f"Continuous data ({key}) is not equivalent: {self[key]}, {other[key]}"] = ContinuousData.isequivalent(self[key], other[key], raise_) + # fmt: on + return super()._isequivalent(other, raise_, checks) class DuplicateContinuousDataName(Exception): diff --git a/eitprocessing/continuous_data/continuous_data_variant.py b/eitprocessing/continuous_data/continuous_data_variant.py index 840f84f1f..f77e64b32 100644 --- a/eitprocessing/continuous_data/continuous_data_variant.py +++ b/eitprocessing/continuous_data/continuous_data_variant.py @@ -5,7 +5,7 @@ from ..variants import Variant -@dataclass +@dataclass(eq=False) class ContinuousDataVariant(Variant): values: NDArray = field(kw_only=True) diff --git a/eitprocessing/eit_data/__init__.py b/eitprocessing/eit_data/__init__.py index e977dc57a..3c5015f94 100644 --- a/eitprocessing/eit_data/__init__.py +++ b/eitprocessing/eit_data/__init__.py @@ -1,5 +1,4 @@ from __future__ import annotations -import contextlib from abc import ABC from abc import abstractmethod from dataclasses import dataclass @@ -17,7 +16,6 @@ ) from eitprocessing.eit_data.eit_data_variant import EITDataVariant from eitprocessing.eit_data.vendor import Vendor -from eitprocessing.helper import NotEquivalent from eitprocessing.mixins.equality import Equivalence from eitprocessing.mixins.slicing import SelectByTime from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection @@ -183,7 +181,7 @@ def _ensure_vendor(vendor: Vendor | str) -> Vendor: @classmethod def concatenate(cls, a: T, b: T, label: str | None = None) -> T: - cls.check_equivalence(a, b, raise_=True) + cls.isequivalent(a, b, raise_=True) a_path = cls._ensure_path_list(a.path) b_path = cls._ensure_path_list(b.path) @@ -209,23 +207,18 @@ def concatenate(cls, a: T, b: T, label: str | None = None) -> T: variants=variants, ) - @classmethod - def check_equivalence(cls, a: T, b: T, raise_=False) -> bool: - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if a.__class__ != b.__class__: - raise NotEquivalent(f"Classes don't match: {type(a)}, {type(b)}") - - if a.framerate != b.framerate: - raise NotEquivalent( - f"Framerates do not match: {a.framerate}, {b.framerate}" - ) - - VariantCollection.check_equivalence(a.variants, b.variants, raise_=True) - - return True - - return False + def isequivalent( + self, + other: Self, + raise_: bool = False, + ) -> bool: + # fmt: off + checks = { + f"Framerates don't match: {self.framerate}, {other.framerate}": self.framerate == other.framerate, + "VariantCollections are not equivalent": VariantCollection.isequivalent(self.variants, other.variants, raise_), + } + # fmt: on + return super()._isequivalent(other, raise_, checks) def _sliced_copy( self, diff --git a/eitprocessing/eit_data/eit_data_variant.py b/eitprocessing/eit_data/eit_data_variant.py index f858d1111..a17382f69 100644 --- a/eitprocessing/eit_data/eit_data_variant.py +++ b/eitprocessing/eit_data/eit_data_variant.py @@ -37,7 +37,7 @@ def global_impedance(self): @classmethod def concatenate(cls, a: Self, b: Self) -> Self: - cls.check_equivalence(a, b, raise_=True) + cls.isequivalent(a, b, raise_=True) return cls( label=a.label, diff --git a/eitprocessing/eit_data/timpel.py b/eitprocessing/eit_data/timpel.py index 1ad8d0995..8402b26bb 100644 --- a/eitprocessing/eit_data/timpel.py +++ b/eitprocessing/eit_data/timpel.py @@ -18,7 +18,7 @@ from . import EITData_ -@dataclass +@dataclass(eq=False) class TimpelEITData(EITData_): framerate: float = 50 vendor: Vendor = field(default=Vendor.TIMPEL, init=False) diff --git a/eitprocessing/helper.py b/eitprocessing/helper.py index 6a25c8072..889df6eb1 100644 --- a/eitprocessing/helper.py +++ b/eitprocessing/helper.py @@ -1,6 +1,2 @@ -class NotEquivalent(Exception): - """Raised when two objects are not equivalent.""" - - class NotConsecutive(Exception): """Raised when trying to concatenate non-consecutive datasets.""" diff --git a/eitprocessing/mixins/test_eq.ipynb b/eitprocessing/mixins/test_eq.ipynb new file mode 100644 index 000000000..95c3802da --- /dev/null +++ b/eitprocessing/mixins/test_eq.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from eitprocessing.mixins.slicing import SelectByIndex\n", + "from dataclasses import dataclass, is_dataclass\n", + "from eitprocessing.eit_data.vendor import Vendor\n", + "from eitprocessing.eit_data.draeger import DraegerEITData\n", + "from eitprocessing.eit_data.timpel import TimpelEITData\n", + "from eitprocessing.eit_data.eit_data_variant import EITDataVariant\n", + "from typing_extensions import Self\n", + "from eitprocessing.eit_data import EITData\n", + "from eitprocessing.mixins.equality import EquivalenceError\n", + "\n", + "import os\n", + "import pytest\n", + "from pprint import pprint\n", + "import bisect\n", + "import numpy as np\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([72758.105, 72758.155, 72758.205, ..., 73357.955, 73358.005,\n", + " 73358.055])" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = DraegerEITData.from_path('/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test2.bin')\n", + "data2 = DraegerEITData.from_path('/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test3.bin')\n", + "timpel_data = TimpelEITData.from_path('/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Timpel_Test.txt')\n", + "\n", + "# pprint(data)\n", + "data.time\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "should be True\n", + "True\n", + "True\n", + "True\n", + "True\n", + "\n", + "should be False\n", + "False\n", + "False\n", + "False\n", + "False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dbodor/git/EIT-ALIVE/eitprocessing/eitprocessing/mixins/slicing.py:52: UserWarning: No starting or end timepoint was selected.\n", + " warnings.warn(\"No starting or end timepoint was selected.\")\n" + ] + } + ], + "source": [ + "print ('should be True')\n", + "print(data == data)\n", + "print(data[10] == data[10])\n", + "print(data[:10] == data[0:10])\n", + "print(data[:] == data)\n", + "\n", + "\n", + "print('\\nshould be False')\n", + "print(data == data2)\n", + "print(data[:10] == data[10])\n", + "print(data[:10] == data[2:10])\n", + "print(data[:10] == data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "should all be True\n", + "True\n", + "True\n", + "True\n", + "\n", + "should all be False\n", + "False\n", + "False\n", + "False\n", + "False\n", + "\n", + "error was correctly raised\n" + ] + } + ], + "source": [ + "print('should all be True')\n", + "print(data.isequivalent(data))\n", + "print(data.isequivalent(data2))\n", + "print(DraegerEITData.isequivalent(data, data2))\n", + "\n", + "print('\\nshould all be False')\n", + "print(data.isequivalent(timpel_data, False))\n", + "print(timpel_data.isequivalent(data))\n", + "print(EITData.isequivalent(timpel_data, data))\n", + "print(DraegerEITData.isequivalent(timpel_data, data))\n", + "\n", + "try:\n", + " _ = DraegerEITData.isequivalent(timpel_data, data, True)\n", + " print('\\nno error was raised, but it should have!')\n", + "except EquivalenceError:\n", + " print('\\nerror was correctly raised')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "should print False/True/False and then catch error\n", + "\n", + "False\n", + "True\n", + "False\n", + "\n", + "error was correctly raised\n" + ] + } + ], + "source": [ + "data_new = DraegerEITData.from_path('/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test3.bin')\n", + "\n", + "print('should print False/True/False and then catch error\\n')\n", + "\n", + "print(data_new == data)\n", + "print(data_new.isequivalent(data))\n", + "\n", + "data_new.framerate = 25\n", + "print(data_new.isequivalent(data))\n", + "\n", + "try:\n", + " _ = data_new.isequivalent(data, raise_ = True)\n", + " print('\\nno error was raised, but it should have!')\n", + "except EquivalenceError:\n", + " print('\\nerror was correctly raised')\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "alive", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/eitprocessing/sequence/__init__.py b/eitprocessing/sequence/__init__.py index 2a0cdd968..faa21189b 100644 --- a/eitprocessing/sequence/__init__.py +++ b/eitprocessing/sequence/__init__.py @@ -1,23 +1,17 @@ -""" -Copyright 2023 Netherlands eScience Center and Erasmus University Medical Center. -Licensed under the Apache License, version 2.0. See LICENSE for details. - -This file contains methods related to parts of electrical impedance tomographs -as they are read. -""" from __future__ import annotations import bisect -import contextlib import copy import warnings from dataclasses import dataclass import numpy as np -from ..eit_data import EITData -from ..helper import NotEquivalent +from typing_extensions import Self +from eitprocessing.eit_data import EITData +from eitprocessing.mixins.equality import Equivalence +from eitprocessing.mixins.equality import EquivalenceError @dataclass(eq=False) -class Sequence: +class Sequence(Equivalence): """Sequence of timepoints containing EIT and/or waveform data. A Sequence is a representation of a continuous set of data points, either EIT frames, @@ -44,40 +38,19 @@ def __post_init__(self): if self.label is None: self.label = f"Sequence_{id(self)}" - def __eq__(self, other) -> bool: - if not self.check_equivalence(self, other, raise_=False): - return False - - # TODO: check equality of object and all attached objects - - return True - - @staticmethod - def check_equivalence(a: Sequence, b: Sequence, raise_=False): - """Checks whether content of two Sequence objects is equivalent. - - It the two objects are equivalent, the method returns `True`. - - If the two objects are not equivalent and `raise_` is `True`, a - NotEquivalent exception is raised. If `raise_` is `False`, the - method returns `False` instead. - - Raises: - - NotEquivalent: when the objects are not equivalent and `raise_` is `True` - - """ - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if a.eit_data or b.eit_data: - if not a.eit_data or not b.eit_data: - raise NotEquivalent("Only one of the sequences contains EIT data") - - EITData.check_equivalence(a.eit_data, b.eit_data, raise_=raise_) - # TODO: add other attached objects for equivalence - - return True - - return False + def isequivalent( + self, + other: Self, + raise_: bool = False, + ) -> bool: + # fmt: off + checks = { + "Only one of the sequences contains EIT data.": bool(self.eit_data) is bool(other.eit_data), # both True or both False + "EITData is not equivalent.": EITData.isequivalent(self.eit_data, other.eit_data, raise_), + # TODO: add other attached objects for equivalence + } + # fmt: on + return super()._isequivalent(other, raise_, checks) def __add__(self, other: Sequence) -> Sequence: return self.concatenate(self, other) @@ -92,8 +65,8 @@ def concatenate( """Create a merge of two Sequence objects.""" # TODO: rewrite try: - Sequence.check_equivalence(a, b, raise_=True) - except NotEquivalent as e: + Sequence.isequivalent(a, b, raise_=True) + except EquivalenceError as e: raise type(e)(f"Sequences could not be merged: {e}") from e if a.eit_data and b.eit_data: diff --git a/eitprocessing/variants/__init__.py b/eitprocessing/variants/__init__.py index 3e0b2f122..36e1be356 100644 --- a/eitprocessing/variants/__init__.py +++ b/eitprocessing/variants/__init__.py @@ -1,14 +1,10 @@ -import contextlib from abc import ABC from abc import abstractmethod from dataclasses import dataclass from dataclasses import field -from typing import TypeVar from typing_extensions import Self -from ..helper import NotEquivalent - - -T = TypeVar("T", bound="Variant") +from eitprocessing.mixins.equality import Equivalence +from eitprocessing.mixins.equality import EquivalenceError @dataclass(eq=False) @@ -34,43 +30,20 @@ class Variant(Equivalence, ABC): description: str params: dict = field(default_factory=dict) - @staticmethod - def check_equivalence(a: T, b: T, raise_=False) -> bool: - """Check the equivalence of two variants - - For two variants to be equivalent, they need to have the same class, - the same label, the same description and the same parameters. Only the - actual data can differ between variants. - - Args: - - a (Variant) - - b (Variant) - - Raises: - - NotEquivalent (only if raise_ is `True`) when a and b are not - equivalent on one of the attributes - """ - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if not isinstance(a, b.__class__): - raise NotEquivalent( - f"Variant classes don't match: {a.__class__}, {b.__class__}" - ) - - if (a_ := a.label) != (b_ := b.label): - raise NotEquivalent(f"EITDataVariant names don't match: {a_}, {b_}") - - if (a_ := a.description) != (b_ := b.description): - raise NotEquivalent( - f"EITDataVariant descriptions don't match: {a_}, {b_}" - ) - - if (a_ := a.params) != (b_ := b.params): - raise NotEquivalent(f"EITDataVariant params don't match: {a_}, {b_}") - - return True - - return False + def isequivalent( + self, + other: Self, + raise_=True, + ) -> bool: + EDV = "EITDataVariant" + # fmt: off + checks = { + f"{EDV} labels don't match: {self.label}, {other.label}": self.label == other.label, + f"{EDV} descriptions don't match: {self.description}, {other.description}": self.description == other.description, + f"{EDV} params don't match: {self.params}, {other.params}": self.params == other.params, + } + # fmt: on + return super()._isequivalent(other, raise_, checks) @classmethod @abstractmethod @@ -88,5 +61,5 @@ def concatenate(cls, a: Self, b: Self) -> Self: - b (Variant) Raises: - - NotEquivalent if a and b are not equivalent and can't be merged + - EquivalenceError if a and b are not equivalent and can't be merged """ diff --git a/eitprocessing/variants/variant_collection.py b/eitprocessing/variants/variant_collection.py index e2a478c14..5f401441f 100644 --- a/eitprocessing/variants/variant_collection.py +++ b/eitprocessing/variants/variant_collection.py @@ -1,15 +1,15 @@ -import contextlib from typing import Generic from typing import TypeVar from typing_extensions import Self -from ..helper import NotEquivalent -from . import Variant +from eitprocessing.mixins.equality import Equivalence +from eitprocessing.mixins.equality import EquivalenceError +from eitprocessing.variants import Variant V = TypeVar("V", bound="Variant") -class VariantCollection(dict, Generic[V]): +class VariantCollection(dict, Equivalence, Generic[V]): """A collection of variants of a single type A VariantCollection is a dictionary with some added features. @@ -51,6 +51,9 @@ class VariantCollection(dict, Generic[V]): variant_type: type[V] + def __eq__(self, other): + return Equivalence.__eq__(self, other) + def __init__(self, variant_type: type[V], *args, **kwargs): self.variant_type = variant_type super().__init__(*args, **kwargs) @@ -96,8 +99,8 @@ def _check_variant(self, variant: V, key=None, overwrite: bool = False) -> None: @classmethod def concatenate(cls, a: Self, b: Self) -> Self: try: - cls.check_equivalence(a, b, raise_=True) - except NotEquivalent as e: + cls.isequivalent(a, b, raise_=True) + except EquivalenceError as e: raise ValueError("VariantCollections could not be concatenated") from e obj = VariantCollection(a.variant_type) @@ -106,26 +109,21 @@ def concatenate(cls, a: Self, b: Self) -> Self: return obj - @classmethod - def check_equivalence(cls, a: Self, b: Self, raise_=False) -> bool: - cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent) - with cm: - if a.variant_type != b.variant_type: - raise NotEquivalent( - f"Variant types do not match: {a.variant_type}, {b.variant_type}" - ) - - if set(a.keys()) != set(b.keys()): - raise NotEquivalent( - f"VariantCollections do not contain the same variants: {a.keys()=}, {b.keys()=}" - ) - - for key in a.keys(): - Variant.check_equivalence(a[key], b[key], raise_=True) - - return True - - return False + def isequivalent( + self, + other: Self, + raise_=True, + ) -> bool: + # fmt: off + checks = { + f"Variant types don't match: {self.variant_type}, {other.variant_type}": self.variant_type == other.variant_type, + f"VariantCollections do not contain the same variants: {self.keys()=}, {other.keys()=}": set(self.keys()) == set(other.keys()), + } + for key in self.keys(): + checks[f"Variant data ({key}) is not equivalent: {self[key]}, {other[key]}"] = \ + Variant.isequivalent(self[key], other[key], raise_) + # fmt: on + return super()._isequivalent(other, raise_, checks) class InvalidVariantType(TypeError): diff --git a/tests/mixins/test_eq.py b/tests/mixins/test_eq.py new file mode 100644 index 000000000..62e159c6e --- /dev/null +++ b/tests/mixins/test_eq.py @@ -0,0 +1,23 @@ +import bisect +import os +from dataclasses import dataclass +from dataclasses import is_dataclass +from pprint import pprint +import numpy as np +import pytest +from typing_extensions import Self +from eitprocessing.eit_data.draeger import DraegerEITData +from eitprocessing.eit_data.eit_data_variant import EITDataVariant +from eitprocessing.eit_data.vendor import Vendor +from eitprocessing.mixins.slicing import SelectByIndex + + +def test_eq(): + data = DraegerEITData.from_path( + "/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test3.bin" + ) + data2 = DraegerEITData.from_path( + "/home/dbodor/git/EIT-ALIVE/eitprocessing/tests/test_data/Draeger_Test3.bin" + ) + + data.isequivalent(data2)