-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1a6534e
commit 2793b79
Showing
6 changed files
with
241 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from __future__ import annotations | ||
from typing import Optional | ||
|
||
|
||
class Interval(object): | ||
def __init__(self, start: Optional[float] = None, end: Optional[float] = None): | ||
self._start = start | ||
self._end = end | ||
|
||
def __contains__(self, value: float) -> bool: | ||
in_interval = True | ||
|
||
if self._start is not None: | ||
in_interval &= self._start < value | ||
|
||
if self._end is not None: | ||
in_interval &= value < self._end | ||
|
||
return in_interval | ||
|
||
def does_intersect(self, other: Interval) -> bool: | ||
if (other._start is None and other._end is None) or (self._start is None and self._end is None): | ||
return True | ||
|
||
if self == other: | ||
return True | ||
|
||
does_intersect = False | ||
|
||
if other._start is not None: | ||
does_intersect |= other._start in self | ||
|
||
if other._end is not None: | ||
does_intersect |= other._end in self | ||
|
||
if self._start is not None: | ||
does_intersect |= self._start in other | ||
|
||
if self._end is not None: | ||
does_intersect |= self._end in other | ||
|
||
return does_intersect | ||
|
||
def __eq__(self, other: object) -> bool: | ||
if not isinstance(other, Interval): | ||
return False | ||
|
||
return self._start == other._start and self._end == other._end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import abc | ||
import datetime | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
from pyobs_cloudcover.cloud_coverage_info import CloudCoverageInfo | ||
|
||
|
||
class Pipeline(object, metaclass=abc.ABCMeta): | ||
@abc.abstractmethod | ||
def __call__(self, image: npt.NDArray[np.float_], obs_time: datetime.datetime) -> CloudCoverageInfo: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import datetime | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
from astroplan import Observer | ||
|
||
from pyobs_cloudcover.cloud_coverage_info import CloudCoverageInfo | ||
from pyobs_cloudcover.pipeline.intervall import Interval | ||
from pyobs_cloudcover.pipeline.pipeline import Pipeline | ||
|
||
|
||
class PipelineController(object): | ||
def __init__(self, pipelines: List[Pipeline], sun_alt_intervals: List[Interval], observer: Observer) -> None: | ||
self._pipelines = pipelines | ||
self._sun_alt_intervals = sun_alt_intervals | ||
self._observer = observer | ||
|
||
self._check_arg_length() | ||
self._check_interval_overlap() | ||
|
||
def _check_arg_length(self) -> None: | ||
if len(self._pipelines) != len(self._sun_alt_intervals): | ||
raise ValueError("Number of pipelines must equal the intervals") | ||
|
||
def _check_interval_overlap(self) -> None: | ||
overlap = [ | ||
other.does_intersect(interval) | ||
for interval in self._sun_alt_intervals | ||
for other in self._sun_alt_intervals | ||
if other is not interval | ||
] | ||
|
||
if True in overlap: | ||
raise ValueError("Sun altitude intervals can't overlap!") | ||
|
||
def __call__(self, image: npt.NDArray[np.float_], obs_time: datetime.datetime) -> Optional[CloudCoverageInfo]: | ||
sun_alt = self._observer.sun_altaz(obs_time).alt.deg | ||
|
||
for pipeline, alt_interval in zip(self._pipelines, self._sun_alt_intervals): | ||
if sun_alt in alt_interval: | ||
return pipeline(image, obs_time) | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from pyobs_cloudcover.pipeline.intervall import Interval | ||
|
||
|
||
def test_upper_lower_bound(): | ||
interval = Interval(start=0, end=10) | ||
assert (5 in interval) == True | ||
assert (-1 in interval) == False | ||
|
||
|
||
def test_upper_bound(): | ||
interval = Interval(start=None, end=10) | ||
assert (-1 in interval) == True | ||
assert (11 in interval) == False | ||
|
||
|
||
def test_lower_bound(): | ||
interval = Interval(start=0, end=None) | ||
assert (-1 in interval) == False | ||
assert (11 in interval) == True | ||
|
||
|
||
def test_no_bound(): | ||
interval = Interval(start=None, end=None) | ||
assert (-1 in interval) == True | ||
assert (11 in interval) == True | ||
|
||
|
||
def test_intersect_none(): | ||
first = Interval(start=0, end=10) | ||
second = Interval(start=None, end=None) | ||
|
||
assert first.does_intersect(second) == True | ||
|
||
|
||
def test_intersect_same(): | ||
first = Interval(start=0, end=10) | ||
second = Interval(start=0, end=10) | ||
|
||
assert first.does_intersect(second) == True | ||
|
||
|
||
def test_intersect_superset(): | ||
first = Interval(start=-10, end=20) | ||
second = Interval(start=0, end=10) | ||
|
||
assert first.does_intersect(second) == True | ||
|
||
|
||
def test_intersect_half_open(): | ||
first = Interval(start=None, end=5) | ||
second = Interval(start=0, end=None) | ||
|
||
assert first.does_intersect(second) == True | ||
|
||
|
||
def test_equal_different(): | ||
first = Interval(start=None, end=5) | ||
second = Interval(start=0) | ||
|
||
assert (first == second) == False | ||
|
||
|
||
def test_equal_same(): | ||
first = Interval(start=None) | ||
second = Interval(start=None) | ||
|
||
assert (first == second) == True | ||
|
||
|
||
def test_equal_invalid(): | ||
first = Interval(start=None) | ||
second = object() | ||
|
||
assert (first == second) == False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import datetime | ||
|
||
import astropy.units as u | ||
import numpy as np | ||
import pytest | ||
from astroplan import Observer | ||
from astropy.coordinates import SkyCoord | ||
from numpy import typing as npt | ||
|
||
from pyobs_cloudcover.cloud_coverage_info import CloudCoverageInfo | ||
from pyobs_cloudcover.pipeline.intervall import Interval | ||
from pyobs_cloudcover.pipeline.pipeline import Pipeline | ||
from pyobs_cloudcover.pipeline.pipeline_controller import PipelineController | ||
|
||
|
||
class MockPipeline(Pipeline): | ||
|
||
def __call__(self, image: npt.NDArray[np.float_], obs_time: datetime.datetime) -> CloudCoverageInfo: | ||
return CloudCoverageInfo(np.array([]), 0, 0, 0) | ||
|
||
|
||
@pytest.fixture() | ||
def observer(): | ||
return Observer(latitude=51.559299 * u.deg, longitude=9.945472 * u.deg, elevation=201 * u.m) | ||
|
||
|
||
def test_invalid_init_args_list(observer): | ||
pipelines = [MockPipeline()] | ||
interval = [] | ||
|
||
with pytest.raises(ValueError): | ||
PipelineController(pipelines, interval, observer) | ||
|
||
|
||
def test_invalid_init_overlapping_intervals(observer): | ||
pipelines = [MockPipeline(), MockPipeline()] | ||
interval = [Interval(None, 10), Interval(0, None)] | ||
|
||
with pytest.raises(ValueError): | ||
PipelineController(pipelines, interval, observer) | ||
|
||
|
||
def test_pipeline_call_outside_interval(mocker, observer): | ||
mocker.patch.object(observer, "sun_altaz", return_value=SkyCoord(alt=10, az=0, frame="altaz", unit="deg")) | ||
|
||
pipelines = [MockPipeline()] | ||
interval = [Interval(0, 1)] | ||
controller = PipelineController(pipelines, interval, observer) | ||
|
||
assert controller(np.array([]), datetime.datetime.now()) is None | ||
|
||
|
||
def test_pipeline_call_inside_interval(mocker, observer): | ||
mocker.patch.object(observer, "sun_altaz", return_value=SkyCoord(alt=5, az=0, frame="altaz", unit="deg")) | ||
|
||
pipelines = [MockPipeline()] | ||
interval = [Interval(0, 10)] | ||
controller = PipelineController(pipelines, interval, observer) | ||
|
||
assert isinstance(controller(np.array([]), datetime.datetime.now()), CloudCoverageInfo) |