From ac7b666266cf2d54eabdba927be89f3dfe03b0d9 Mon Sep 17 00:00:00 2001 From: devsjc <47188100+devsjc@users.noreply.github.com> Date: Fri, 22 Nov 2024 14:01:13 +0000 Subject: [PATCH] style(service): Refactor tests --- .../internal/services/_dummy_adaptors.py | 89 +++++++++++++++++++ .../internal/services/test_archiver.py | 42 ++++++++- .../internal/services/test_consumer.py | 82 +---------------- 3 files changed, 131 insertions(+), 82 deletions(-) create mode 100644 src/nwp_consumer/internal/services/_dummy_adaptors.py diff --git a/src/nwp_consumer/internal/services/_dummy_adaptors.py b/src/nwp_consumer/internal/services/_dummy_adaptors.py new file mode 100644 index 00000000..f63ab097 --- /dev/null +++ b/src/nwp_consumer/internal/services/_dummy_adaptors.py @@ -0,0 +1,89 @@ + +from collections.abc import Callable, Iterator +from typing import override + +import datatime as dt +import numpy as np +import xarray as xr +from joblib import delayed +from returns.result import ResultE, Success + +from nwp_consumer.internal import entities, ports + + +class DummyModelRepository(ports.ModelRepository): + + @classmethod + @override + def authenticate(cls) -> ResultE["DummyModelRepository"]: + return Success(cls()) + + @staticmethod + @override + def repository() -> entities.ModelRepositoryMetadata: + return entities.ModelRepositoryMetadata( + name="ACME-Test-Models", + is_archive=False, + is_order_based=False, + running_hours=[0, 6, 12, 18], + delay_minutes=60, + max_connections=4, + required_env=[], + optional_env={}, + postprocess_options=entities.PostProcessOptions(), + ) + + @staticmethod + @override + def model() -> entities.ModelMetadata: + return entities.ModelMetadata( + name="simple-random", + resolution="17km", + expected_coordinates=entities.NWPDimensionCoordinateMap( + init_time=[dt.datetime(2021, 1, 1, 0, 0, tzinfo=dt.UTC)], + step=list(range(0, 48, 1)), + variable=[ + entities.Parameter.TEMPERATURE_SL, + entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL, + entities.Parameter.CLOUD_COVER_HIGH, + ], + latitude=np.linspace(90, -90, 721).tolist(), + longitude=np.linspace(-180, 179.8, 1440).tolist(), + ), + ) + + + @override + def fetch_init_data(self, it: dt.datetime) \ + -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: + + def gen_dataset(step: int, variable: str) -> ResultE[list[xr.DataArray]]: + """Define a generator that provides one variable at one step.""" + da = xr.DataArray( + name=self.model().name, + dims=["init_time", "step", "variable", "latitude", "longitude"], + data=np.random.rand(1, 1, 1, 721, 1440), + coords=self.model().expected_coordinates.to_pandas() | { + "init_time": [np.datetime64(it.replace(tzinfo=None), "ns")], + "step": [step], + "variable": [variable], + }, + ) + return Success([da]) + + + for s in self.model().expected_coordinates.step: + for v in self.model().expected_coordinates.variable: + yield delayed(gen_dataset)(s, v.value) + + +class DummyNotificationRepository(ports.NotificationRepository): + + @override + def notify( + self, + message: entities.StoreAppendedNotification | entities.StoreCreatedNotification, + ) -> ResultE[str]: + return Success(str(message)) + + diff --git a/src/nwp_consumer/internal/services/test_archiver.py b/src/nwp_consumer/internal/services/test_archiver.py index dc79346c..0bf4cbe3 100644 --- a/src/nwp_consumer/internal/services/test_archiver.py +++ b/src/nwp_consumer/internal/services/test_archiver.py @@ -1 +1,41 @@ -# TODO: tests for archiver +import datetime as dt +import shutil +import unittest + +import xarray as xr +from returns.pipeline import is_successful + +from nwp_consumer.internal.services.archiver_service import ArchiverService + +from ._dummy_adaptors import DummyModelRepository, DummyNotificationRepository + + +class TestParallelConsumer(unittest.TestCase): + + @unittest.skip("Takes an age to run, need to figure out a better way.") + def test_archive(self) -> None: + """Test the consume method of the ParallelConsumer class.""" + + test_consumer = ArchiverService( + model_repository=DummyModelRepository, + notification_repository=DummyNotificationRepository, + ) + + result = test_consumer.archive(year=2021, month=1) + + self.assertTrue(is_successful(result), msg=result) + + da: xr.DataArray = xr.open_dataarray(result.unwrap(), engine="zarr") + + self.assertEqual( + list(da.sizes.keys()), + ["init_time", "step", "variable", "latitude", "longitude"], + ) + + path = result.unwrap() + shutil.rmtree(path) + + +if __name__ == "__main__": + unittest.main() + diff --git a/src/nwp_consumer/internal/services/test_consumer.py b/src/nwp_consumer/internal/services/test_consumer.py index 8936835d..85baea11 100644 --- a/src/nwp_consumer/internal/services/test_consumer.py +++ b/src/nwp_consumer/internal/services/test_consumer.py @@ -1,93 +1,13 @@ import datetime as dt import shutil import unittest -from collections.abc import Callable, Iterator -from typing import override -import numpy as np import xarray as xr -from joblib import delayed from returns.pipeline import is_successful -from returns.result import ResultE, Success -from nwp_consumer.internal import entities, ports from nwp_consumer.internal.services.consumer_service import ConsumerService - -class DummyModelRepository(ports.ModelRepository): - - @classmethod - @override - def authenticate(cls) -> ResultE["DummyModelRepository"]: - return Success(cls()) - - @staticmethod - @override - def repository() -> entities.ModelRepositoryMetadata: - return entities.ModelRepositoryMetadata( - name="ACME-Test-Models", - is_archive=False, - is_order_based=False, - running_hours=[0, 6, 12, 18], - delay_minutes=60, - max_connections=4, - required_env=[], - optional_env={}, - postprocess_options=entities.PostProcessOptions(), - ) - - @staticmethod - @override - def model() -> entities.ModelMetadata: - return entities.ModelMetadata( - name="simple-random", - resolution="17km", - expected_coordinates=entities.NWPDimensionCoordinateMap( - init_time=[dt.datetime(2021, 1, 1, 0, 0, tzinfo=dt.UTC)], - step=list(range(0, 48, 1)), - variable=[ - entities.Parameter.TEMPERATURE_SL, - entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL, - entities.Parameter.CLOUD_COVER_HIGH, - ], - latitude=np.linspace(90, -90, 721).tolist(), - longitude=np.linspace(-180, 179.8, 1440).tolist(), - ), - ) - - - @override - def fetch_init_data(self, it: dt.datetime) \ - -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: - - def gen_dataset(step: int, variable: str) -> ResultE[list[xr.DataArray]]: - """Define a generator that provides one variable at one step.""" - da = xr.DataArray( - name=self.model().name, - dims=["init_time", "step", "variable", "latitude", "longitude"], - data=np.random.rand(1, 1, 1, 721, 1440), - coords=self.model().expected_coordinates.to_pandas() | { - "init_time": [np.datetime64(it.replace(tzinfo=None), "ns")], - "step": [step], - "variable": [variable], - }, - ) - return Success([da]) - - - for s in self.model().expected_coordinates.step: - for v in self.model().expected_coordinates.variable: - yield delayed(gen_dataset)(s, v.value) - - -class DummyNotificationRepository(ports.NotificationRepository): - - @override - def notify( - self, - message: entities.StoreAppendedNotification | entities.StoreCreatedNotification, - ) -> ResultE[str]: - return Success(str(message)) +from ._dummy_adaptors import DummyModelRepository, DummyNotificationRepository class TestParallelConsumer(unittest.TestCase):