diff --git a/pyproject.toml b/pyproject.toml index 539c0710..866156d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "numpy == 2.1.0", "ocf-blosc2 == 0.0.11", "psutil == 6.0.0", - "returns == 0.23.0", + "returns == 0.24.0", "s3fs == 2024.9.0", "xarray == 2024.9.0", "zarr == 2.18.3" diff --git a/src/nwp_consumer/cmd/main.py b/src/nwp_consumer/cmd/main.py index 4e6c166c..43e3d27c 100644 --- a/src/nwp_consumer/cmd/main.py +++ b/src/nwp_consumer/cmd/main.py @@ -5,7 +5,7 @@ import sys from typing import NamedTuple -from nwp_consumer.internal import handlers, ports, repositories, services +from nwp_consumer.internal import handlers, ports, repositories log = logging.getLogger("nwp-consumer") @@ -61,17 +61,10 @@ def parse_env() -> Adaptors: def run_cli() -> None: """Entrypoint for the CLI handler.""" - # TODO: InfoUseCase adaptors = parse_env() c = handlers.CLIHandler( - consumer_usecase=services.ConsumerService( - model_repository=adaptors.model_repository, - notification_repository=adaptors.notification_repository, - ), - archiver_usecase=services.ArchiverService( - model_repository=adaptors.model_repository, - notification_repository=adaptors.notification_repository, - ), + model_adaptor=adaptors.model_repository, + notification_adaptor=adaptors.notification_repository, ) returncode: int = c.run() sys.exit(returncode) diff --git a/src/nwp_consumer/internal/entities/performance.py b/src/nwp_consumer/internal/entities/performance.py index 4b36fe27..c241850f 100644 --- a/src/nwp_consumer/internal/entities/performance.py +++ b/src/nwp_consumer/internal/entities/performance.py @@ -23,7 +23,7 @@ class PerformanceMonitor(Thread): memory_buffer: list[int] cpu_buffer: list[float] start_time: float - end_time: float + end_time: float | None stop: bool = True def __enter__(self) -> None: @@ -60,6 +60,8 @@ def get_usage(self) -> tuple[int, float]: def get_runtime(self) -> int: """Get the runtime of the thread in seconds.""" + if self.end_time is None: + return int(time.time() - self.start_time) return int(self.end_time - self.start_time) def run(self) -> None: diff --git a/src/nwp_consumer/internal/handlers/cli.py b/src/nwp_consumer/internal/handlers/cli.py index e69956fa..aeec6654 100644 --- a/src/nwp_consumer/internal/handlers/cli.py +++ b/src/nwp_consumer/internal/handlers/cli.py @@ -4,9 +4,9 @@ import datetime as dt import logging -from returns.result import Failure, Success +from returns.result import Failure, ResultE -from nwp_consumer.internal import ports +from nwp_consumer.internal import ports, services log = logging.getLogger("nwp-consumer") @@ -14,15 +14,17 @@ class CLIHandler: """CLI driving actor.""" + model_adaptor: type[ports.ModelRepository] + notification_adaptor: type[ports.NotificationRepository] + def __init__( self, - consumer_usecase: ports.ConsumeUseCase, - archiver_usecase: ports.ArchiveUseCase, - ) -> None: + model_adaptor: type[ports.ModelRepository], + notification_adaptor: type[ports.NotificationRepository], + ) -> None: """Create a new instance.""" - self._consumer_usecase = consumer_usecase - self._archiver_usecase = archiver_usecase - + self.model_adaptor = model_adaptor + self.notification_adaptor = notification_adaptor @property def parser(self) -> argparse.ArgumentParser: @@ -82,30 +84,35 @@ def run(self) -> int: args = self.parser.parse_args() match args.command: case "consume": - result = self._consumer_usecase.consume(it=args.init_time) - - match result: - case Failure(e): - log.error(f"Failed to consume NWP data: {e}") - return 1 - case Success(path): - log.info(f"Successfully consumed NWP data to '{path}'") - return 0 + service_result = services.ConsumerService.from_adaptors( + model_adaptor=self.model_adaptor, + notification_adaptor=self.notification_adaptor, + ) + result: ResultE[str] = service_result.do( + consume_result + for service in service_result + for consume_result in service.consume(period=args.init_time) + ) + if isinstance(result, Failure): + log.error(f"Failed to consume NWP data: {result!s}") + return 1 case "archive": - result = self._archiver_usecase.archive(year=args.year, month=args.month) - - match result: - case Failure(e): - log.error(f"Failed to archive NWP data: {e}") - return 1 - case Success(path): - log.info(f"Successfully archived NWP data to '{path}'") - return 0 + service_result = services.ConsumerService.from_adaptors( + model_adaptor=self.model_adaptor, + notification_adaptor=self.notification_adaptor, + ) + result = service_result.do( + consume_result + for service in service_result + for consume_result in service.consume(period=args.init_time) + ) + if isinstance(result, Failure): + log.error(f"Failed to archive NWP data: {result!s}") + return 1 case "info": log.error("Info command is coming soon! :)") - return 0 case _: log.error(f"Unknown command: {args.command}") diff --git a/src/nwp_consumer/internal/ports/__init__.py b/src/nwp_consumer/internal/ports/__init__.py index a51e2b33..1d7b8f41 100644 --- a/src/nwp_consumer/internal/ports/__init__.py +++ b/src/nwp_consumer/internal/ports/__init__.py @@ -7,13 +7,11 @@ in the `repositories` module. """ -from .services import ConsumeUseCase, ArchiveUseCase -from .repositories import ModelRepository, ZarrRepository, NotificationRepository +from .services import ConsumeUseCase +from .repositories import ModelRepository, NotificationRepository __all__ = [ "ConsumeUseCase", - "ArchiveUseCase", "ModelRepository", - "ZarrRepository", "NotificationRepository", ] diff --git a/src/nwp_consumer/internal/ports/repositories.py b/src/nwp_consumer/internal/ports/repositories.py index 4ac0094b..cfa9e57c 100644 --- a/src/nwp_consumer/internal/ports/repositories.py +++ b/src/nwp_consumer/internal/ports/repositories.py @@ -14,7 +14,6 @@ import abc import datetime as dt import logging -import pathlib from collections.abc import Callable, Iterator import xarray as xr @@ -124,16 +123,6 @@ def model() -> entities.ModelMetadata: """Metadata about the model.""" pass - -class ZarrRepository(abc.ABC): - """Interface for a repository that stores Zarr NWP data.""" - - @abc.abstractmethod - def save(self, src: pathlib.Path, dst: pathlib.Path) -> ResultE[str]: - """Save NWP store data in the repository.""" - pass - - class NotificationRepository(abc.ABC): """Interface for a repository that sends notifications. diff --git a/src/nwp_consumer/internal/ports/services.py b/src/nwp_consumer/internal/ports/services.py index f19a1888..6bb88dc1 100644 --- a/src/nwp_consumer/internal/ports/services.py +++ b/src/nwp_consumer/internal/ports/services.py @@ -3,7 +3,7 @@ These interfaces define the signatures that *driving* actors must conform to in order to interact with the core. -Also sometimes referred to as *primary ports*. +Sometimes referred to as *primary ports*. """ import abc @@ -11,8 +11,6 @@ from returns.result import ResultE -from nwp_consumer.internal import entities - class ConsumeUseCase(abc.ABC): """Interface for the consumer use case. @@ -24,16 +22,15 @@ class ConsumeUseCase(abc.ABC): @abc.abstractmethod - def consume(self, it: dt.datetime | None = None) -> ResultE[str]: - """Consume NWP data to Zarr format for desired init time. + def consume(self, period: dt.datetime | dt.date | None = None) -> ResultE[str]: + """Consume NWP data to Zarr format for desired time period. Where possible the implementation should be as memory-efficient as possible. The designs of the repository methods also enable parallel processing within the implementation. Args: - it: The initialization time for which to consume data. - If None, the latest available forecast should be consumed. + period: The period for which to gather init time data. Returns: The path to the produced Zarr store. @@ -46,51 +43,13 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[str]: pass @abc.abstractmethod - def postprocess(self, options: entities.PostProcessOptions) -> ResultE[str]: - """Postprocess the produced Zarr according to given options.""" - pass - - -class ArchiveUseCase(abc.ABC): - """Interface for the archive use case. - - Defines the business-critical methods for the following use cases: - - - 'A user should be able to archive NWP data for a given time period.' - """ - - @abc.abstractmethod - def archive(self, year: int, month: int) -> ResultE[str]: - """Archive NWP data to Zarr format for the given month. + def archive(self, period: dt.date) -> ResultE[str]: + """Archive NWP data to Zarr format for desired time period. Args: - year: The year for which to archive data. - month: The month for which to archive data. + period: The period for which to gather init time data. Returns: The path to the produced Zarr store. """ pass - -class InfoUseCase(abc.ABC): - """Interface for the notification use case. - - Defines the business-critical methods for the following use cases: - - - 'A user should be able to retrieve information about the service.' - """ - - @abc.abstractmethod - def available_models(self) -> list[str]: - """Get a list of available models.""" - pass - - @abc.abstractmethod - def model_repository_info(self) -> str: - """Get information about the model repository.""" - pass - - @abc.abstractmethod - def model_info(self) -> str: - """Get information about the model.""" - pass diff --git a/src/nwp_consumer/internal/repositories/model_repositories/noaa_s3.py b/src/nwp_consumer/internal/repositories/model_repositories/noaa_s3.py index 6c980df9..f48d4126 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/noaa_s3.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/noaa_s3.py @@ -121,26 +121,33 @@ def fetch_init_data( ) for url in urls: - yield delayed(self._download_and_convert)(url=url) + yield delayed(self._download_and_convert)(url=url, it=it) @classmethod @override def authenticate(cls) -> ResultE["NOAAS3ModelRepository"]: return Success(cls()) - def _download_and_convert(self, url: str) -> ResultE[list[xr.DataArray]]: + def _download_and_convert(self, url: str, it: dt.datetime) -> ResultE[list[xr.DataArray]]: """Download and convert a file from S3. Args: url: The URL to the S3 object. + it: The init time of the object in question, used in the saved path """ - return self._download(url).bind(self._convert) + return self._download(url=url, it=it).bind(self._convert) - def _download(self, url: str) -> ResultE[pathlib.Path]: - """Download an ECMWF realtime file from S3. + def _download(self, url: str, it: dt.datetime) -> ResultE[pathlib.Path]: + """Download a grib file from NOAA S3. + + The URLs have the following format:: + + https://noaa-gfs-bdp-pds.s3.amazonaws.com/gfs.20230911/06/atmos/gfs.t06z.pgrb2.1p00.f087 + <------------------bucket---------------><---inittime---> <-------filename----step> Args: url: The URL to the S3 object. + it: The init time of the object in question, used in the saved path """ local_path: pathlib.Path = ( pathlib.Path( @@ -148,54 +155,58 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: "RAWDIR", f"~/.local/cache/nwp/{self.repository().name}/{self.model().name}/raw", ), - ) / url.split("/")[-1] - ).with_suffix(".grib").expanduser() + ) / it.strftime("%Y/%m/%d/%H") / (url.split("/")[-1] + ".grib") + ).expanduser() # Only download the file if not already present - if not local_path.exists(): - local_path.parent.mkdir(parents=True, exist_ok=True) - log.debug("Requesting file from S3 at: '%s'", url) + if local_path.exists(): + return Success(local_path) - fs = s3fs.S3FileSystem(anon=True) - try: - if not fs.exists(url): - raise FileNotFoundError(f"File not found at '{url}'") + local_path.parent.mkdir(parents=True, exist_ok=True) + log.debug("Requesting file from S3 at: '%s'", url) - with local_path.open("wb") as lf, fs.open(url, "rb") as rf: - for chunk in iter(lambda: rf.read(12 * 1024), b""): - lf.write(chunk) - lf.flush() + fs = s3fs.S3FileSystem(anon=True) + try: + if not fs.exists(url): + raise FileNotFoundError(f"File not found at '{url}'") - except Exception as e: - return Failure(OSError( - f"Failed to download file from S3 at '{url}'. Encountered error: {e}", - )) + with local_path.open("wb") as lf, fs.open(url, "rb") as rf: + for chunk in iter(lambda: rf.read(12 * 1024), b""): + lf.write(chunk) + lf.flush() - if local_path.stat().st_size != fs.info(url)["size"]: - return Failure(ValueError( - f"File size mismatch from file at '{url}': " - f"{local_path.stat().st_size} != {fs.info(url)['size']} (remote). " - "File may be corrupted.", - )) + except Exception as e: + return Failure(OSError( + f"Failed to download file from S3 at '{url}'. Encountered error: {e}", + )) + + # For some reason, the GFS files are about 2MB larger when downloaded + # then their losted size in AWS. I'd be interested to know why! + if local_path.stat().st_size < fs.info(url)["size"]: + return Failure(ValueError( + f"File size mismatch from file at '{url}': " + f"{local_path.stat().st_size} != {fs.info(url)['size']} (remote). " + "File may be corrupted.", + )) - # Also download the associated index file - # * This isn't critical, but speeds up reading the file in when converting - # TODO: Re-incorporate this when https://github.com/ecmwf/cfgrib/issues/350 - # TODO: is resolved. Currently downloaded index files are ignored due to - # TODO: path differences once downloaded. - # index_url: str = url + ".idx" - # index_path: pathlib.Path = local_path.with_suffix(".grib.idx") - # try: - # with index_path.open("wb") as lf, fs.open(index_url, "rb") as rf: - # for chunk in iter(lambda: rf.read(12 * 1024), b""): - # lf.write(chunk) - # lf.flush() - # except Exception as e: - # log.warning( - # f"Failed to download index file from S3 at '{url}'. " - # "This will require a manual indexing when converting the file. " - # f"Encountered error: {e}", - # ) + # Also download the associated index file + # * This isn't critical, but speeds up reading the file in when converting + # TODO: Re-incorporate this when https://github.com/ecmwf/cfgrib/issues/350 + # TODO: is resolved. Currently downloaded index files are ignored due to + # TODO: path differences once downloaded. + # index_url: str = url + ".idx" + # index_path: pathlib.Path = local_path.with_suffix(".grib.idx") + # try: + # with index_path.open("wb") as lf, fs.open(index_url, "rb") as rf: + # for chunk in iter(lambda: rf.read(12 * 1024), b""): + # lf.write(chunk) + # lf.flush() + # except Exception as e: + # log.warning( + # f"Failed to download index file from S3 at '{url}'. " + # "This will require a manual indexing when converting the file. " + # f"Encountered error: {e}", + # ) return Success(local_path) @@ -219,6 +230,8 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: "ignore_keys": { "levelType": ["isobaricInhPa", "depthBelowLandLayer", "meanSea"], }, + "errors": "raise", + "indexpath": "", # TODO: Change when above TODO is resolved }, ) except Exception as e: diff --git a/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_s3.py b/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_s3.py index 944fc809..b639ea98 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_s3.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_s3.py @@ -49,7 +49,7 @@ def test__download_and_convert(self) -> None: for url in urls: with self.subTest(url=url): - result = c._download_and_convert(url) + result = c._download_and_convert(url=url, it=test_it) self.assertIsInstance(result, Success, msg=f"{result!s}") diff --git a/src/nwp_consumer/internal/services/__init__.py b/src/nwp_consumer/internal/services/__init__.py index 90d20dce..556d8099 100644 --- a/src/nwp_consumer/internal/services/__init__.py +++ b/src/nwp_consumer/internal/services/__init__.py @@ -5,9 +5,7 @@ """ from .consumer_service import ConsumerService -from .archiver_service import ArchiverService __all__ = [ "ConsumerService", - "ArchiverService" ] diff --git a/src/nwp_consumer/internal/services/archiver_service.py b/src/nwp_consumer/internal/services/archiver_service.py deleted file mode 100644 index 08b43e22..00000000 --- a/src/nwp_consumer/internal/services/archiver_service.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Implementation of the NWP consumer services.""" - -import dataclasses -import logging -import os -import pathlib -from typing import TYPE_CHECKING, override - -from joblib import Parallel, cpu_count -from returns.result import Failure, ResultE, Success - -from nwp_consumer.internal import entities, ports - -if TYPE_CHECKING: - import datetime as dt - -log = logging.getLogger("nwp-consumer") - - -class ArchiverService(ports.ArchiveUseCase): - """Service implementation of the consumer use case. - - This services contains the business logic required to enact - the consumer use case. It is responsible for consuming NWP data - and writing it to a Zarr store. - """ - - mr: type[ports.ModelRepository] - nr: type[ports.NotificationRepository] - - def __init__( - self, - model_repository: type[ports.ModelRepository], - notification_repository: type[ports.NotificationRepository], - ) -> None: - """Create a new instance.""" - self.mr = model_repository - self.nr = notification_repository - - @override - def archive(self, year: int, month: int) -> ResultE[str]: - monitor = entities.PerformanceMonitor() - with monitor: - - init_times = self.mr.repository().month_its(year=year, month=month) - - # Create a store for the archive - init_store_result: ResultE[entities.TensorStore] = \ - entities.TensorStore.initialize_empty_store( - model=self.mr.model().name, - repository=self.mr.repository().name, - coords=dataclasses.replace( - self.mr.model().expected_coordinates, - init_time=init_times, - ), - ) - - if isinstance(init_store_result, Failure): - return Failure(OSError( - f"Failed to initialize store for {year}-{month}: {init_store_result!s}"), - ) - store = init_store_result.unwrap() - - missing_times_result = store.missing_times() - if isinstance(missing_times_result, Failure): - return Failure(missing_times_result.failure()) - log.info(f"{len(missing_times_result.unwrap())} missing init_times in store.") - - failed_times: list[dt.datetime] = [] - for n, it in enumerate(missing_times_result.unwrap()): - log.info( - f"Consuming data from {self.mr.repository().name} for {it:%Y-%m-%d %H:%M} " - f"(time {n + 1}/{len(missing_times_result.unwrap())})", - ) - - # Authenticate with the model repository - amr_result = self.mr.authenticate() - if isinstance(amr_result, Failure): - store.delete_store() - return Failure(OSError( - "Unable to authenticate with model repository " - f"'{self.mr.repository().name}': " - f"{amr_result.failure()}", - )) - amr = amr_result.unwrap() - - # Create a generator to fetch and process raw data - n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections) - if os.getenv("CONCURRENCY", "True").capitalize() == "False": - n_jobs = 1 - log.debug(f"Downloading using {n_jobs} concurrent thread(s)") - fetch_result_generator = Parallel( - n_jobs=n_jobs, - prefer="threads", - return_as="generator_unordered", - )(amr.fetch_init_data(it=it)) - - # Regionally write the results of the generator as they are ready - for fetch_result in fetch_result_generator: - if isinstance(fetch_result, Failure): - log.error( - f"Error fetching data for init time '{it:%Y-%m-%d %H:%M}' " - f"and model {self.mr.model().name}: {fetch_result.failure()!s}", - ) - failed_times.append(it) - continue - for da in fetch_result.unwrap(): - write_result = store.write_to_region(da) - # Fail soft if a region fails to write - if isinstance(write_result, Failure): - log.error(f"Failed to write time {it:%Y-%m-%d %H:%M}: {write_result}") - failed_times.append(it) - - del fetch_result_generator - - # Add the failed times to the store's metadata - store.update_attrs({ - "failed_times": ", ".join([t.strftime("Day %d %H:%M") for t in failed_times]), - }) - - if len(failed_times) == len(missing_times_result.unwrap()): - store.delete_store() - return Failure(OSError( - "Failed to write any regions for all init times. " - "Check error logs for details.", - )) - - # Postprocess the dataset as required - # postprocess_result = store.postprocess(self._mr.metadata().postprocess_options) - # if isinstance(postprocess_result, Failure): - # return Failure(postprocess_result.failure()) - - notify_result = self.nr().notify( - message=entities.StoreCreatedNotification( - filename=pathlib.Path(store.path).name, - size_mb=store.size_kb // 1024, - performance=entities.PerformanceMetadata( - duration_seconds=monitor.get_runtime(), - memory_mb=monitor.max_memory_mb(), - ), - ), - ) - if isinstance(notify_result, Failure): - return Failure(OSError( - "Failed to notify of store creation: " - f"{notify_result.failure()}", - )) - - return Success(store.path) - diff --git a/src/nwp_consumer/internal/services/consumer_service.py b/src/nwp_consumer/internal/services/consumer_service.py index 11155a74..3673f3fa 100644 --- a/src/nwp_consumer/internal/services/consumer_service.py +++ b/src/nwp_consumer/internal/services/consumer_service.py @@ -1,149 +1,240 @@ -"""Implementation of the NWP consumer services.""" +"""Implementation of the NWP consumer service.""" import dataclasses import datetime as dt +import functools import logging import os import pathlib +from collections.abc import Callable, Iterator from typing import override +import xarray as xr from joblib import Parallel, cpu_count +from returns.methods import partition +from returns.pipeline import flow from returns.result import Failure, ResultE, Success from nwp_consumer.internal import entities, ports log = logging.getLogger("nwp-consumer") - class ConsumerService(ports.ConsumeUseCase): - """Service implementation of the consumer use case. + """Service implementation for the NWP Consumer. - This services contains the business logic required to enact - the consumer use case. It is responsible for consuming NWP data - and writing it to a Zarr store. + Defines the business-critical methods and logic. """ - mr: type[ports.ModelRepository] - nr: type[ports.NotificationRepository] + mr: ports.ModelRepository + nr: ports.NotificationRepository def __init__( - self, - model_repository: type[ports.ModelRepository], - notification_repository: type[ports.NotificationRepository], - ) -> None: - """Create a new instance.""" + self, + model_repository: ports.ModelRepository, + notification_repository: ports.NotificationRepository, + ) -> None: + """Create a new instance of the service.""" self.mr = model_repository self.nr = notification_repository + + @classmethod + def from_adaptors( + cls, + model_adaptor: type[ports.ModelRepository], + notification_adaptor: type[ports.NotificationRepository], + ) -> ResultE["ConsumerService"]: + """Create a new instance of the service from adaptors.""" + notification_repository = notification_adaptor() + model_repository_result = model_adaptor.authenticate() + return model_repository_result.do( + cls( + model_repository=model_repository, + notification_repository=notification_repository, + ) + for model_repository in model_repository_result + ) + + @staticmethod + def _fold_dataarrays_generator( + generator: Iterator[ResultE[list[xr.DataArray]]], + store: entities.TensorStore, + ) -> ResultE[int]: + """Process data from data generator. + + Args: + generator: A generator of ResultE objects containing either a list data arrays + or a Failure object. + store: The store to write the data to. + + Returns: + A ResultE object containing the sum of the write results or a Failure object. + """ + results: list[ResultE[int]] = [] + for value in generator: + if isinstance(value, Failure): + results.extend([value]) + else: + results.extend([store.write_to_region(da=da) for da in value.unwrap()]) + successes, failures = partition(results) + # TODO: Define the failure threshold for number of write attempts properly + log.info(f"Processed {len(successes)} DataArrays successfully with {len(failures)} errors.") + if len(failures) > 0: + for i, exc in enumerate(failures): + if i < 5: + log.error(str(exc)) + else: + break + return Failure(OSError( + "Error threshold exceeded: " + f"{len(failures)} errors (>0) occurred during processing.", + )) + else: + return Success(sum(successes)) + + @staticmethod + def _parallelize_generator[T]( + delayed_generator: Iterator[Callable[..., T]], + max_connections: int, + ) -> Iterator[T]: + """Parallelize a generator of delayed functions. + + Args: + delayed_generator: An iterable of delayed items. + The creation of these items must be delayed, either via joblib.delayed + or functools.partial, so they can be executed lazily. + max_connections: The maximum number of connections to use. + """ + # TODO: Change this based on threads instead of CPU count + n_jobs: int = max(cpu_count() - 1, max_connections) + if os.getenv("CONCURRENCY", "True").capitalize() == "False": + n_jobs = 1 + log.debug(f"Using {n_jobs} concurrent thread(s)") + + return Parallel( # type: ignore + n_jobs=n_jobs, + prefer="threads", + verbose=0, + return_as="generator_unordered", + )(delayed_generator) + + @staticmethod + def _create_suitable_store( + repository_metadata: entities.ModelRepositoryMetadata, + model_metadata: entities.ModelMetadata, + period: dt.datetime | dt.date | None = None, + ) -> ResultE[entities.TensorStore]: + """Create a store for the data with the relevant init time coordinates. + + Args: + repository_metadata: The metadata for the repository. + model_metadata: The metadata for the model. + period: The period for which to gather init time data. + """ + its: list[dt.datetime] = [] + match period: + case _ if period is None: + its = [repository_metadata.determine_latest_it_from(dt.datetime.now(tz=dt.UTC))] + case single_it if isinstance(period, dt.datetime): + its = [single_it] # type: ignore + case multiple_its if isinstance(period, dt.date): + its = repository_metadata.month_its( + year=multiple_its.year, + month=multiple_its.month, + ) + + # Create a store for the data with the relevant init time coordinates + return entities.TensorStore.initialize_empty_store( + model=model_metadata.name, + repository=repository_metadata.name, + coords=dataclasses.replace( + model_metadata.expected_coordinates, + init_time=its, + ), + ) + @override - def consume(self, it: dt.datetime | None = None) -> ResultE[str]: - # Note that the usage of the returns here is not in the spirit of - # 'railway orientated programming', mostly due to to the number of - # generators involved - it seemed clearer to be explicit. However, - # it would be much neater to refactor this to be more functional. + def consume( + self, + period: dt.datetime | dt.date | None = None, + ) -> ResultE[str]: + """Consume NWP data to Zarr format for desired time period. + + Where possible the implementation should be as memory-efficient as possible. + The designs of the repository methods also enable parallel processing within + the implementation. + + Args: + period: The period for which to gather init time data. + + Returns: + The path to the produced Zarr store. + + See Also: + - `repositories.ModelRepository.fetch_init_data` + - `tensorstore.TensorStore.write_to_region` + - https://joblib.readthedocs.io/en/stable/auto_examples/parallel_generator.html + """ monitor = entities.PerformanceMonitor() with monitor: - if it is None: - it = self.mr.repository().determine_latest_it_from(dt.datetime.now(tz=dt.UTC)) - log.info( - f"Consuming data from repository '{self.mr.repository().name}' " - f"for the '{self.mr.model().name}' model " - f"spanning init time '{it:%Y-%m-%d %H:%M}'", + init_store_result = self._create_suitable_store( + repository_metadata=self.mr.repository(), + model_metadata=self.mr.model(), + period=period, ) - it = it.replace(tzinfo=dt.UTC) - - # Create a store for the init time - init_store_result: ResultE[entities.TensorStore] = \ - entities.TensorStore.initialize_empty_store( - model=self.mr.model().name, - repository=self.mr.repository().name, - coords=dataclasses.replace( - self.mr.model().expected_coordinates, - init_time=[it], - ), - ) - if isinstance(init_store_result, Failure): return Failure(OSError( f"Failed to initialize store for init time: {init_store_result!s}", )) store = init_store_result.unwrap() - amr_result = self.mr.authenticate() - if isinstance(amr_result, Failure): - store.delete_store() - return Failure(OSError( - "Unable to authenticate with model repository " - f"'{self.mr.repository().name}': " - f"{amr_result.failure()}", - )) - amr = amr_result.unwrap() - - # Create a generator to fetch and process raw data - n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections) - if os.getenv("CONCURRENCY", "True").capitalize() == "False": - n_jobs = 1 - log.debug(f"Downloading using {n_jobs} concurrent thread(s)") - fetch_result_generator = Parallel( - n_jobs=n_jobs, - prefer="threads", - return_as="generator_unordered", - )(amr.fetch_init_data(it=it)) - - # Regionally write the results of the generator as they are ready - failed_etls: int = 0 - for fetch_result in fetch_result_generator: - if isinstance(fetch_result, Failure): - log.error( - f"Error fetching data for init time '{it:%Y-%m-%d %H:%M}' " - f"and model {self.mr.model().name}: {fetch_result.failure()!s}", - ) - failed_etls += 1 - continue - for da in fetch_result.unwrap(): - write_result = store.write_to_region(da) - if isinstance(write_result, Failure): - log.error( - f"Error writing data for init time '{it:%Y-%m-%d %H:%M}' " - f"and model {self.mr.model().name}: " - f"{write_result.failure()!s}", - ) - failed_etls += 1 - - del fetch_result_generator - # Fail hard if any of the writes failed - # * TODO: Consider just how hard we want to fail in this instance - if failed_etls > 0: - store.delete_store() - return Failure(OSError( - f"Failed to write {failed_etls} regions " - f"for init time '{it:%Y-%m-%d %H:%M}'. " - "See error logs for details.", - )) + missing_times_result = store.missing_times() + if isinstance(missing_times_result, Failure): + return missing_times_result - # Postprocess the dataset as required - # postprocess_result = store.postprocess(self.mr.repository().postprocess_options) - # if isinstance(postprocess_result, Failure): - # return Failure(postprocess_result.failure()) - - notify_result = self.nr().notify( - message=entities.StoreCreatedNotification( - filename=pathlib.Path(store.path).name, - size_mb=store.size_kb // 1024, # TODO: 2024-11-19 check this is right - performance=entities.PerformanceMetadata( - duration_seconds=monitor.get_runtime(), - memory_mb=monitor.max_memory_mb(), - ), + for n, it in enumerate(missing_times_result.unwrap()): + log.info( + f"Consuming data from {self.mr.repository().name} for {it:%Y-%m-%d %H:%M} " + f"(time {n + 1}/{len(missing_times_result.unwrap())})", + ) + process_result = flow( + self._parallelize_generator( + self.mr.fetch_init_data(it), + max_connections=self.mr.repository().max_connections, + ), + functools.partial(self._fold_dataarrays_generator, store=store), + ) + if isinstance(process_result, Failure): + return process_result + + notification_message = entities.StoreCreatedNotification( + filename=pathlib.Path(store.path).name, + size_mb=store.size_kb // 1024, + performance=entities.PerformanceMetadata( + duration_seconds=monitor.get_runtime(), + memory_mb=monitor.max_memory_mb(), ), ) + notify_result = self.nr.notify(message=notification_message) if isinstance(notify_result, Failure): - return Failure(OSError( + log.error( "Failed to notify of store creation: " - f"{notify_result.failure()}", - )) + f"{notify_result.failure()}. " + f"Notification: {notification_message}", + ) + log.info(f"Successfully processed data to '{store.path}'") return Success(store.path) @override - def postprocess(self, options: entities.PostProcessOptions) -> ResultE[str]: - return Failure(NotImplementedError("Postprocessing not yet implemented")) + def archive(self, period: dt.date) -> ResultE[str]: + return self.consume(period=period) + + @staticmethod + def info( + model_adaptor: type[ports.ModelRepository], + notification_adaptor: type[ports.NotificationRepository], + ) -> str: + """Get information about the service.""" + raise NotImplementedError("Not yet implemented") + diff --git a/src/nwp_consumer/internal/services/test_archiver.py b/src/nwp_consumer/internal/services/test_archiver.py deleted file mode 100644 index 3dd6cd86..00000000 --- a/src/nwp_consumer/internal/services/test_archiver.py +++ /dev/null @@ -1,40 +0,0 @@ -import shutil -import unittest - -import xarray as xr -from returns.pipeline import is_successful - -from nwp_consumer.internal.services.archiver_service import ArchiverService - -from ._dummy_adaptors import DummyModelRepository, DummyNotificationRepository - - -class TestParallelConsumer(unittest.TestCase): - - @unittest.skip("Takes an age to run, need to figure out a better way.") - def test_archive(self) -> None: - """Test the consume method of the ParallelConsumer class.""" - - test_consumer = ArchiverService( - model_repository=DummyModelRepository, - notification_repository=DummyNotificationRepository, - ) - - result = test_consumer.archive(year=2021, month=1) - - self.assertTrue(is_successful(result), msg=result) - - da: xr.DataArray = xr.open_dataarray(result.unwrap(), engine="zarr") - - self.assertEqual( - list(da.sizes.keys()), - ["init_time", "step", "variable", "latitude", "longitude"], - ) - - path = result.unwrap() - shutil.rmtree(path) - - -if __name__ == "__main__": - unittest.main() - diff --git a/src/nwp_consumer/internal/services/test_consumer.py b/src/nwp_consumer/internal/services/test_consumer_service.py similarity index 75% rename from src/nwp_consumer/internal/services/test_consumer.py rename to src/nwp_consumer/internal/services/test_consumer_service.py index 85baea11..10524e2c 100644 --- a/src/nwp_consumer/internal/services/test_consumer.py +++ b/src/nwp_consumer/internal/services/test_consumer_service.py @@ -15,12 +15,12 @@ class TestParallelConsumer(unittest.TestCase): def test_consume(self) -> None: """Test the consume method of the ParallelConsumer class.""" - test_consumer = ConsumerService( - model_repository=DummyModelRepository, - notification_repository=DummyNotificationRepository, - ) + test_consumer = ConsumerService.from_adaptors( + model_adaptor=DummyModelRepository, + notification_adaptor=DummyNotificationRepository, + ).unwrap() - result = test_consumer.consume(it=dt.datetime(2021, 1, 1, tzinfo=dt.UTC)) + result = test_consumer.consume(period=dt.datetime(2021, 1, 1, tzinfo=dt.UTC)) self.assertTrue(is_successful(result), msg=result) diff --git a/src/test_integration/test_integration.py b/src/test_integration/test_integration.py index a60e2ec5..490ce76c 100644 --- a/src/test_integration/test_integration.py +++ b/src/test_integration/test_integration.py @@ -1,28 +1,31 @@ import datetime as dt import unittest +from typing import TYPE_CHECKING import xarray as xr from returns.pipeline import is_successful -from nwp_consumer.internal import handlers, repositories, services +from nwp_consumer.internal import repositories, services +if TYPE_CHECKING: + from returns.result import ResultE class TestIntegration(unittest.TestCase): def test_ceda_metoffice_global_model(self) -> None: - c = handlers.CLIHandler( - consumer_usecase=services.ConsumerService( - model_repository=repositories.model_repositories.CEDAFTPModelRepository, - notification_repository=repositories.notification_repositories.StdoutNotificationRepository, - ), - archiver_usecase=services.ArchiverService( - model_repository=repositories.model_repositories.CEDAFTPModelRepository, - notification_repository=repositories.notification_repositories.StdoutNotificationRepository, - ), + test_it =dt.datetime(2021, 1, 1, tzinfo=dt.UTC) + service_result = services.ConsumerService.from_adaptors( + model_adaptor=repositories.model_repositories.CEDAFTPModelRepository, + notification_adaptor=repositories.notification_repositories.StdoutNotificationRepository, + ) + result: ResultE[str] = service_result.do( + consume_result + for service in service_result + for consume_result in service.consume(period=test_it) ) - result = c._consumer_usecase.consume(it=dt.datetime(2021, 1, 1, tzinfo=dt.UTC)) self.assertTrue(is_successful(result), msg=f"{result}") da = xr.open_dataarray(result.unwrap(), engine="zarr") self.assertTrue(da.sizes["init_time"] > 0) +