Skip to content

Commit

Permalink
use generic equivalence method
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Feb 5, 2024
1 parent c4a2671 commit a443a81
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 191 deletions.
32 changes: 5 additions & 27 deletions eitprocessing/continuous_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib
from dataclasses import dataclass
from dataclasses import field
from dataclasses import dataclass, field
from typing import Any

import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self
Expand All @@ -16,6 +15,7 @@ class ContinuousData(Equivalence):
name: str
unit: str
description: str
category: str
time: NDArray
loaded: bool
calculated_from: Any | list[Any] | None = None
Expand All @@ -28,10 +28,11 @@ def __post_init__(self):
raise DataSourceUnknown(
"Data must be loaded or calculated form another dataset."
)
self._check_equivalence = ["unit", "category"]

@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

calculcated_from = None if a.loaded else [a.calculated_from, b.calculated_from]

Expand All @@ -48,29 +49,6 @@ def concatenate(cls, a: Self, b: Self) -> Self:
)
return obj

@classmethod
def check_equivalence(cls, a: Self, b: Self, raise_: bool = False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if a.name != b.name:
raise NotEquivalent(f"Names do not match: {a.name}, {b.name}")
if a.unit != b.unit:
raise NotEquivalent(f"Units do not match: {a.unit}, {b.unit}")
if a.description != b.description:
raise NotEquivalent(
f"Descriptions do not match: {a.description}, {b.description}"
)
if a.loaded != b.loaded:
raise NotEquivalent(
f"Only one of the datasets is loaded: {a.loaded=}, {b.loaded=}"
)

VariantCollection.check_equivalence(a.variants, b.variants, raise_=True)

return True

return False


class DataSourceUnknown(Exception):
"""Raised when the source of data is unknown."""
33 changes: 10 additions & 23 deletions eitprocessing/continuous_data/continuous_data_collection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import contextlib
from typing import Any

from typing_extensions import Self
from ..helper import NotEquivalent
from . import ContinuousData

from eitprocessing.continuous_data import ContinuousData
from eitprocessing.mixins.equality import Equivalence, EquivalenceError


class ContinuousDataCollection(dict):
class ContinuousDataCollection(dict, Equivalence):
def __setitem__(self, __key: Any, __value: Any) -> None:
self._check_data(__value, key=__key)
return super().__setitem__(__key, __value)
Expand All @@ -31,32 +32,18 @@ def _check_data(
@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
try:
cls.check_equivalence(a, b, raise_=True)
except NotEquivalent as e:
raise ValueError("VariantCollections could not be concatenated") from e
cls.isequivalent(a, b, raise_=True)
except EquivalenceError as e:
raise EquivalenceError(
"ContinuousDataCollections could not be concatenated"
) from e

obj = ContinuousDataCollection()
for key in a.keys() & b.keys():
obj.add(ContinuousData.concatenate(a[key], b[key]))

return obj

@classmethod
def check_equivalence(cls, a: Self, b: Self, raise_=False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if set(a.keys()) != set(b.keys()):
raise NotEquivalent(
f"VariantCollections do not contain the same variants: {a.keys()=}, {b.keys()=}"
)

for key in a.keys():
ContinuousData.check_equivalence(a[key], b[key], raise_=True)

return True

return False


class DuplicateContinuousDataName(Exception):
"""Raised when a variant with the same name already exists in the collection."""
39 changes: 9 additions & 30 deletions eitprocessing/eit_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
from __future__ import annotations
import contextlib
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import field

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import reduce
from pathlib import Path
from typing import TypeAlias
from typing import TypeVar
from typing import TypeAlias, TypeVar

import numpy as np
from numpy.typing import NDArray
from typing_extensions import Self
from typing_extensions import override
from typing_extensions import Self, override

from eitprocessing.continuous_data.continuous_data_collection import (
ContinuousDataCollection,
)
from eitprocessing.eit_data.eit_data_variant import EITDataVariant
from eitprocessing.eit_data.vendor import Vendor
from eitprocessing.helper import NotEquivalent
from eitprocessing.mixins.equality import Equivalence
from eitprocessing.mixins.slicing import SelectByTime
from eitprocessing.sparse_data.sparse_data_collection import SparseDataCollection
from eitprocessing.variants.variant_collection import VariantCollection


PathLike: TypeAlias = str | Path
PathArg: TypeAlias = PathLike | list[PathLike]
T = TypeVar("T", bound="EITData")
Expand All @@ -46,6 +42,7 @@ class EITData(SelectByTime, Equivalence, ABC):
def __post_init__(self):
if not self.label:
self.label = f"{self.__class__.__name__}_{id(self)}"
self._check_equivalence = ["vendor", "framerate"]

@classmethod
def from_path( # pylint: disable=too-many-arguments,too-many-locals
Expand Down Expand Up @@ -183,7 +180,7 @@ def _ensure_vendor(vendor: Vendor | str) -> Vendor:

@classmethod
def concatenate(cls, a: T, b: T, label: str | None = None) -> T:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

a_path = cls._ensure_path_list(a.path)
b_path = cls._ensure_path_list(b.path)
Expand All @@ -209,24 +206,6 @@ def concatenate(cls, a: T, b: T, label: str | None = None) -> T:
variants=variants,
)

@classmethod
def check_equivalence(cls, a: T, b: T, raise_=False) -> bool:
cm = contextlib.nullcontext() if raise_ else contextlib.suppress(NotEquivalent)
with cm:
if a.__class__ != b.__class__:
raise NotEquivalent(f"Classes don't match: {type(a)}, {type(b)}")

if a.framerate != b.framerate:
raise NotEquivalent(
f"Framerates do not match: {a.framerate}, {b.framerate}"
)

VariantCollection.check_equivalence(a.variants, b.variants, raise_=True)

return True

return False

def _sliced_copy(
self,
start_index: int,
Expand Down
2 changes: 1 addition & 1 deletion eitprocessing/eit_data/eit_data_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def global_impedance(self):

@classmethod
def concatenate(cls, a: Self, b: Self) -> Self:
cls.check_equivalence(a, b, raise_=True)
cls.isequivalent(a, b, raise_=True)

return cls(
label=a.label,
Expand Down
4 changes: 0 additions & 4 deletions eitprocessing/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
class NotEquivalent(Exception):
"""Raised when two objects are not equivalent."""


class NotConsecutive(Exception):
"""Raised when trying to concatenate non-consecutive datasets."""
Loading

0 comments on commit a443a81

Please sign in to comment.