Skip to content

Commit

Permalink
use equivalence check in child classes
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Dec 13, 2023
1 parent 68cd951 commit 7ba4d40
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 194 deletions.
51 changes: 22 additions & 29 deletions eitprocessing/continuous_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import contextlib
from dataclasses import dataclass
from dataclasses import field
from typing import Any
import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self
from ..helper import NotEquivalent
from ..variants.variant_collection import VariantCollection
from .continuous_data_variant import ContinuousDataVariant
from eitprocessing.continuous_data.continuous_data_variant import ContinuousDataVariant
from eitprocessing.mixins.equality import Equivalence
from eitprocessing.mixins.equality import EquivalenceError
from eitprocessing.variants.variant_collection import VariantCollection


@dataclass
class ContinuousData:
@dataclass(eq=False)
class ContinuousData(Equivalence):
name: str
unit: str
description: str
Expand All @@ -30,7 +30,7 @@ def __post_init__(self):

@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

calculcated_from = None if a.loaded else [a.calculated_from, b.calculated_from]

Expand All @@ -47,28 +47,21 @@ def concatenate(cls, a: Self, b: Self) -> Self:
)
return obj

@classmethod
def check_equivalence(cls, a: Self, b: Self, raise_: bool = False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if a.name != b.name:
raise NotEquivalent(f"Names do not match: {a.name}, {b.name}")
if a.unit != b.unit:
raise NotEquivalent(f"Units do not match: {a.unit}, {b.unit}")
if a.description != b.description:
raise NotEquivalent(
f"Descriptions do not match: {a.description}, {b.description}"
)
if a.loaded != b.loaded:
raise NotEquivalent(
f"Only one of the datasets is loaded: {a.loaded=}, {b.loaded=}"
)

VariantCollection.check_equivalence(a.variants, b.variants, raise_=True)

return True

return False
def isequivalent(
self,
other: Self,
raise_: bool = False,
) -> bool:
# fmt: off
checks = {
f"Names don't match: {self.name}, {other.name}.": self.name == other.name,
f"Units don't match: {self.unit}, {other.unit}.": self.unit == other.unit,
f"Descriptions don't match: {self.description}, {other.description}.": self.description == other.description,
f"Only one of the datasets is loaded: {self.loaded=}, {other.loaded=}.": self.loaded == other.loaded,
f"VariantCollections are not equivalent: {self.variants}, {other.variants}.": VariantCollection.isequivalent(self.variants,other.variants, raise_),
}
# fmt: on
return super()._isequivalent(other, raise_, checks)


class DataSourceUnknown(Exception):
Expand Down
40 changes: 19 additions & 21 deletions eitprocessing/continuous_data/continuous_data_collection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import contextlib
from typing import Any
from typing_extensions import Self
from ..helper import NotEquivalent
from . import ContinuousData
from eitprocessing.continuous_data import ContinuousData
from eitprocessing.mixins.equality import Equivalence
from eitprocessing.mixins.equality import EquivalenceError


class ContinuousDataCollection(dict):
class ContinuousDataCollection(dict, Equivalence):
def __setitem__(self, __key: Any, __value: Any) -> None:
self._check_data(__value, key=__key)
return super().__setitem__(__key, __value)
Expand All @@ -31,8 +31,8 @@ def _check_data(
@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
try:
cls.check_equivalence(a, b, raise_=True)
except NotEquivalent as e:
cls.isequivalent(a, b, raise_=True)
except EquivalenceError as e:
raise ValueError("VariantCollections could not be concatenated") from e

obj = ContinuousDataCollection()
Expand All @@ -41,21 +41,19 @@ def concatenate(cls, a: Self, b: Self) -> Self:

return obj

@classmethod
def check_equivalence(cls, a: Self, b: Self, raise_=False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if set(a.keys()) != set(b.keys()):
raise NotEquivalent(
f"VariantCollections do not contain the same variants: {a.keys()=}, {b.keys()=}"
)

for key in a.keys():
ContinuousData.check_equivalence(a[key], b[key], raise_=True)

return True

return False
def isequivalent(
self,
other: Self,
raise_: bool = False,
) -> bool:
# fmt: off
checks = {
f"VariantCollections do not contain the same variants: {self.keys()=}, {other.keys()=}": set(self.keys()) == set(other.keys()),
}
for key in self.keys():
checks[f"Continuous data ({key}) is not equivalent: {self[key]}, {other[key]}"] = ContinuousData.isequivalent(self[key], other[key], raise_)
# fmt: on
return super()._isequivalent(other, raise_, checks)


class DuplicateContinuousDataName(Exception):
Expand Down
2 changes: 1 addition & 1 deletion eitprocessing/continuous_data/continuous_data_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..variants import Variant


@dataclass
@dataclass(eq=False)
class ContinuousDataVariant(Variant):
values: NDArray = field(kw_only=True)

Expand Down
33 changes: 13 additions & 20 deletions eitprocessing/eit_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
import contextlib
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
Expand All @@ -17,7 +16,6 @@
)
from eitprocessing.eit_data.eit_data_variant import EITDataVariant
from eitprocessing.eit_data.vendor import Vendor
from eitprocessing.helper import NotEquivalent
from eitprocessing.mixins.equality import Equivalence
from eitprocessing.mixins.slicing import SelectByTime
from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection
Expand Down Expand Up @@ -183,7 +181,7 @@ def _ensure_vendor(vendor: Vendor | str) -> Vendor:

@classmethod
def concatenate(cls, a: T, b: T, label: str | None = None) -> T:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

a_path = cls._ensure_path_list(a.path)
b_path = cls._ensure_path_list(b.path)
Expand All @@ -209,23 +207,18 @@ def concatenate(cls, a: T, b: T, label: str | None = None) -> T:
variants=variants,
)

@classmethod
def check_equivalence(cls, a: T, b: T, raise_=False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if a.__class__ != b.__class__:
raise NotEquivalent(f"Classes don't match: {type(a)}, {type(b)}")

if a.framerate != b.framerate:
raise NotEquivalent(
f"Framerates do not match: {a.framerate}, {b.framerate}"
)

VariantCollection.check_equivalence(a.variants, b.variants, raise_=True)

return True

return False
def isequivalent(
self,
other: Self,
raise_: bool = False,
) -> bool:
# fmt: off
checks = {
f"Framerates don't match: {self.framerate}, {other.framerate}": self.framerate == other.framerate,
"VariantCollections are not equivalent": VariantCollection.isequivalent(self.variants, other.variants, raise_),
}
# fmt: on
return super()._isequivalent(other, raise_, checks)

def _sliced_copy(
self,
Expand Down
2 changes: 1 addition & 1 deletion eitprocessing/eit_data/eit_data_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def global_impedance(self):

@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

return cls(
label=a.label,
Expand Down
2 changes: 1 addition & 1 deletion eitprocessing/eit_data/timpel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from . import EITData_


@dataclass
@dataclass(eq=False)
class TimpelEITData(EITData_):
framerate: float = 50
vendor: Vendor = field(default=Vendor.TIMPEL, init=False)
Expand Down
4 changes: 0 additions & 4 deletions eitprocessing/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
class NotEquivalent(Exception):
"""Raised when two objects are not equivalent."""


class NotConsecutive(Exception):
"""Raised when trying to concatenate non-consecutive datasets."""
Loading

0 comments on commit 7ba4d40

Please sign in to comment.