Skip to content

Commit

Permalink
use parent isequivalent method in child classes
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Dec 7, 2023
1 parent b16547d commit df95c03
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 192 deletions.
49 changes: 21 additions & 28 deletions eitprocessing/continuous_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import contextlib
from dataclasses import dataclass
from dataclasses import field
from typing import Any
import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self
from ..helper import NotEquivalent
from ..variants.variant_collection import VariantCollection
from .continuous_data_variant import ContinuousDataVariant
from eitprocessing.continuous_data.continuous_data_variant import ContinuousDataVariant
from eitprocessing.mixins.equality import Equivalence
from eitprocessing.mixins.equality import EquivalenceError
from eitprocessing.variants.variant_collection import VariantCollection


@dataclass
class ContinuousData:
class ContinuousData(Equivalence):
name: str
unit: str
description: str
Expand All @@ -30,7 +30,7 @@ def __post_init__(self):

@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

calculcated_from = None if a.loaded else [a.calculated_from, b.calculated_from]

Expand All @@ -47,28 +47,21 @@ def concatenate(cls, a: Self, b: Self) -> Self:
)
return obj

@classmethod
def check_equivalence(cls, a: Self, b: Self, raise_: bool = False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if a.name != b.name:
raise NotEquivalent(f"Names do not match: {a.name}, {b.name}")
if a.unit != b.unit:
raise NotEquivalent(f"Units do not match: {a.unit}, {b.unit}")
if a.description != b.description:
raise NotEquivalent(
f"Descriptions do not match: {a.description}, {b.description}"
)
if a.loaded != b.loaded:
raise NotEquivalent(
f"Only one of the datasets is loaded: {a.loaded=}, {b.loaded=}"
)

VariantCollection.check_equivalence(a.variants, b.variants, raise_=True)

return True

return False
def isequivalent(
self,
other: Self,
raise_: bool = False,
) -> bool:
# fmt: off
checks = {
f"Names don't match: {self.name}, {other.name}.": self.name == other.name,
f"Units don't match: {self.unit}, {other.unit}.": self.unit == other.unit,
f"Descriptions don't match: {self.description}, {other.description}.": self.description == other.description,
f"Only one of the datasets is loaded: {self.loaded=}, {other.loaded=}.": self.loaded == other.loaded,
f"VariantCollections are not equivalent: {self.variants}, {other.variants}.": VariantCollection.isequivalent(self.variants,other.variants),
}
# fmt: on
return super().isequivalent(other, raise_, checks)


class DataSourceUnknown(Exception):
Expand Down
41 changes: 20 additions & 21 deletions eitprocessing/continuous_data/continuous_data_collection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import contextlib
from typing import Any
from typing_extensions import Self
from ..helper import NotEquivalent
from . import ContinuousData
from eitprocessing.continuous_data import ContinuousData
from eitprocessing.mixins.equality import Equivalence
from eitprocessing.mixins.equality import EquivalenceError


class ContinuousDataCollection(dict):
class ContinuousDataCollection(dict, Equivalence):
def __setitem__(self, __key: Any, __value: Any) -> None:
self._check_data(__value, key=__key)
return super().__setitem__(__key, __value)
Expand All @@ -31,8 +31,8 @@ def _check_data(
@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
try:
cls.check_equivalence(a, b, raise_=True)
except NotEquivalent as e:
cls.isequivalent(a, b, raise_=True)
except EquivalenceError as e:
raise ValueError("VariantCollections could not be concatenated") from e

obj = ContinuousDataCollection()
Expand All @@ -41,21 +41,20 @@ def concatenate(cls, a: Self, b: Self) -> Self:

return obj

@classmethod
def check_equivalence(cls, a: Self, b: Self, raise_=False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if set(a.keys()) != set(b.keys()):
raise NotEquivalent(
f"VariantCollections do not contain the same variants: {a.keys()=}, {b.keys()=}"
)

for key in a.keys():
ContinuousData.check_equivalence(a[key], b[key], raise_=True)

return True

return False
def isequivalent(
self,
other: Self,
raise_: bool = False,
) -> bool:
# fmt: off
checks = {
f"VariantCollections do not contain the same variants: {self.keys()=}, {other.keys()=}": set(self.keys()) == set(other.keys()),
}
for key in self.keys():
checks[f"Continuous data ({key}) is not equivalent: {self[key]}, {other[key]}"] = \
ContinuousData.isequivalent(self[key], other[key])
# fmt: on
return super().isequivalent(other, raise_, checks)


class DuplicateContinuousDataName(Exception):
Expand Down
34 changes: 14 additions & 20 deletions eitprocessing/eit_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
import contextlib
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
Expand All @@ -17,8 +16,8 @@
)
from eitprocessing.eit_data.eit_data_variant import EITDataVariant
from eitprocessing.eit_data.vendor import Vendor
from eitprocessing.helper import NotEquivalent
from eitprocessing.mixins.equality import Equivalence
from eitprocessing.mixins.equality import EquivalenceError
from eitprocessing.mixins.slicing import SelectByTime
from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection
from eitprocessing.variants.variant_collection import VariantCollection
Expand Down Expand Up @@ -183,7 +182,7 @@ def _ensure_vendor(vendor: Vendor | str) -> Vendor:

@classmethod
def concatenate(cls, a: T, b: T, label: str | None = None) -> T:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

a_path = cls._ensure_path_list(a.path)
b_path = cls._ensure_path_list(b.path)
Expand All @@ -209,23 +208,18 @@ def concatenate(cls, a: T, b: T, label: str | None = None) -> T:
variants=variants,
)

@classmethod
def check_equivalence(cls, a: T, b: T, raise_=False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if a.__class__ != b.__class__:
raise NotEquivalent(f"Classes don't match: {type(a)}, {type(b)}")

if a.framerate != b.framerate:
raise NotEquivalent(
f"Framerates do not match: {a.framerate}, {b.framerate}"
)

VariantCollection.check_equivalence(a.variants, b.variants, raise_=True)

return True

return False
def isequivalent(
self,
other: Self,
raise_: bool = False,
) -> bool:
# fmt: off
checks = {
f"Framerates don't match: {self.framerate}, {other.framerate}": self.framerate == other.framerate,
"VariantCollections are not equivalent": VariantCollection.isequivalent(self.variants, other.variants),
}
# fmt: on
return super().isequivalent(other, raise_, checks)

def _sliced_copy(
self,
Expand Down
2 changes: 1 addition & 1 deletion eitprocessing/eit_data/eit_data_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def global_impedance(self):

@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

return cls(
label=a.label,
Expand Down
4 changes: 0 additions & 4 deletions eitprocessing/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
class NotEquivalent(Exception):
"""Raised when two objects are not equivalent."""


class NotConsecutive(Exception):
"""Raised when trying to concatenate non-consecutive datasets."""
Loading

0 comments on commit df95c03

Please sign in to comment.