diff --git a/pyobs_cloudcover/pipeline/intervall.py b/pyobs_cloudcover/pipeline/intervall.py new file mode 100644 index 0000000..b61dc4a --- /dev/null +++ b/pyobs_cloudcover/pipeline/intervall.py @@ -0,0 +1,48 @@ +from __future__ import annotations +from typing import Optional + + +class Interval(object): + def __init__(self, start: Optional[float] = None, end: Optional[float] = None): + self._start = start + self._end = end + + def __contains__(self, value: float) -> bool: + in_interval = True + + if self._start is not None: + in_interval &= self._start < value + + if self._end is not None: + in_interval &= value < self._end + + return in_interval + + def does_intersect(self, other: Interval) -> bool: + if (other._start is None and other._end is None) or (self._start is None and self._end is None): + return True + + if self == other: + return True + + does_intersect = False + + if other._start is not None: + does_intersect |= other._start in self + + if other._end is not None: + does_intersect |= other._end in self + + if self._start is not None: + does_intersect |= self._start in other + + if self._end is not None: + does_intersect |= self._end in other + + return does_intersect + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Interval): + return False + + return self._start == other._start and self._end == other._end diff --git a/pyobs_cloudcover/pipeline/night/pipeline.py b/pyobs_cloudcover/pipeline/night/pipeline.py index 97b9bd2..c70b1f2 100644 --- a/pyobs_cloudcover/pipeline/night/pipeline.py +++ b/pyobs_cloudcover/pipeline/night/pipeline.py @@ -25,11 +25,11 @@ def __init__(self, self._cloud_map_generator = cloud_map_generator self._coverage_info_calculator = coverage_info_calculator - def __call__(self, image: npt.NDArray[np.float_], obstime: datetime.datetime) -> CloudCoverageInfo: + def __call__(self, image: npt.NDArray[np.float_], obs_time: datetime.datetime) -> CloudCoverageInfo: preprocessed_image = self._preprocess(image) img_height, img_width = preprocessed_image.shape - catalog = self._catalog_constructor(obstime, img_height, img_width) + catalog = self._catalog_constructor(obs_time, img_height, img_width) matches = self._star_reverse_matcher(preprocessed_image, catalog) cloud_map = self._cloud_map_generator(catalog, matches, img_height, img_width) diff --git a/pyobs_cloudcover/pipeline/pipeline.py b/pyobs_cloudcover/pipeline/pipeline.py new file mode 100644 index 0000000..bd09e44 --- /dev/null +++ b/pyobs_cloudcover/pipeline/pipeline.py @@ -0,0 +1,13 @@ +import abc +import datetime + +import numpy as np +import numpy.typing as npt + +from pyobs_cloudcover.cloud_coverage_info import CloudCoverageInfo + + +class Pipeline(object, metaclass=abc.ABCMeta): + @abc.abstractmethod + def __call__(self, image: npt.NDArray[np.float_], obs_time: datetime.datetime) -> CloudCoverageInfo: + ... diff --git a/pyobs_cloudcover/pipeline/pipeline_controller.py b/pyobs_cloudcover/pipeline/pipeline_controller.py new file mode 100644 index 0000000..f02f17e --- /dev/null +++ b/pyobs_cloudcover/pipeline/pipeline_controller.py @@ -0,0 +1,44 @@ +import datetime +from typing import List, Optional + +import numpy as np +import numpy.typing as npt +from astroplan import Observer + +from pyobs_cloudcover.cloud_coverage_info import CloudCoverageInfo +from pyobs_cloudcover.pipeline.intervall import Interval +from pyobs_cloudcover.pipeline.pipeline import Pipeline + + +class PipelineController(object): + def __init__(self, pipelines: List[Pipeline], sun_alt_intervals: List[Interval], observer: Observer) -> None: + self._pipelines = pipelines + self._sun_alt_intervals = sun_alt_intervals + self._observer = observer + + self._check_arg_length() + self._check_interval_overlap() + + def _check_arg_length(self) -> None: + if len(self._pipelines) != len(self._sun_alt_intervals): + raise ValueError("Number of pipelines must equal the intervals") + + def _check_interval_overlap(self) -> None: + overlap = [ + other.does_intersect(interval) + for interval in self._sun_alt_intervals + for other in self._sun_alt_intervals + if other is not interval + ] + + if True in overlap: + raise ValueError("Sun altitude intervals can't overlap!") + + def __call__(self, image: npt.NDArray[np.float_], obs_time: datetime.datetime) -> Optional[CloudCoverageInfo]: + sun_alt = self._observer.sun_altaz(obs_time).alt.deg + + for pipeline, alt_interval in zip(self._pipelines, self._sun_alt_intervals): + if sun_alt in alt_interval: + return pipeline(image, obs_time) + + return None diff --git a/tests/unit/pipeline/test_interval.py b/tests/unit/pipeline/test_interval.py new file mode 100644 index 0000000..25188d9 --- /dev/null +++ b/tests/unit/pipeline/test_interval.py @@ -0,0 +1,74 @@ +from pyobs_cloudcover.pipeline.intervall import Interval + + +def test_upper_lower_bound(): + interval = Interval(start=0, end=10) + assert (5 in interval) == True + assert (-1 in interval) == False + + +def test_upper_bound(): + interval = Interval(start=None, end=10) + assert (-1 in interval) == True + assert (11 in interval) == False + + +def test_lower_bound(): + interval = Interval(start=0, end=None) + assert (-1 in interval) == False + assert (11 in interval) == True + + +def test_no_bound(): + interval = Interval(start=None, end=None) + assert (-1 in interval) == True + assert (11 in interval) == True + + +def test_intersect_none(): + first = Interval(start=0, end=10) + second = Interval(start=None, end=None) + + assert first.does_intersect(second) == True + + +def test_intersect_same(): + first = Interval(start=0, end=10) + second = Interval(start=0, end=10) + + assert first.does_intersect(second) == True + + +def test_intersect_superset(): + first = Interval(start=-10, end=20) + second = Interval(start=0, end=10) + + assert first.does_intersect(second) == True + + +def test_intersect_half_open(): + first = Interval(start=None, end=5) + second = Interval(start=0, end=None) + + assert first.does_intersect(second) == True + + +def test_equal_different(): + first = Interval(start=None, end=5) + second = Interval(start=0) + + assert (first == second) == False + + +def test_equal_same(): + first = Interval(start=None) + second = Interval(start=None) + + assert (first == second) == True + + +def test_equal_invalid(): + first = Interval(start=None) + second = object() + + assert (first == second) == False diff --git a/tests/unit/pipeline/test_pipeline_controller.py b/tests/unit/pipeline/test_pipeline_controller.py new file mode 100644 index 0000000..2b68ed5 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_controller.py @@ -0,0 +1,60 @@ +import datetime + +import astropy.units as u +import numpy as np +import pytest +from astroplan import Observer +from astropy.coordinates import SkyCoord +from numpy import typing as npt + +from pyobs_cloudcover.cloud_coverage_info import CloudCoverageInfo +from pyobs_cloudcover.pipeline.intervall import Interval +from pyobs_cloudcover.pipeline.pipeline import Pipeline +from pyobs_cloudcover.pipeline.pipeline_controller import PipelineController + + +class MockPipeline(Pipeline): + + def __call__(self, image: npt.NDArray[np.float_], obs_time: datetime.datetime) -> CloudCoverageInfo: + return CloudCoverageInfo(np.array([]), 0, 0, 0) + + +@pytest.fixture() +def observer(): + return Observer(latitude=51.559299 * u.deg, longitude=9.945472 * u.deg, elevation=201 * u.m) + + +def test_invalid_init_args_list(observer): + pipelines = [MockPipeline()] + interval = [] + + with pytest.raises(ValueError): + PipelineController(pipelines, interval, observer) + + +def test_invalid_init_overlapping_intervals(observer): + pipelines = [MockPipeline(), MockPipeline()] + interval = [Interval(None, 10), Interval(0, None)] + + with pytest.raises(ValueError): + PipelineController(pipelines, interval, observer) + + +def test_pipeline_call_outside_interval(mocker, observer): + mocker.patch.object(observer, "sun_altaz", return_value=SkyCoord(alt=10, az=0, frame="altaz", unit="deg")) + + pipelines = [MockPipeline()] + interval = [Interval(0, 1)] + controller = PipelineController(pipelines, interval, observer) + + assert controller(np.array([]), datetime.datetime.now()) is None + + +def test_pipeline_call_inside_interval(mocker, observer): + mocker.patch.object(observer, "sun_altaz", return_value=SkyCoord(alt=5, az=0, frame="altaz", unit="deg")) + + pipelines = [MockPipeline()] + interval = [Interval(0, 10)] + controller = PipelineController(pipelines, interval, observer) + + assert isinstance(controller(np.array([]), datetime.datetime.now()), CloudCoverageInfo)