From dbbcf57aef3c80ce42f17066588392ff955ba9e3 Mon Sep 17 00:00:00 2001 From: devsjc <47188100+devsjc@users.noreply.github.com> Date: Thu, 31 Oct 2024 15:17:31 +0000 Subject: [PATCH] feat(models): Add GFS Model repository --- .github/workflows/branch_ci.yml | 7 +- .github/workflows/tagged_ci.yml | 4 +- Containerfile | 112 +++++-- pyproject.toml | 27 +- src/nwp_consumer/__init__.py | 3 + src/nwp_consumer/cmd/main.py | 9 +- .../internal/entities/coordinates.py | 64 +++- .../internal/entities/parameters.py | 32 ++ .../internal/entities/postprocess.py | 25 -- .../internal/entities/tensorstore.py | 315 +++++++----------- .../internal/entities/test_parameters.py | 10 + .../internal/entities/test_tensorstore.py | 243 ++++++++++---- src/nwp_consumer/internal/handlers/cli.py | 4 +- .../internal/ports/repositories.py | 6 +- .../internal/repositories/__init__.py | 6 +- .../model_repositories/__init__.py | 6 +- .../model_repositories/ecmwf_realtime.py | 76 +++-- .../model_repositories/metoffice_global.py | 72 ++-- .../model_repositories/noaa_gfs.py | 302 +++++++++++++++++ .../test_metoffice_global.py | 8 +- .../model_repositories/test_noaa_gfs.py | 111 ++++++ .../internal/services/archiver_service.py | 18 +- .../internal/services/consumer_service.py | 17 +- src/test_integration/test_integration.py | 4 +- 24 files changed, 1025 insertions(+), 456 deletions(-) create mode 100644 src/nwp_consumer/internal/repositories/model_repositories/noaa_gfs.py create mode 100644 src/nwp_consumer/internal/repositories/model_repositories/test_noaa_gfs.py diff --git a/.github/workflows/branch_ci.yml b/.github/workflows/branch_ci.yml index ba18063e..726d2f7e 100644 --- a/.github/workflows/branch_ci.yml +++ b/.github/workflows/branch_ci.yml @@ -45,7 +45,7 @@ jobs: python-version-file: "pyproject.toml" - name: Install editable package and required dependencies - run: uv sync --extra=dev + run: uv sync - name: Lint package run: uv run ruff check --output-format=github . @@ -76,7 +76,7 @@ jobs: python-version-file: "pyproject.toml" - name: Install editable package and required dependencies - run: uv sync --extra=dev + run: uv sync # Run unittests # * Produce JUnit XML report @@ -113,7 +113,7 @@ jobs: python-version-file: "pyproject.toml" - name: Install editable package and required dependencies - run: uv sync --extra=dev + run: uv sync - name: Build documentation run: uv run pydoctor @@ -174,3 +174,4 @@ jobs: labels: ${{ steps.meta.outputs.labels }} platforms: linux/amd64,linux/arm64 cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache + diff --git a/.github/workflows/tagged_ci.yml b/.github/workflows/tagged_ci.yml index 42fcc688..3b9bd8dd 100644 --- a/.github/workflows/tagged_ci.yml +++ b/.github/workflows/tagged_ci.yml @@ -99,7 +99,7 @@ jobs: python-version-file: "pyproject.toml" - name: Install editable package and required dependencies - run: uv sync + run: uv sync --no-dev # Building the wheel dynamically assigns the version according to git # * The setuptools_git_versioning package reads the git tags and assigns the version @@ -118,4 +118,4 @@ jobs: uses: pypa/gh-action-pypi-publish@v1.10 with: user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/Containerfile b/Containerfile index 864e8f49..66a076d6 100644 --- a/Containerfile +++ b/Containerfile @@ -1,20 +1,85 @@ -# Build a virtualenv using miniconda -# * Conda creates a completely isolated environment, -# including all required shared libraries, enabling -# just putting the virtualenv into a distroless image -# without having to faff around with linking all -# the filelist (including for each dependency) of -# https://packages.debian.org/trixie/libpython3.12-dev, e.g. -# -# echo "Copying symlinked python binary into venv" && \ -# cp --remove-destination /usr/local/bin/python3.12 /venv/bin/python && \ -# echo "Copying libpython package into venv" && \ -# cp -r /usr/local/lib/* /venv/lib/ && \ -# cp -r /usr/local/include/python3.12/* /venv/include/ && \ -# mkdir -p /venv/lib/aarch64-linux-gnu/ && \ -# cp -r /usr/lib/aarch64-linux-gnu/* /venv/lib/aarch64-linux-gnu/ && \ -# mkdir -p /venv/include/aarch64-linux-gnu/ && \ -# cp -r /usr/include/aarch64-linux-gnu/* /venv/include/aarch64-linux-gnu/ && \ +# POTENTIAL FOR SMALLER CONTAINERFILE IF THIS CAN BE GOT WORKING + + +# # --- Base Python image ----------------------------------------------------------------- +# FROM python:3.12-bookworm AS python-base +# +# # --- Distroless Container creation ----------------------------------------------------- +# FROM gcr.io/distroless/cc-debian12 AS python-distroless +# +# ARG CHIPSET_ARCH=aarch64-linux-gnu +# +# # Copy the python installation from the base image +# COPY --from=python-base /usr/local/lib/ /usr/local/lib/ +# COPY --from=python-base /usr/local/bin/python /usr/local/bin/python +# COPY --from=python-base /etc/ld.so.cache /etc/ld.so.cache +# +# # Add common compiled libraries +# COPY --from=python-base /usr/lib/${CHIPSET_ARCH}/libz.so.1 /lib/${CHIPSET_ARCH}/ +# COPY --from=python-base /usr/lib/${CHIPSET_ARCH}/libffi.so.8 /lib/${CHIPSET_ARCH}/ +# COPY --from=python-base /usr/lib/${CHIPSET_ARCH}/libbz2.so.1.0 /lib/${CHIPSET_ARCH}/ +# COPY --from=python-base /usr/lib/${CHIPSET_ARCH}/libm.so.6 /lib/${CHIPSET_ARCH}/ +# COPY --from=python-base /usr/lib/${CHIPSET_ARCH}/libc.so.6 /lib/${CHIPSET_ARCH}/ +# +# # Don't generate .pyc, enable tracebacks +# ENV LANG=C.UTF-8 \ +# LC_ALL=C.UTF-8 \ +# PYTHONDONTWRITEBYTECODE=1 \ +# PYTHONFAULTHANDLER=1 +# +# # Check python installation works +# COPY --from=python-base /bin/rm /bin/rm +# COPY --from=python-base /bin/sh /bin/sh +# RUN python --version +# RUN rm /bin/sh /bin/rm +# +# # --- Virtualenv builder image ---------------------------------------------------------- +# FROM python-base AS build-venv +# COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv +# +# ENV UV_LINK_MODE=copy \ +# UV_COMPILE_BYTECODE=1 \ +# UV_PYTHON_DOWNLOADS=never \ +# UV_PYTHON=python3.12 \ +# UV_NO_CACHE=1 \ +# CFLAGS="-g0 -Wl,--strip-all" \ +# VENV=/.venv +# +# COPY pyproject.toml ./ +# +# # Synchronize DEPENDENCIES without the application itself. +# # This layer is cached until uv.lock or pyproject.toml change. +# # Delete any unwanted parts of the installed packages to reduce size +# RUN uv venv ${VENV} && \ +# echo "Installing dependencies into ${VENV}" && \ +# mkdir src && \ +# du -h ${VENV}/lib/python3.12/site-packages && \ +# uv sync --no-dev --no-install-project && \ +# echo "Copying libpython package into ${VENV}" && \ +# cp --remove-destination /usr/local/bin/python3.12 ${VENV}/bin/python && \ +# cp /usr/local/lib/libpython3.12.so.1.0 ${VENV}/lib/ && \ +# echo "Optimizing site-packages" && \ +# rm -r ${VENV}/lib/python3.12/site-packages/**/tests && \ +# du -h ${VENV}/lib/python3.12/site-packages | sort -h | tail -n 4 +# +# COPY . /src +# RUN uv pip install --no-deps /src && ls /.venv/bin +# +# # --- Distroless App image -------------------------------------------------------------- +# FROM python-distroless +# +# COPY --from=build-venv /.venv /venv +# +# ENV RAWDIR=/work/raw \ +# ZARRDIR=/work/data \ +# ECCODES_DEFINITION_PATH=.venv/share/eccodes/definitions +# +# ENTRYPOINT ["/venv/bin/nwp-consumer-cli"] +# VOLUME /work +# STOPSIGNAL SIGINT + + +# WORKING CONTAINERFILE FROM quay.io/condaforge/miniforge3:latest AS build-venv @@ -30,25 +95,22 @@ COPY pyproject.toml /_lock/ # Synchronize DEPENDENCIES without the application itself. # This layer is cached until uv.lock or pyproject.toml change. # Delete any unwanted parts of the installed packages to reduce size -RUN --mount=type=cache,target=/root/.cache \ - apt-get update && apt-get install build-essential -y && \ +RUN apt-get -qq update && apt-get -qq -y install gcc && \ echo "Creating virtualenv at /venv" && \ - conda create -qy -p /venv python=3.12 numcodecs -RUN which gcc + conda create --quiet --yes -p /venv python=3.12 numcodecs eccodes RUN echo "Installing dependencies into /venv" && \ cd /_lock && \ mkdir src && \ uv sync --no-dev --no-install-project && \ echo "Optimizing /venv site-packages" && \ rm -r /venv/lib/python3.12/site-packages/**/tests && \ - rm -r /venv/lib/python3.12/site-packages/**/_*cache* - + rm -r /venv/lib/python3.12/site-packages/**/_*cache* && \ + rm -r /venv/share/eccodes/definitions/bufr # Then install the application itself # * Delete the test and cache folders from installed packages to reduce size COPY . /src -RUN --mount=type=cache,target=/root/.cache \ - uv pip install --no-deps --python=$UV_PROJECT_ENVIRONMENT /src +RUN uv pip install --no-deps --python=$UV_PROJECT_ENVIRONMENT /src # Copy the virtualenv into a distroless image # * These are small images that only contain the runtime dependencies diff --git a/pyproject.toml b/pyproject.toml index d2110286..bcf74161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,44 +18,39 @@ authors = [ classifiers = ["Programming Language :: Python :: 3"] dependencies = [ "dask == 2024.8.1", - "eccodes == 2.38.1", + "eccodes == 2.38.3", "ecmwf-api-client == 1.6.3", - "cfgrib == 0.9.14.0", + "cfgrib == 0.9.14.1", "dagster-pipes == 1.8.5", "joblib == 1.4.2", "numpy == 2.1.0", "ocf-blosc2 == 0.0.11", "psutil == 6.0.0", - "requests == 2.32.3", "returns == 0.23.0", "s3fs == 2024.9.0", "xarray == 2024.9.0", - "zarr == 2.18.2" + "zarr == 2.18.3" ] -[project.optional-dependencies] -test = [ +[dependency-groups] +dev = [ + # Testing + "botocore == 1.33.7", # Required for moto, prevents installing the whole of boto3 "flask == 3.0.0", "flask-cors == 4.0.0", "moto[s3,server] == 4.2.11", "unittest-xml-reporting == 3.2.0", "hypothesis == 6.115.3", -] -lint = [ + # Linting "returns[compatible-mypy]", "ruff == 0.6.9", "pandas-stubs", "types-psutil", "types-pytz", "types-pyyaml", -] -docs = [ + # Docs "pydoctor >= 24.3.0", -] -dev = [ - "nwp-consumer[test,lint,docs]", -] -lsp = [ + # IDE support "python-lsp-server", "pylsp-mypy", "python-lsp-ruff", @@ -102,6 +97,8 @@ plugins = [ # If they are ever made, remove from here! module = [ "cfgrib", + "botocore.session", + "botocore.client", "joblib", "ocf_blosc2", "s3fs", diff --git a/src/nwp_consumer/__init__.py b/src/nwp_consumer/__init__.py index b6f4d458..2cba477d 100644 --- a/src/nwp_consumer/__init__.py +++ b/src/nwp_consumer/__init__.py @@ -23,6 +23,8 @@ +-------------------------------+-------------------------------------+---------------------------------------------+ | MODEL_REPOSITORY | The model repository to use. | ceda-metoffice-global | +-------------------------------+-------------------------------------+---------------------------------------------+ +| CONCURRENCY | Whether to use concurrency. | True | ++-------------------------------+-------------------------------------+---------------------------------------------+ Development Documentation @@ -149,6 +151,7 @@ "gribapi", "aiobotocore", "s3fs", + "fsspec", "asyncio", "botocore", "cfgrib", diff --git a/src/nwp_consumer/cmd/main.py b/src/nwp_consumer/cmd/main.py index 6d001863..cc8b8c50 100644 --- a/src/nwp_consumer/cmd/main.py +++ b/src/nwp_consumer/cmd/main.py @@ -18,12 +18,11 @@ def parse_env() -> Adaptors: """Parse from the environment.""" model_repository_adaptor: type[ports.ModelRepository] match os.getenv("MODEL_REPOSITORY"): - case None: - log.error("MODEL_REPOSITORY is not set in environment.") - sys.exit(1) + case None | "gfs": + model_repository_adaptor = repositories.NOAAS3ModelRepository case "ceda": - model_repository_adaptor = repositories.CedaMetOfficeGlobalModelRepository - case "ecmwf-realtime-s3": + model_repository_adaptor = repositories.CEDAFTPModelRepository + case "ecmwf-realtime": model_repository_adaptor = repositories.ECMWFRealTimeS3ModelRepository case _ as model: log.error(f"Unknown model: {model}") diff --git a/src/nwp_consumer/internal/entities/coordinates.py b/src/nwp_consumer/internal/entities/coordinates.py index 565fe707..8e192d09 100644 --- a/src/nwp_consumer/internal/entities/coordinates.py +++ b/src/nwp_consumer/internal/entities/coordinates.py @@ -37,7 +37,10 @@ import dataclasses import datetime as dt +import json +from importlib.metadata import PackageNotFoundError, version +import dask.array import numpy as np import pandas as pd import pytz @@ -46,6 +49,11 @@ from .parameters import Parameter +try: + __version__ = version("nwp-consumer") +except PackageNotFoundError: + __version__ = "v?" + @dataclasses.dataclass(slots=True) class NWPDimensionCoordinateMap: @@ -69,7 +77,7 @@ class NWPDimensionCoordinateMap: """The forecast step times. This corresponds to the horizon of the values, which is the time - difference between the forecast initialisation time and the target + difference between the forecast initialization time and the target time at which the forecast data is valid. """ variable: list[Parameter] @@ -207,7 +215,7 @@ def to_pandas(self) -> dict[str, pd.Index]: # type: ignore This is useful for interoperability with xarray, which prefers to define DataArray coordinates using a dict pandas Index objects. - For the most part, the conversion consists of a straighforward cast + For the most part, the conversion consists of a straightforward cast to a pandas Index object. However, there are some caveats involving the time-centric dimensions: @@ -367,11 +375,57 @@ def default_chunking(self) -> dict[str, int]: that wants to cover the entire dimension should have a size equal to the dimension length. - It defaults to a single chunk per init time and step, and a single chunk - for each entire other dimension. + It defaults to a single chunk per init time and step, and 8 chunks + for each entire other dimension. These are purposefully small, to ensure + that when perfomring parallel writes, chunk boundaries are not crossed. """ out_dict: dict[str, int] = { "init_time": 1, "step": 1, - } | {dim: len(getattr(self, dim)) for dim in self.dims if dim not in ["init_time", "step"]} + } | { + dim: len(getattr(self, dim)) // 8 if len(getattr(self, dim)) > 8 else 1 + for dim in self.dims + if dim not in ["init_time", "step"] + } + return out_dict + + + def as_zeroed_dataarray(self, name: str) -> xr.DataArray: + """Express the coordinates as an xarray DataArray. + + Data is populated with zeros and a default chunking scheme is applied. + + Args: + name: The name of the DataArray. + + See Also: + - https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes + """ + # Create a dask array of zeros with the shape of the dataset + # * The values of this are ignored, only the shape and chunks are used + dummy_values = dask.array.zeros( # type: ignore + shape=list(self.shapemap.values()), + chunks=tuple([self.default_chunking()[k] for k in self.shapemap]), + ) + attrs: dict[str, str] = { + "produced_by": "".join(( + f"nwp-consumer {__version__} at ", + f"{dt.datetime.now(tz=dt.UTC).strftime('%Y-%m-%d %H:%M')}", + )), + "variables": json.dumps({ + p.value: { + "description": p.metadata().description, + "units": p.metadata().units, + } for p in self.variable + }), + } + # Create a DataArray object with the given coordinates and dummy values + da: xr.DataArray = xr.DataArray( + name=name, + data=dummy_values, + coords=self.to_pandas(), + attrs=attrs, + ) + return da + diff --git a/src/nwp_consumer/internal/entities/parameters.py b/src/nwp_consumer/internal/entities/parameters.py index a15fc0d2..c5c0e398 100644 --- a/src/nwp_consumer/internal/entities/parameters.py +++ b/src/nwp_consumer/internal/entities/parameters.py @@ -23,6 +23,8 @@ import dataclasses from enum import StrEnum, auto +from returns.result import Failure, ResultE, Success + @dataclasses.dataclass(slots=True) class ParameterLimits: @@ -77,6 +79,9 @@ class ParameterData: Used in sanity and validity checking the database values. """ + alternate_shortnames: list[str] = dataclasses.field(default_factory=list) + """Alternate names for the parameter found in the wild.""" + def __str__(self) -> str: """String representation of the parameter.""" return self.name @@ -121,6 +126,7 @@ def metadata(self) -> ParameterData: description="Temperature at screen level", units="C", limits=ParameterLimits(upper=60, lower=-90), + alternate_shortnames=["t", "t2m"], ) case self.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.name: return ParameterData( @@ -130,6 +136,7 @@ def metadata(self) -> ParameterData: "incident on the surface expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=1500, lower=0), + alternate_shortnames=["swavr", "ssrd", "dswrf"], ) case self.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL.name: return ParameterData( @@ -139,6 +146,7 @@ def metadata(self) -> ParameterData: "incident on the surface expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=500, lower=0), + alternate_shortnames=["strd", "dlwrf"], ) case self.RELATIVE_HUMIDITY_SL.name: return ParameterData( @@ -148,6 +156,7 @@ def metadata(self) -> ParameterData: "to the equilibrium vapour pressure of water", units="%", limits=ParameterLimits(upper=100, lower=0), + alternate_shortnames=["r"], ) case self.VISIBILITY_SL.name: return ParameterData( @@ -157,6 +166,7 @@ def metadata(self) -> ParameterData: "horizontally in daylight conditions.", units="m", limits=ParameterLimits(upper=4500, lower=0), + alternate_shortnames=["vis"], ) case self.WIND_U_COMPONENT_10m.name: return ParameterData( @@ -166,6 +176,7 @@ def metadata(self) -> ParameterData: "the wind in the eastward direction.", units="m/s", limits=ParameterLimits(upper=100, lower=-100), + alternate_shortnames=["u10"], ) case self.WIND_V_COMPONENT_10m.name: return ParameterData( @@ -176,6 +187,7 @@ def metadata(self) -> ParameterData: units="m/s", # Non-tornadic winds are usually < 100m/s limits=ParameterLimits(upper=100, lower=-100), + alternate_shortnames=["v10"], ) case self.WIND_U_COMPONENT_100m.name: return ParameterData( @@ -185,6 +197,7 @@ def metadata(self) -> ParameterData: "the wind in the eastward direction.", units="m/s", limits=ParameterLimits(upper=100, lower=-100), + alternate_shortnames=["u100"], ) case self.WIND_V_COMPONENT_100m.name: return ParameterData( @@ -194,6 +207,7 @@ def metadata(self) -> ParameterData: "the wind in the northward direction.", units="m/s", limits=ParameterLimits(upper=100, lower=-100), + alternate_shortnames=["v100"], ) case self.WIND_U_COMPONENT_200m.name: return ParameterData( @@ -203,6 +217,7 @@ def metadata(self) -> ParameterData: "the wind in the eastward direction.", units="m/s", limits=ParameterLimits(upper=150, lower=-150), + alternate_shortnames=["u200"], ) case self.WIND_V_COMPONENT_200m.name: return ParameterData( @@ -212,6 +227,7 @@ def metadata(self) -> ParameterData: "the wind in the northward direction.", units="m/s", limits=ParameterLimits(upper=150, lower=-150), + alternate_shortnames=["v200"], ) case self.SNOW_DEPTH_GL.name: return ParameterData( @@ -219,6 +235,7 @@ def metadata(self) -> ParameterData: description="Depth of snow on the ground.", units="m", limits=ParameterLimits(upper=12, lower=0), + alternate_shortnames=["sd", "sdwe"], ) case self.CLOUD_COVER_HIGH.name: return ParameterData( @@ -229,6 +246,7 @@ def metadata(self) -> ParameterData: "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), + alternate_shortnames=["hcc"], ) case self.CLOUD_COVER_MEDIUM.name: return ParameterData( @@ -239,6 +257,7 @@ def metadata(self) -> ParameterData: "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), + alternate_shortnames=["mcc"], ) case self.CLOUD_COVER_LOW.name: return ParameterData( @@ -249,6 +268,7 @@ def metadata(self) -> ParameterData: "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), + alternate_shortnames=["lcc"], ) case self.CLOUD_COVER_TOTAL.name: return ParameterData( @@ -259,6 +279,7 @@ def metadata(self) -> ParameterData: "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), + alternate_shortnames=["tcc", "clt"], ) case self.TOTAL_PRECIPITATION_RATE_GL.name: return ParameterData( @@ -268,6 +289,7 @@ def metadata(self) -> ParameterData: "including rain, snow, and hail.", units="kg/m^2/s", limits=ParameterLimits(upper=0.2, lower=0), + alternate_shortnames=["prate", "tprate"], ) case self.DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL.name: return ParameterData( @@ -278,6 +300,7 @@ def metadata(self) -> ParameterData: "expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=1000, lower=0), + alternate_shortnames=["uvb"], ) case self.DIRECT_SHORTWAVE_RADIATION_FLUX_GL.name: return ParameterData( @@ -289,7 +312,16 @@ def metadata(self) -> ParameterData: "expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=1000, lower=0), + alternate_shortnames=["dsrp"], ) case _: # Shouldn't happen thanks to the test case in test_parameters.py raise ValueError(f"Unknown parameter: {self}") + + def try_from_alternate(name: str) -> ResultE["Parameter"]: + """Map an alternate name to a parameter.""" + for p in Parameter: + if name in p.metadata().alternate_shortnames: + return Success(p) + return Failure(ValueError(f"Unknown shortname: {name}")) + diff --git a/src/nwp_consumer/internal/entities/postprocess.py b/src/nwp_consumer/internal/entities/postprocess.py index 1b0a0187..b17cd90d 100644 --- a/src/nwp_consumer/internal/entities/postprocess.py +++ b/src/nwp_consumer/internal/entities/postprocess.py @@ -30,22 +30,6 @@ class PostProcessOptions: i.e. nothing occurs by default. """ - standardize_coordinates: bool = False - """Whether to standardize the coordinates of the data. - - Note that this doesn't refer to interpolation: rather, it makes - the coordinates adhere to the usual directionality and regionality - within the circular space, i.e.: - - - Latitude values should be in the range and direction [+90, -90] - - Longitude values should be in the range and direction [-180, 180] - - Y values should be in descending order - - X values should be in ascending order - """ - - rechunk: bool = False - """Whether to rechunk the data.""" - validate: bool = False """Whether to validate the data. @@ -60,9 +44,6 @@ class PostProcessOptions: `_. """ - zip: bool = False - """Whether to zip the data.""" - plot: bool = False """Whether to save a plot of the data.""" @@ -71,10 +52,7 @@ def requires_rewrite(self) -> bool: """Boolean indicating whether the specified options necessitate a rewrite.""" return any( [ - self.standardize_coordinates, - self.rechunk, self.codec, - self.zip, ], ) @@ -82,11 +60,8 @@ def requires_postprocessing(self) -> bool: """Boolean indicating whether the specified options necessitate post-processing.""" return any( [ - self.standardize_coordinates, - self.rechunk, self.validate, self.codec, - self.zip, self.plot, ], ) diff --git a/src/nwp_consumer/internal/entities/tensorstore.py b/src/nwp_consumer/internal/entities/tensorstore.py index 7b857eed..480973a7 100644 --- a/src/nwp_consumer/internal/entities/tensorstore.py +++ b/src/nwp_consumer/internal/entities/tensorstore.py @@ -7,18 +7,14 @@ This module provides a class for storing metadata about a Zarr store. """ +import abc import dataclasses import datetime as dt -import json import logging import os import pathlib -import shutil -from importlib.metadata import PackageNotFoundError, version from typing import Any -import dask.array -import numpy as np import pandas as pd import xarray as xr import zarr @@ -30,11 +26,6 @@ log = logging.getLogger("nwp-consumer") -try: - __version__ = version("nwp-consumer") -except PackageNotFoundError: - __version__ = "v?" - @dataclasses.dataclass(slots=True) class ParameterScanResult: @@ -53,25 +44,25 @@ class ParameterScanResult: @dataclasses.dataclass(slots=True) -class TensorStore: +class TensorStore(abc.ABC): """Store class for multidimensional data. This class is used to store data in a Zarr store. - Each store instance is associated with a single init time, + Each store instance has defined coordinates for the data, and is capable of handling parallel, region-based updates. """ name: str """Identifier for the store and the data within.""" - path: pathlib.Path + path: str """The path to the store.""" coordinate_map: NWPDimensionCoordinateMap """The coordinates of the store.""" - size_mb: int - """The size of the store in megabytes.""" + size_kb: int + """The size of the store in kilobytes.""" encoding: dict[str, Any] """The encoding passed to Zarr whilst writing.""" @@ -79,18 +70,16 @@ class TensorStore: @classmethod def initialize_empty_store( cls, - name: str, + model: str, + repository: str, coords: NWPDimensionCoordinateMap, - overwrite_existing: bool = True, ) -> ResultE["TensorStore"]: """Initialize a store for a given init time. This method writes a blank dataarray to disk based on the input coordinates, which define the dimension labels and tick values of the output dataset object. - If the store already exists, it will be overwritten, unless the 'overwrite_existing' - flag is set to False. In this case, the existing store will be used only if its - coordinates are consistent with the expected coordinates. + .. note: If a store already exists at the expected path, it will be overwritten! The dataarray is 'blank' because it is written via:: @@ -112,9 +101,10 @@ def initialize_empty_store( - As above for the init_time dimension. Args: - name: The name of the tensor. + model: The name of the model providing the tensor data. + This is also used as the name of the tensor. + repository: The name of the repository providing the tensor data. coords: The coordinates of the store. - overwrite_existing: Whether to overwrite an existing store. Returns: An indicator of a successful store write containing the number of bytes written. @@ -131,123 +121,124 @@ def initialize_empty_store( ValueError( "Cannot initialize store with 'init_time' dimension coordinates not " "specified via a populated list. Check instantiation of " - "NWPDimensionCoordinateMap. " + "NWPDimensionCoordinateMap passed to this function. " f"Got: {coords.init_time} (not a list, or empty).", ), ) - store_range: str = f"{coords.init_time[0]:%Y%m%d%H}" - if len(coords.init_time) > 1: - store_range = f"{coords.init_time[0]:%Y%m%d%H}-{coords.init_time[-1]:%Y%m%d%H}" - store_path = pathlib.Path( - f"{os.getenv('ZARRDIR', f'~/.local/cache/nwp/{name}/data')}/{store_range}.zarr", - ) - # * Define a set of chunks allowing for intermediate parallel writes - # NOTE: This is not the same as the final chunking of the dataset! - # Merely a chunksize that is small enough to allow for parallel writes - # to different regions of the init store. - intermediate_chunks: dict[str, int] = { - "init_time": 1, - "step": 1, - "variable": 1, - "latitude": coords.shapemap.get("latitude", 400) // 4, - "longitude": coords.shapemap.get("longitude", 400) // 8, - "values": coords.shapemap.get("values", 100), - } - # Create a dask array of zeros with the shape of the dataset - # * The values of this are ignored, only the shape and chunks are used - dummy_values = dask.array.zeros( # type: ignore - shape=list(coords.shapemap.values()), - chunks=tuple([intermediate_chunks[k] for k in coords.shapemap]), - ) - attrs: dict[str, str] = { - "produced_by": "".join(( - f"nwp-consumer {__version__} at ", - f"{dt.datetime.now(tz=dt.UTC).strftime('%Y-%m-%d %H:%M')}", - )), - "variables": json.dumps({ - p.value: { - "description": p.metadata().description, - "units": p.metadata().units, - } for p in coords.variable - }), - } - # Create a DataArray object with the given coordinates and dummy values - da: xr.DataArray = xr.DataArray( - name=name, - data=dummy_values, - coords=coords.to_pandas(), - attrs=attrs, - ) - encoding: dict[str, Any] ={ + zarrdir = os.getenv("ZARRDIR", f"~/.local/cache/nwp/{repository}/{model}/data") + store: zarr.storage.Store + path: str + try: + path = pathlib.Path( + "/".join((zarrdir, TensorStore.gen_store_filename(coords=coords))), + ).expanduser().as_posix() + store = zarr.storage.DirectoryStore(path) + if zarrdir.startswith("s3"): + import s3fs + log.debug("Attempting AWS connection using credential discovery") + try: + fs = s3fs.S3FileSystem( + anon=False, + client_kwargs={ + "region_name": os.getenv("AWS_REGION", "eu-west-1"), + "endpoint_url": os.getenv("AWS_ENDPOINT_URL", None), + }, + ) + path = zarrdir + "/" + TensorStore.gen_store_filename(coords=coords) + fs.mkdirs(path=path, exist_ok=True) + store = s3fs.mapping.S3Map(path, fs, check=False, create=True) + except Exception as e: + return Failure(OSError( + f"Unable to create file mapping for path '{path}'. " + "Ensure ZARRDIR environment variable is specified correctly, " + "and AWS credentials are discoverable by botocore. " + f"Error context: {e}", + )) + except Exception as e: + return Failure(OSError( + f"Unable to create Zarr Store at dir '{zarrdir}'. " + "Ensure ZARRDIR environment variable is specified correctly. " + f"Error context: {e}", + )) + + # Write the coordinates to a skeleton Zarr store + # * 'compute=False' enables only saving metadata + # * 'mode="w"' overwrites any existing store + log.info("initializing zarr store at '%s'", path) + da: xr.DataArray = coords.as_zeroed_dataarray(name=model) + encoding = { + model: {"write_empty_chunks": False}, "init_time": {"units": "nanoseconds since 1970-01-01"}, "step": {"units": "hours"}, } + try: + _ = da.to_zarr( + store=store, + compute=False, + mode="w", + consolidated=True, + encoding=encoding, + ) + # Ensure the store is readable + store_da = xr.open_dataarray(store, engine="zarr") + except Exception as e: + return Failure( + OSError( + f"Failed writing blank store to '{path}': {e}", + ), + ) - match (os.path.exists(store_path), overwrite_existing): - case (True, False): - store_da: xr.DataArray = xr.open_dataarray(store_path, engine="zarr") - for dim in store_da.dims: - if dim not in da.dims: - return Failure( - ValueError( - "Cannot use existing store due to mismatched coordinates. " - f"Dimension '{dim}' in existing store not found in new store. " - "Use 'overwrite_existing=True' or move the existing store at " - f"'{store_path}' to a new location. ", - ), - ) - if not np.array_equal(store_da.coords[dim].values, da.coords[dim].values): - return Failure( - ValueError( - "Cannot use existing store due to mismatched coordinates. " - f"Dimension '{dim}' in existing store has different coordinate " - "values from specified. " - "Use 'overwrite_existing=True' or move the existing store at " - f"'{store_path}' to a new location.", - ), - ) - case (_, _): - try: - # Write the dataset to a skeleton zarr file - # * 'compute=False' enables only saving metadata - # * 'mode="w"' overwrites any existing store - _ = da.to_zarr( - store=store_path, - compute=False, - mode="w", - consolidated=True, - encoding=encoding, - ) - # Ensure the store is readable - store_da = xr.open_dataarray(store_path, engine="zarr") - except Exception as e: - return Failure( - OSError( - f"Failed writing blank store to disk: {e}", - ), - ) # Check the resultant array's coordinates can be converted back coordinate_map_result = NWPDimensionCoordinateMap.from_xarray(store_da) if isinstance(coordinate_map_result, Failure): return Failure( OSError( f"Error reading back coordinates of initialized store " - f"from disk (possible corruption): {coordinate_map_result}", + f"from path '{path}' (possible corruption): {coordinate_map_result}", ), ) return Success( cls( - name=name, - path=store_path, + name=model, + path=path, coordinate_map=coordinate_map_result.unwrap(), - size_mb=0, + size_kb=0, encoding=encoding, ), ) + #def from_existing_store( + # model: str, + # repository: str, + # expected_coords: NWPDimensionCoordinateMap, + #) -> ResultE["TensorStore"]: + # """Create a TensorStore instance from an existing store.""" + # pass # TODO + + # for dim in store_da.dims: + # if dim not in da.dims: + # return Failure( + # ValueError( + # "Cannot use existing store due to mismatched coordinates. " + # f"Dimension '{dim}' in existing store not found in new store. " + # "Use 'overwrite_existing=True' or move the existing store at " + # f"'{store}' to a new location. ", + # ), + # ) + # if not np.array_equal(store_da.coords[dim].values, da.coords[dim].values): + # return Failure( + # ValueError( + # "Cannot use existing store due to mismatched coordinates. " + # f"Dimension '{dim}' in existing store has different coordinate " + # "values from specified. " + # "Use 'overwrite_existing=True' or move the existing store at " + # f"'{store}' to a new location.", + # ), + # ) + # --- Business logic methods --- # def write_to_region( self, @@ -291,7 +282,7 @@ def write_to_region( # Calculate the number of bytes written nbytes: int = da.nbytes del da - self.size_mb += nbytes // (1024**2) + self.size_kb += nbytes // 1024 return Success(nbytes) def validate_store(self) -> ResultE[bool]: @@ -367,104 +358,33 @@ def scan_parameter_values(self, p: Parameter) -> ResultE[ParameterScanResult]: ), ) - - def postprocess(self, options: PostProcessOptions) -> ResultE[pathlib.Path]: + def postprocess(self, options: PostProcessOptions) -> ResultE[str]: """Post-process the store. This creates a new store, as many of the postprocess options require modifications to the underlying file structure of the store. """ + # TODO: Implement postprocessing options if options.requires_postprocessing(): log.info("Applying postprocessing options to store %s", self.name) if options.validate: log.warning("Validation not yet implemented in efficient manner. Skipping option.") - store_da: xr.DataArray = xr.open_dataarray( - self.path, - engine="zarr", - ) - - if options.codec: - log.debug("Applying codec %s to store %s", options.codec.name, self.name) - self.encoding = self.encoding | {"compressor": options.codec.value} - - if options.rechunk: - store_da = store_da.chunk(chunks=self.coordinate_map.default_chunking()) - - if options.standardize_coordinates: - # Make the longitude values range from -180 to 180 - store_da = store_da.assign_coords({ - "longitude": ((store_da.coords["longitude"] + 180) % 360) - 180, - }) - # Find the index of the maximum value - idx: int = store_da.coords["longitude"].argmax().values - # Move the maximum value to the end, and do the same to the underlying data - store_da = store_da.roll( - longitude=len(store_da.coords["longitude"]) - idx - 1, - roll_coords=True, - ) - coordinates_result = NWPDimensionCoordinateMap.from_xarray(store_da) - match coordinates_result: - case Failure(e): - return Failure(e) - case Success(coords): - self.coordinate_map = coords - - if options.requires_rewrite(): - processed_path = self.path.parent / (self.path.name + ".processed") - try: - log.debug( - "Writing postprocessed store to %s", - processed_path, - ) - # Clear the encoding for any variables indexed as an 'object' type - # * e.g. Dimensions with string labels -> the variable dim - # * See https://github.com/sgkit-dev/sgkit/issues/991 - # * and https://github.com/pydata/xarray/issues/3476 - store_da.coords["variable"].encoding.clear() - _ = store_da.to_zarr( - store=processed_path, - mode="w", - encoding=self.encoding, - consolidated=True, - ) - self.path = processed_path - except Exception as e: - return Failure( - OSError( - f"Error encountered writing postprocessed store: {e}", - ), - ) - - if options.zip: - log.debug( - "Postprocessor: Zipping store to " - f"{self.path.with_suffix(".zarr.zip")}", - ) - try: - shutil.make_archive(self.path.name, "zip", self.path) - except Exception as e: - return Failure( - OSError( - f"Error encountered zipping store: {e}", - ), - ) - log.debug("Postprocessing complete for store %s", self.name) return Success(self.path) else: return Success(self.path) - def update_attrs(self, attrs: dict[str, str]) -> ResultE[pathlib.Path]: + def update_attrs(self, attrs: dict[str, str]) -> ResultE[str]: """Update the attributes of the store. This method updates the attributes of the store with the given dictionary. """ - group: zarr.Group = zarr.open_group(self.path.as_posix()) + group: zarr.Group = zarr.open_group(self.path) group.attrs.update(attrs) - zarr.consolidate_metadata(self.path.as_posix()) + zarr.consolidate_metadata(self.path) return Success(self.path) def missing_times(self) -> ResultE[list[dt.datetime]]: @@ -490,4 +410,17 @@ def missing_times(self) -> ResultE[list[dt.datetime]]: missing_times.append(pd.Timestamp(it).to_pydatetime().replace(tzinfo=dt.UTC)) return Success(missing_times) + @staticmethod + def gen_store_filename(coords: NWPDimensionCoordinateMap) -> str: + """Create a filename for the store. + + If the store only covers a single init_time, the filename is the init time. + Else, if it covers multiple init_times, the filename is the range of init times. + The extension is '.zarr'. + """ + store_range: str = coords.init_time[0].strftime("%Y%m%d%H") + if len(coords.init_time) > 1: + store_range = f"{coords.init_time[0]:%Y%m%d%H}-{coords.init_time[-1]:%Y%m%d%H}" + + return store_range + ".zarr" diff --git a/src/nwp_consumer/internal/entities/test_parameters.py b/src/nwp_consumer/internal/entities/test_parameters.py index 1470e28f..c8b864f2 100644 --- a/src/nwp_consumer/internal/entities/test_parameters.py +++ b/src/nwp_consumer/internal/entities/test_parameters.py @@ -2,6 +2,7 @@ from hypothesis import given from hypothesis import strategies as st +from returns.pipeline import is_successful from .parameters import Parameter @@ -15,6 +16,15 @@ def test_metadata(self, p: Parameter) -> None: metadata = p.metadata() self.assertEqual(metadata.name, p.value) + @given(st.sampled_from([s for p in Parameter for s in p.metadata().alternate_shortnames])) + def test_try_from_shortname(self, shortname: str) -> None: + """Test the try_from_shortname method.""" + p = Parameter.try_from_alternate(shortname) + self.assertTrue(is_successful(p)) + + p = Parameter.try_from_alternate("invalid") + self.assertFalse(is_successful(p)) + if __name__ == "__main__": unittest.main() diff --git a/src/nwp_consumer/internal/entities/test_tensorstore.py b/src/nwp_consumer/internal/entities/test_tensorstore.py index c7c9d142..68844ebb 100644 --- a/src/nwp_consumer/internal/entities/test_tensorstore.py +++ b/src/nwp_consumer/internal/entities/test_tensorstore.py @@ -1,27 +1,77 @@ +import contextlib import dataclasses import datetime as dt +import logging +import os +import shutil import unittest +from collections.abc import Generator +from unittest.mock import patch import numpy as np import xarray as xr +from botocore.client import BaseClient as BotocoreClient +from botocore.session import Session +from moto.server import ThreadedMotoServer from returns.pipeline import is_successful -from returns.result import Failure, Success from .coordinates import NWPDimensionCoordinateMap from .parameters import Parameter from .postprocess import PostProcessOptions from .tensorstore import TensorStore +logging.getLogger("werkzeug").setLevel(logging.ERROR) + + +class MockS3Bucket(contextlib.ContextDecorator): + + client: BotocoreClient + server: ThreadedMotoServer + bucket: str = "test-bucket" + + def __enter__(self) -> None: + self.server = ThreadedMotoServer() + self.server.start() + + session = Session() + self.client = session.create_client( + service_name="s3", + region_name="us-east-1", + endpoint_url="http://localhost:5000", + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + ) + + self.client.create_bucket( + Bucket=self.bucket, + ) + + def __exit__(self, *exc) -> bool: # type:ignore + response = self.client.list_objects_v2( + Bucket=self.bucket, + ) + if "Contents" in response: + for obj in response["Contents"]: + self.client.delete_object( + Bucket=self.bucket, + Key=obj["Key"], + ) + self.server.stop() + return False + class TestTensorStore(unittest.TestCase): """Test the business methods of the TensorStore class.""" - test_coords: NWPDimensionCoordinateMap - test_store: TensorStore + @contextlib.contextmanager + def store(self, year: int) -> Generator[TensorStore, None, None]: + """Create an instance of the TensorStore class.""" - def setUp(self) -> None: - self.test_coords = NWPDimensionCoordinateMap( - init_time=[dt.datetime(2021, 1, 1, h, tzinfo=dt.UTC) for h in [0, 6, 12, 18]], + test_coords: NWPDimensionCoordinateMap = NWPDimensionCoordinateMap( + init_time=[ + dt.datetime(year, 1, 1, h, tzinfo=dt.UTC) + for h in [0, 6, 12, 18] + ], step=[1, 2, 3, 4], variable=[Parameter.TEMPERATURE_SL], latitude=np.linspace(90, -90, 12).tolist(), @@ -29,25 +79,74 @@ def setUp(self) -> None: ) init_result = TensorStore.initialize_empty_store( - name="test_da", - coords=self.test_coords, + model="test_da", + repository="dummy_repository", + coords=test_coords, + ) + self.assertTrue( + is_successful(init_result), + msg=f"Unable to initialize store: {init_result}", ) - match init_result: - case Success(store): - self.test_store = store - case Failure(e): - raise ValueError(f"Failed to initialize test store: {e}.") + store = init_result.unwrap() + yield store + shutil.rmtree(store.path) + + @patch.dict(os.environ, { + "AWS_ENDPOINT_URL": "http://localhost:5000", + "AWS_ACCESS_KEY_ID": "test-key", + "AWS_SECRET_ACCESS_KEY": "test-secret", + "ZARRDIR": "s3://test-bucket/data", + }, clear=True) + def test_initialize_empty_store_s3(self) -> None: + """Test the initialize_empty_store method.""" + test_coords: NWPDimensionCoordinateMap = NWPDimensionCoordinateMap( + init_time=[ + dt.datetime(2024, 1, 1, h, tzinfo=dt.UTC) + for h in [0, 6, 12, 18] + ], + step=[1, 2, 3, 4], + variable=[Parameter.TEMPERATURE_SL], + latitude=np.linspace(90, -90, 12).tolist(), + longitude=np.linspace(0, 360, 18).tolist(), + ) - def test_initialize_empty_store(self) -> None: - """Test the initialize_empty_store method.""" - # TODO - pass + with MockS3Bucket(): + init_result = TensorStore.initialize_empty_store( + model="test_da", + repository="dummy_repository", + coords=test_coords, + ) + self.assertTrue(is_successful(init_result)) + + # Assert it overwrites existing stores successfully + init_result = TensorStore.initialize_empty_store( + model="new_test_da", + repository="dummy_repository", + coords=test_coords, + ) + self.assertTrue(is_successful(init_result)) def test_write_to_region(self) -> None: """Test the write_to_region method.""" - # TODO - pass + with self.store(year=2022) as ts: + test_da: xr.DataArray = xr.DataArray( + name="test_da", + data=np.ones( + shape=list(ts.coordinate_map.shapemap.values()), + ), + coords=ts.coordinate_map.to_pandas(), + ) + + # Write each init time and step one at a time + for it in test_da.coords["init_time"].values: + for step in test_da.coords["step"].values: + write_result = ts.write_to_region( + da=test_da.where( + test_da["init_time"] == it, drop=True, + ).where(test_da["step"] == step, drop=True), + ) + self.assertTrue(is_successful(write_result), msg=write_result) def test_postprocess(self) -> None: """Test the postprocess method.""" @@ -64,25 +163,19 @@ class TestCase: options=PostProcessOptions(), should_error=False, ), - TestCase( - name="standardize_coordinates", - options=PostProcessOptions( - standardize_coordinates=True, - ), - should_error=False, - ), ] - for t in tests: - with self.subTest(name=t.name): - result = self.test_store.postprocess(t.options) - if t.should_error: - self.assertTrue( - isinstance(result, Exception), - msg="Expected error to be returned.", - ) - else: - self.assertTrue(is_successful(result)) + with self.store(year=1971) as ts: + for t in tests: + with self.subTest(name=t.name): + result = ts.postprocess(t.options) + if t.should_error: + self.assertTrue( + isinstance(result, Exception), + msg="Expected error to be returned.", + ) + else: + self.assertTrue(is_successful(result)) def test_missing_times(self) -> None: """Test the missing_times method.""" @@ -93,45 +186,47 @@ class TestCase: times_to_write: list[dt.datetime] expected: list[dt.datetime] - tests: list[TestCase] = [ - TestCase( - name="all_missing_times", - times_to_write=[], - expected=self.test_coords.init_time, - ), - TestCase( - name="some_missing_times", - times_to_write=[self.test_coords.init_time[0], self.test_coords.init_time[2]], - expected=[self.test_coords.init_time[1], self.test_coords.init_time[3]], - ), - TestCase( - name="no_missing_times", - times_to_write=self.test_coords.init_time, - expected=[], - ), - ] - - for t in tests: - with self.subTest(name=t.name): - for i in t.times_to_write: - write_result = self.test_store.write_to_region( - da=xr.DataArray( - name="test_da", - data=np.ones( - shape=[ - 1 if k == "init_time" else v - for k, v in self.test_coords.shapemap.items() - ], + with self.store(year=2024) as ts: + tests: list[TestCase] = [ + TestCase( + name="all_missing_times", + times_to_write=[], + expected=ts.coordinate_map.init_time, + ), + TestCase( + name="some_missing_times", + times_to_write=[ts.coordinate_map.init_time[0], ts.coordinate_map.init_time[2]], + expected=[ts.coordinate_map.init_time[1], ts.coordinate_map.init_time[3]], + ), + TestCase( + name="no_missing_times", + times_to_write=ts.coordinate_map.init_time, + expected=[], + ), + ] + + for t in tests: + with self.subTest(name=t.name): + for i in t.times_to_write: + write_result = ts.write_to_region( + da=xr.DataArray( + name="test_da", + data=np.ones( + shape=[ + 1 if k == "init_time" else v + for k, v in ts.coordinate_map.shapemap.items() + ], + ), + coords=ts.coordinate_map.to_pandas() | { + "init_time": [np.datetime64(i.replace(tzinfo=None), "ns")], + }, ), - coords=self.test_coords.to_pandas() | { - "init_time": [np.datetime64(i.replace(tzinfo=None), "ns")], - }, - ), - ) - write_result.unwrap() - result = self.test_store.missing_times() - missing_times = result.unwrap() - self.assertListEqual(missing_times, t.expected) + ) + write_result.unwrap() + result = ts.missing_times() + missing_times = result.unwrap() + self.assertListEqual(missing_times, t.expected) if __name__ == "__main__": unittest.main() + diff --git a/src/nwp_consumer/internal/handlers/cli.py b/src/nwp_consumer/internal/handlers/cli.py index 6cceead4..e69956fa 100644 --- a/src/nwp_consumer/internal/handlers/cli.py +++ b/src/nwp_consumer/internal/handlers/cli.py @@ -89,7 +89,7 @@ def run(self) -> int: log.error(f"Failed to consume NWP data: {e}") return 1 case Success(path): - log.info(f"Successfully consumed NWP data to '{path.as_posix()}'") + log.info(f"Successfully consumed NWP data to '{path}'") return 0 case "archive": @@ -100,7 +100,7 @@ def run(self) -> int: log.error(f"Failed to archive NWP data: {e}") return 1 case Success(path): - log.info(f"Successfully archived NWP data to '{path.as_posix()}'") + log.info(f"Successfully archived NWP data to '{path}'") return 0 case "info": diff --git a/src/nwp_consumer/internal/ports/repositories.py b/src/nwp_consumer/internal/ports/repositories.py index aea603a9..858d0650 100644 --- a/src/nwp_consumer/internal/ports/repositories.py +++ b/src/nwp_consumer/internal/ports/repositories.py @@ -87,9 +87,9 @@ def fetch_init_data(self, it: dt.datetime) \ ``_download_and_convert`` in the example above. This is to allow for parallelization of the download and processing. - .. note:: It is however, worth considering the most efficient way to download and process the data. - The above assumes that the data comes in many files, but there is a possibility of the - case where the source provides one large file with many underlying datasets within. + .. note:: It is however, worth considering the most efficient way to download and process + the data. The above assumes that the data comes in many files, but there is a possibility + of the case where the source provides one large file with many underlying datasets within. In this case, it may be more efficient to download the large file in the `fetch_init_data` method and then process the datasets within via the yielded functions. diff --git a/src/nwp_consumer/internal/repositories/__init__.py b/src/nwp_consumer/internal/repositories/__init__.py index d1f60926..7d34747b 100644 --- a/src/nwp_consumer/internal/repositories/__init__.py +++ b/src/nwp_consumer/internal/repositories/__init__.py @@ -24,8 +24,9 @@ """ from .model_repositories import ( - CedaMetOfficeGlobalModelRepository, + CEDAFTPModelRepository, ECMWFRealTimeS3ModelRepository, + NOAAS3ModelRepository, ) from .notification_repositories import ( StdoutNotificationRepository, @@ -33,8 +34,9 @@ ) __all__ = [ - "CedaMetOfficeGlobalModelRepository", + "CEDAFTPModelRepository", "ECMWFRealTimeS3ModelRepository", + "NOAAS3ModelRepository", "StdoutNotificationRepository", "DagsterPipesNotificationRepository", ] diff --git a/src/nwp_consumer/internal/repositories/model_repositories/__init__.py b/src/nwp_consumer/internal/repositories/model_repositories/__init__.py index 3580a1cc..f7e444d3 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/__init__.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/__init__.py @@ -1,8 +1,10 @@ -from .metoffice_global import CedaMetOfficeGlobalModelRepository +from .metoffice_global import CEDAFTPModelRepository from .ecmwf_realtime import ECMWFRealTimeS3ModelRepository +from .noaa_gfs import NOAAS3ModelRepository __all__ = [ - "CedaMetOfficeGlobalModelRepository", + "CEDAFTPModelRepository", "ECMWFRealTimeS3ModelRepository", + "NOAAS3ModelRepository", ] diff --git a/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py b/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py index f726bcc8..35b3ae38 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py @@ -65,7 +65,7 @@ def repository() -> entities.ModelRepositoryMetadata: name="ECMWF-Realtime-S3", is_archive=False, is_order_based=True, - running_hours=[0, 12], + running_hours=[0, 6, 12, 18], delay_minutes=(60 * 6), # 6 hours max_connections=100, required_env=[ @@ -118,7 +118,6 @@ def model() -> entities.ModelMetadata: @override def fetch_init_data(self, it: dt.datetime) \ -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: - # List relevant files in the S3 bucket try: urls: list[str] = [ @@ -145,6 +144,10 @@ def fetch_init_data(self, it: dt.datetime) \ "named with the expected pattern, e.g. 'A2S10250000102603001.", )) + log.debug( + f"Found {len(urls)} files for init time '{it.strftime('%Y-%m-%d %H:%M')}' " + f"in bucket path '{self.bucket}/ecmwf'.", + ) for url in urls: yield delayed(self._download_and_convert)(url=url) @@ -167,6 +170,7 @@ def authenticate(cls) -> ResultE["ECMWFRealTimeS3ModelRepository"]: f"Credentials may be wrong or undefined. Encountered error: {e}", )) + log.debug(f"Successfully authenticated with S3 instance '{bucket}'") return Success(cls(bucket=bucket, fs=_fs)) @@ -194,9 +198,9 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: ).with_suffix(".grib").expanduser() # Only download the file if not already present - if not local_path.exists(): + if not local_path.exists() or local_path.stat().st_size == 0: local_path.parent.mkdir(parents=True, exist_ok=True) - log.info("Requesting file from S3 at: '%s'", url) + log.debug("Requesting file from S3 at: '%s'", url) try: if not self._fs.exists(url): @@ -234,13 +238,19 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: f"Error opening '{path}' as list of xarray Datasets: {e}", )) if len(dss) == 0: - return Failure(ValueError(f"No datasets found in '{path}'")) + return Failure(ValueError( + f"No datasets found in '{path}'. File may be corrupted. " + "A redownload of the file may be required.", + )) processed_das: list[xr.DataArray] = [] for i, ds in enumerate(dss): try: da: xr.DataArray = ( - ds.pipe(ECMWFRealTimeS3ModelRepository._rename_vars) + ECMWFRealTimeS3ModelRepository._rename_or_drop_vars( + ds=ds, + allowed_parameters=ECMWFRealTimeS3ModelRepository.model().expected_coordinates.variable, + ) .rename(name_dict={"time": "init_time"}) .expand_dims(dim="init_time") .expand_dims(dim="step") @@ -274,43 +284,13 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: return Success(processed_das) - @staticmethod - def _rename_vars(ds: xr.Dataset) -> xr.Dataset: - """Rename variables to match the expected names.""" - rename_map: dict[str, str] = { - "dsrp": entities.Parameter.DIRECT_SHORTWAVE_RADIATION_FLUX_GL.value, - "uvb": entities.Parameter.DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL.value, - "sd": entities.Parameter.SNOW_DEPTH_GL.value, - "tcc": entities.Parameter.CLOUD_COVER_TOTAL.value, - "clt": entities.Parameter.CLOUD_COVER_TOTAL.value, - "u10": entities.Parameter.WIND_U_COMPONENT_10m.value, - "v10": entities.Parameter.WIND_V_COMPONENT_10m.value, - "t2m": entities.Parameter.TEMPERATURE_SL.value, - "ssrd": entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.value, - "strd": entities.Parameter.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL.value, - "lcc": entities.Parameter.CLOUD_COVER_LOW.value, - "mcc": entities.Parameter.CLOUD_COVER_MEDIUM.value, - "hcc": entities.Parameter.CLOUD_COVER_HIGH.value, - "vis": entities.Parameter.VISIBILITY_SL.value, - "u200": entities.Parameter.WIND_U_COMPONENT_200m.value, - "v200": entities.Parameter.WIND_V_COMPONENT_200m.value, - "u100": entities.Parameter.WIND_U_COMPONENT_100m.value, - "v100": entities.Parameter.WIND_V_COMPONENT_100m.value, - "tprate": entities.Parameter.TOTAL_PRECIPITATION_RATE_GL.value, - } - - for old, new in rename_map.items(): - if old in ds.data_vars: - ds = ds.rename({old: new}) - return ds - @staticmethod def _wanted_file(filename: str, it: dt.datetime, max_step: int) -> bool: """Determine if the file is wanted based on the init time. See module docstring for the file naming convention. Returns True if the filename describes data corresponding to the input - initialisation time and model metadata. + initialization time and model metadata. Args: filename: The name of the file. @@ -329,3 +309,25 @@ def _wanted_file(filename: str, it: dt.datetime, max_step: int) -> bool: "%Y%m%d%H%M%z", ) return tt < it + dt.timedelta(hours=max_step) + + + @staticmethod + def _rename_or_drop_vars(ds: xr.Dataset, allowed_parameters: list[entities.Parameter]) \ + -> xr.Dataset: + """Rename variables to match the expected names, dropping invalid ones. + + Args: + ds: The xarray dataset to rename. + allowed_parameters: The list of parameters allowed in the resultant dataset. + """ + for var in ds.data_vars: + param_result = entities.Parameter.try_from_alternate(str(var)) + match param_result: + case Success(p): + if p in allowed_parameters: + ds = ds.rename_vars({var: p.value}) + continue + log.warning("Dropping invalid parameter '%s' from dataset", var) + ds = ds.drop_vars(str(var)) + return ds + diff --git a/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py b/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py index 320d6d58..f3aa1de8 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py @@ -11,7 +11,7 @@ `this PDF `_. For further details on the repository, see the -`CedaMetOfficeGlobalModelRepository.repository` implementation. +`CEDAFTPModelRepository.repository` implementation. Data discrepancies and corrections ================================== @@ -94,7 +94,7 @@ log = logging.getLogger("nwp-consumer") -class CedaMetOfficeGlobalModelRepository(ports.ModelRepository): +class CEDAFTPModelRepository(ports.ModelRepository): """Repository implementation for the MetOffice global model data.""" url_base: str = "ftp.ceda.ac.uk/badc/ukmo-nwp/data/global-grib" @@ -119,9 +119,7 @@ def repository() -> entities.ModelRepositoryMetadata: max_connections=20, required_env=["CEDA_FTP_USER", "CEDA_FTP_PASS"], optional_env={}, - postprocess_options=entities.PostProcessOptions( - standardize_coordinates=True, - ), + postprocess_options=entities.PostProcessOptions(), ) @staticmethod @@ -204,7 +202,7 @@ def _download_and_convert(self, url: str) -> ResultE[list[xr.DataArray]]: @classmethod @override - def authenticate(cls) -> ResultE["CedaMetOfficeGlobalModelRepository"]: + def authenticate(cls) -> ResultE["CEDAFTPModelRepository"]: """Authenticate with the CEDA FTP server. Returns: @@ -285,7 +283,11 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: ) try: da: xr.DataArray = ( - ds.sel(step=[np.timedelta64(i, "h") for i in range(0, 48, 1)]) + CEDAFTPModelRepository._rename_or_drop_vars( + ds=ds, + allowed_parameters=CEDAFTPModelRepository.model().expected_coordinates.variable, + ) + .sel(step=[np.timedelta64(i, "h") for i in range(0, 48, 1)]) .expand_dims(dim={"init_time": [ds["time"].values]}) .drop_vars( names=[ @@ -294,8 +296,7 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: if v not in ["init_time", "step", "latitude", "longitude"] ], ) - .pipe(CedaMetOfficeGlobalModelRepository._rename_vars) - .to_dataarray(name=CedaMetOfficeGlobalModelRepository.model().name) + .to_dataarray(name=CEDAFTPModelRepository.model().name) .transpose("init_time", "step", "variable", "latitude", "longitude") # Remove the last value of the longitude dimension as it overlaps with the next file # Reverse the latitude dimension to be in descending order @@ -311,47 +312,22 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: @staticmethod - def _rename_vars(ds: xr.Dataset) -> xr.Dataset: - """Rename variables to match the expected names. - - To find the names as they exist in the raw files, the following - function was used: - - >>> import xarray as xr - >>> import urllib.request - >>> import datetime as dt - >>> - >>> def download_single_file(parameter: str) -> xr.Dataset: - >>> it = dt.datetime(2021, 1, 1, 0, tzinfo=dt.UTC) - >>> base_url = "ftp://:@ftp.ceda.ac.uk/badc/ukmo-nwp/data/global-grib" - >>> url = f"{base_url}/{it:%Y/%m/%d}/" + \ - >>> f"{it:%Y%m%d%H}_WSGlobal17km_{parameter}_AreaA_000144.grib" - >>> response = urllib.request.urlopen(url) - >>> with open("/tmp/mo-global/test.grib", "wb") as f: - >>> for chunk in iter(lambda: response.read(16 * 1024), b""): - >>> f.write(chunk) - >>> f.flush() - >>> - >>> ds = xr.open_dataset("/tmp/mo-global/test.grib", engine="cfgrib") - >>> return ds + def _rename_or_drop_vars(ds: xr.Dataset, allowed_parameters: list[entities.Parameter]) \ + -> xr.Dataset: + """Rename variables to match the expected names, dropping invalid ones. Args: ds: The xarray dataset to rename. + allowed_parameters: The list of parameters allowed in the resultant dataset. """ - rename_map: dict[str, str] = { - "t": entities.Parameter.TEMPERATURE_SL.value, - "r": entities.Parameter.RELATIVE_HUMIDITY_SL.value, - "sf": entities.Parameter.SNOW_DEPTH_GL.value, - "prate": entities.Parameter.TOTAL_PRECIPITATION_RATE_GL.value, - "swavr": entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.value, - "u": entities.Parameter.WIND_U_COMPONENT_10m.value, - "v": entities.Parameter.WIND_V_COMPONENT_10m.value, - "vis": entities.Parameter.VISIBILITY_SL.value, - "hcc": entities.Parameter.CLOUD_COVER_HIGH.value, - "lcc": entities.Parameter.CLOUD_COVER_LOW.value, - "mcc": entities.Parameter.CLOUD_COVER_MEDIUM.value, - } - for old, new in rename_map.items(): - if old in ds.data_vars: - ds = ds.rename_vars({old: new}) + for var in ds.data_vars: + param_result = entities.Parameter.try_from_alternate(str(var)) + match param_result: + case Success(p): + if p in allowed_parameters: + ds = ds.rename_vars({var: p.value}) + continue + log.warning("Dropping invalid parameter '%s' from dataset", var) + ds = ds.drop_vars(str(var)) return ds + diff --git a/src/nwp_consumer/internal/repositories/model_repositories/noaa_gfs.py b/src/nwp_consumer/internal/repositories/model_repositories/noaa_gfs.py new file mode 100644 index 00000000..c0e63df5 --- /dev/null +++ b/src/nwp_consumer/internal/repositories/model_repositories/noaa_gfs.py @@ -0,0 +1,302 @@ +"""Repository implementation for NOAA GFS data stored in S3. + +This module contains the implementation of the model repository for the +NOAA GFS data stored in an S3 bucket. +""" + +import datetime as dt +import logging +import os +import pathlib +import re +from collections.abc import Callable, Iterator +from typing import override + +import cfgrib +import s3fs +import xarray as xr +from joblib import delayed +from returns.result import Failure, ResultE, Success + +from nwp_consumer.internal import entities, ports + +log = logging.getLogger("nwp-consumer") + + +class NOAAS3ModelRepository(ports.ModelRepository): + """Model repository implementation for GFS data stored in S3.""" + + @staticmethod + @override + def repository() -> entities.ModelRepositoryMetadata: + return entities.ModelRepositoryMetadata( + name="NOAA-GFS-S3", + is_archive=False, + is_order_based=False, + running_hours=[0, 6, 12, 18], + delay_minutes=(60 * 24 * 7), # 1 week + max_connections=100, + required_env=[], + optional_env={}, + postprocess_options=entities.PostProcessOptions(), + ) + + @staticmethod + @override + def model() -> entities.ModelMetadata: + return entities.ModelMetadata( + name="NCEP-GFS", + resolution="1 degree", + expected_coordinates=entities.NWPDimensionCoordinateMap( + init_time=[], + step=list(range(0, 49, 3)), + variable=sorted( + [ + entities.Parameter.TEMPERATURE_SL, + entities.Parameter.CLOUD_COVER_TOTAL, + entities.Parameter.CLOUD_COVER_HIGH, + entities.Parameter.CLOUD_COVER_MEDIUM, + entities.Parameter.CLOUD_COVER_LOW, + entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL, + entities.Parameter.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL, + entities.Parameter.TOTAL_PRECIPITATION_RATE_GL, + entities.Parameter.SNOW_DEPTH_GL, + entities.Parameter.RELATIVE_HUMIDITY_SL, + entities.Parameter.VISIBILITY_SL, + entities.Parameter.WIND_U_COMPONENT_10m, + entities.Parameter.WIND_V_COMPONENT_10m, + entities.Parameter.WIND_U_COMPONENT_100m, + entities.Parameter.WIND_V_COMPONENT_100m, + ], + ), + latitude=[float(lat) for lat in range(90, -90 - 1, -1)], + longitude=[float(lon) for lon in range(-180, 180 + 1, 1)], + ), + ) + + @override + def fetch_init_data( + self, it: dt.datetime, + ) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: + # List relevant files in the s3 bucket + bucket_path: str = f"noaa-gfs-bdp-pds/gfs.{it:%Y%m%d}/{it:%H}/atmos" + try: + fs = s3fs.S3FileSystem(anon=True) + urls: list[str] = [ + f"s3://{f}" + for f in fs.ls(bucket_path) + if self._wanted_file( + filename=f.split("/")[-1], + it=it, + max_step=max(self.model().expected_coordinates.step), + ) + ] + except Exception as e: + yield delayed(Failure)( + ValueError( + f"Failed to list file in bucket path '{bucket_path}'. " + "Ensure the path exists and the bucket does not require auth. " + f"Encountered error: '{e}'", + ), + ) + return + + if len(urls) == 0: + yield delayed(Failure)( + ValueError( + f"No files found for init time '{it:%Y-%m-%d %H:%M}'. " + "in bucket path '{bucket_path}'. Ensure files exists at the given path " + "with the expected filename pattern. ", + ), + ) + + for url in urls: + yield delayed(self._download_and_convert)(url=url) + + @classmethod + @override + def authenticate(cls) -> ResultE["NOAAS3ModelRepository"]: + return Success(cls()) + + def _download_and_convert(self, url: str) -> ResultE[list[xr.DataArray]]: + """Download and convert a file from S3. + + Args: + url: The URL to the S3 object. + """ + return self._download(url).bind(self._convert) + + def _download(self, url: str) -> ResultE[pathlib.Path]: + """Download an ECMWF realtime file from S3. + + Args: + url: The URL to the S3 object. + """ + local_path: pathlib.Path = ( + pathlib.Path( + os.getenv( + "RAWDIR", + f"~/.local/cache/nwp/{self.repository().name}/{self.model().name}/raw", + ), + ) / url.split("/")[-1] + ).with_suffix(".grib").expanduser() + + # Only download the file if not already present + if not local_path.exists(): + local_path.parent.mkdir(parents=True, exist_ok=True) + log.debug("Requesting file from S3 at: '%s'", url) + + fs = s3fs.S3FileSystem(anon=True) + try: + if not fs.exists(url): + raise FileNotFoundError(f"File not found at '{url}'") + + with local_path.open("wb") as lf, fs.open(url, "rb") as rf: + for chunk in iter(lambda: rf.read(12 * 1024), b""): + lf.write(chunk) + lf.flush() + + except Exception as e: + return Failure(OSError( + f"Failed to download file from S3 at '{url}'. Encountered error: {e}", + )) + + if local_path.stat().st_size != fs.info(url)["size"]: + return Failure(ValueError( + f"Failed to download file from S3 at '{url}'. " + "File size mismatch. File may be corrupted.", + )) + + # Also download the associated index file + # * This isn't critical, but speeds up reading the file in when converting + # TODO: Re-incorporate this when https://github.com/ecmwf/cfgrib/issues/350 + # TODO: is resolved. Currently downloaded index files are ignored due to + # TODO: path differences once downloaded. + index_url: str = url + ".idx" + index_path: pathlib.Path = local_path.with_suffix(".grib.idx") + try: + with index_path.open("wb") as lf, fs.open(index_url, "rb") as rf: + for chunk in iter(lambda: rf.read(12 * 1024), b""): + lf.write(chunk) + lf.flush() + except Exception as e: + log.warning( + f"Failed to download index file from S3 at '{url}'. " + "This will require a manual indexing when converting the file. " + f"Encountered error: {e}", + ) + + return Success(local_path) + + def _convert(self, path: pathlib.Path) -> ResultE[list[xr.DataArray]]: + """Convert a GFS file to an xarray DataArray collection. + + Args: + path: The path to the local grib file. + """ + try: + # Use some options when opening the datasets: + # * 'squeeze' reduces length-1- dimensions to scalar coordinates, + # thus single-level variables should not have any extra dimensions + # * 'filter_by_keys' reduces the number of variables loaded to only those + # in the expected list + dss: list[xr.Dataset] = cfgrib.open_datasets( + path.as_posix(), + backend_kwargs={ + "squeeze": True, + "filter_by_keys": { + "shortName": [ + x for v in self.model().expected_coordinates.variable + for x in v.metadata().alternate_shortnames + ], + }, + }, + ) + except Exception as e: + return Failure(ValueError( + f"Error opening '{path}' as list of xarray Datasets: {e}", + )) + + if len(dss) == 0: + return Failure(ValueError( + f"No datasets found in '{path}'. File may be corrupted. " + "A redownload of the file may be required.", + )) + + processed_das: list[xr.DataArray] = [] + for i, ds in enumerate(dss): + try: + ds = NOAAS3ModelRepository._rename_or_drop_vars( + ds=ds, + allowed_parameters=self.model().expected_coordinates.variable, + ) + # Ignore datasets with no variables of interest + if len(ds.data_vars) == 0: + continue + # Ignore datasets with multi-level variables + # * This would not work without the "squeeze" option in the open_datasets call, + # which reduces single-length dimensions to scalar coordinates + if any(x not in ["latitude", "longitude" ,"time"] for x in ds.dims): + continue + da: xr.DataArray = ( + ds + .rename(name_dict={"time": "init_time"}) + .expand_dims(dim="init_time") + .expand_dims(dim="step") + .to_dataarray(name=NOAAS3ModelRepository.model().name) + ) + da = ( + da.drop_vars( + names=[ + c for c in da.coords + if c not in ["init_time", "step", "variable", "latitude", "longitude"] + ], + errors="raise", + ) + .transpose("init_time", "step", "variable", "latitude", "longitude") + .assign_coords(coords={"longitude": (da.coords["longitude"] + 180) % 360 - 180}) + .sortby(variables=["step", "variable", "longitude"]) + .sortby(variables="latitude", ascending=False) + ) + except Exception as e: + return Failure(ValueError( + f"Error processing dataset {i} from '{path}' to DataArray: {e}", + )) + processed_das.append(da) + + return Success(processed_das) + + @staticmethod + def _wanted_file(filename: str, it: dt.datetime, max_step: int) -> bool: + """Determine if a file is wanted based on the init time and max step. + + See module docstring for file naming convention. + """ + pattern: str = r"^gfs\.t(\d{2})z\.pgrb2\.1p00\.f(\d{3})$" + match: re.Match[str] | None = re.search(pattern=pattern, string=filename) + if match is None: + return False + if int(match.group(1)) != it.hour: + return False + return not int(match.group(2)) > max_step + + @staticmethod + def _rename_or_drop_vars(ds: xr.Dataset, allowed_parameters: list[entities.Parameter]) \ + -> xr.Dataset: + """Rename variables to match the expected names, dropping invalid ones. + + Args: + ds: The xarray dataset to rename. + allowed_parameters: The list of parameters allowed in the resultant dataset. + """ + for var in ds.data_vars: + param_result = entities.Parameter.try_from_alternate(str(var)) + match param_result: + case Success(p): + if p in allowed_parameters: + ds = ds.rename_vars({var: p.value}) + continue + log.debug("Dropping invalid parameter '%s' from dataset", var) + ds = ds.drop_vars(str(var)) + return ds + diff --git a/src/nwp_consumer/internal/repositories/model_repositories/test_metoffice_global.py b/src/nwp_consumer/internal/repositories/model_repositories/test_metoffice_global.py index 41280637..65799504 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/test_metoffice_global.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/test_metoffice_global.py @@ -8,11 +8,11 @@ from nwp_consumer.internal import entities -from .metoffice_global import CedaMetOfficeGlobalModelRepository +from .metoffice_global import CEDAFTPModelRepository -class TestCedaMetOfficeGlobalModelRepository(unittest.TestCase): - """Test the business methods of the CedaMetOfficeGlobalModelRepository class.""" +class TestCEDAFTPModelRepository(unittest.TestCase): + """Test the business methods of the CEDAFTPModelRepository class.""" @unittest.skipIf( condition="CI" in os.environ, @@ -21,7 +21,7 @@ class TestCedaMetOfficeGlobalModelRepository(unittest.TestCase): def test__download_and_convert(self) -> None: """Test the _download_and_convert method.""" - auth_result = CedaMetOfficeGlobalModelRepository.authenticate() + auth_result = CEDAFTPModelRepository.authenticate() self.assertTrue(is_successful(auth_result), msg=f"Error: {auth_result.failure}") c = auth_result.unwrap() diff --git a/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_gfs.py b/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_gfs.py new file mode 100644 index 00000000..e6a23f58 --- /dev/null +++ b/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_gfs.py @@ -0,0 +1,111 @@ +import dataclasses +import datetime as dt +import os +import unittest +from typing import TYPE_CHECKING + +import s3fs +from returns.pipeline import is_successful + +from ...entities import NWPDimensionCoordinateMap +from .noaa_gfs import NOAAS3ModelRepository + +if TYPE_CHECKING: + import xarray as xr + + from nwp_consumer.internal import entities + + +class TestECMWFRealTimeS3ModelRepository(unittest.TestCase): + """Test the business methods of the ECMWFRealTimeS3ModelRepository class.""" + + @unittest.skipIf( + condition="CI" in os.environ, + reason="Skipping integration test that requires S3 access.", + ) # TODO: Move into integration tests, or remove + def test__download_and_convert(self) -> None: + """Test the _download_and_convert method.""" + + c: NOAAS3ModelRepository = NOAAS3ModelRepository.authenticate().unwrap() + + test_it: dt.datetime = dt.datetime(2024, 10, 24, 12, tzinfo=dt.UTC) + test_coordinates: entities.NWPDimensionCoordinateMap = dataclasses.replace( + c.model().expected_coordinates, + init_time=[test_it], + ) + + fs = s3fs.S3FileSystem(anon=True) + bucket_path: str = f"noaa-gfs-bdp-pds/gfs.{test_it:%Y%m%d}/{test_it:%H}/atmos" + urls: list[str] = [ + f"s3://{f}" + for f in fs.ls(bucket_path) + if c._wanted_file( + filename=f.split("/")[-1], + it=test_it, + max_step=max(c.model().expected_coordinates.step), + ) + ] + + for url in urls: + with self.subTest(url=url): + result = c._download_and_convert(url) + + self.assertTrue(is_successful(result), msg=f"Error: {result}") + + da: xr.DataArray = result.unwrap()[0] + determine_region_result = NWPDimensionCoordinateMap.from_xarray(da).bind( + test_coordinates.determine_region, + ) + self.assertTrue( + is_successful(determine_region_result), + msg=f"Error: {determine_region_result}", + ) + + def test__wanted_file(self) -> None: + """Test the _wanted_file method.""" + + @dataclasses.dataclass + class TestCase: + name: str + filename: str + expected: bool + + test_it: dt.datetime = dt.datetime(2024, 10, 25, 0, tzinfo=dt.UTC) + + tests: list[TestCase] = [ + TestCase( + name="valid_filename", + filename=f"gfs.t{test_it:%H}z.pgrb2.1p00.f000", + expected=True, + ), + TestCase( + name="invalid_init_time", + filename="gfs.t02z.pgrb2.1p00.f000", + expected=False, + ), + TestCase( + name="invalid_prefix", + filename=f"gfs.t{test_it:%H}z.pgrb2.0p20.f006", + expected=False, + ), + TestCase( + name="unexpected_extension", + filename=f"gfs.t{test_it:%H}z.pgrb2.1p00.f030.nc", + expected=False, + ), + TestCase( + name="step_too_large", + filename=f"gfs.t{test_it:%H}z.pgrb2.1p00.f049", + expected=False, + ), + ] + + for t in tests: + with self.subTest(name=t.name): + result = NOAAS3ModelRepository._wanted_file( + filename=t.filename, + it=test_it, + max_step=max(NOAAS3ModelRepository.model().expected_coordinates.step), + ) + self.assertEqual(result, t.expected) + diff --git a/src/nwp_consumer/internal/services/archiver_service.py b/src/nwp_consumer/internal/services/archiver_service.py index 160e2517..7b8be59d 100644 --- a/src/nwp_consumer/internal/services/archiver_service.py +++ b/src/nwp_consumer/internal/services/archiver_service.py @@ -2,10 +2,11 @@ import dataclasses import logging +import os import pathlib from typing import TYPE_CHECKING, override -from joblib import Parallel +from joblib import Parallel, cpu_count from returns.result import Failure, ResultE, Success from nwp_consumer.internal import entities, ports @@ -45,12 +46,12 @@ def archive(self, year: int, month: int) -> ResultE[pathlib.Path]: # Create a store for the archive init_store_result: ResultE[entities.TensorStore] = \ entities.TensorStore.initialize_empty_store( - name=self.mr.repository().name, + model=self.mr.model().name, + repository=self.mr.repository().name, coords=dataclasses.replace( self.mr.model().expected_coordinates, init_time=init_times, ), - overwrite_existing=False, ) match init_store_result: @@ -85,8 +86,12 @@ def archive(self, year: int, month: int) -> ResultE[pathlib.Path]: amr = amr_result.unwrap() # Create a generator to fetch and process raw data + n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections) + if os.getenv("CONCURRENCY", "True").capitalize() == "False": + n_jobs = 1 + log.debug(f"Downloading using {n_jobs} concurrent thread(s)") da_result_generator = Parallel( - n_jobs=self.mr.repository().max_connections - 1, + n_jobs=n_jobs, prefer="threads", return_as="generator_unordered", )(amr.fetch_init_data(it=it)) @@ -115,8 +120,8 @@ def archive(self, year: int, month: int) -> ResultE[pathlib.Path]: monitor.join() notify_result = self.nr().notify( message=entities.StoreCreatedNotification( - filename=store.path.name, - size_mb=store.size_mb, + filename=pathlib.Path(store.path).name, + size_mb=store.size_kb // 1024, performance=entities.PerformanceMetadata( duration_seconds=monitor.get_runtime(), memory_mb=max(monitor.memory_buffer) / 1e6, @@ -135,3 +140,4 @@ def archive(self, year: int, month: int) -> ResultE[pathlib.Path]: return Failure( TypeError(f"Unexpected result type: {type(init_store_result)}"), ) + diff --git a/src/nwp_consumer/internal/services/consumer_service.py b/src/nwp_consumer/internal/services/consumer_service.py index dd470848..29aaca4f 100644 --- a/src/nwp_consumer/internal/services/consumer_service.py +++ b/src/nwp_consumer/internal/services/consumer_service.py @@ -3,10 +3,11 @@ import dataclasses import datetime as dt import logging +import os import pathlib from typing import override -from joblib import Parallel +from joblib import Parallel, cpu_count from returns.result import Failure, ResultE, Success from nwp_consumer.internal import entities, ports @@ -50,7 +51,8 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[pathlib.Path]: # Create a store for the init time init_store_result: ResultE[entities.TensorStore] = \ entities.TensorStore.initialize_empty_store( - name=self.mr.model().name, + model=self.mr.model().name, + repository=self.mr.repository().name, coords=dataclasses.replace(self.mr.model().expected_coordinates, init_time=[it]), ) @@ -71,8 +73,13 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[pathlib.Path]: )) amr = amr_result.unwrap() + n_jobs: int = max(cpu_count() - 1, self.mr.repository().max_connections) + if os.getenv("CONCURRENCY", "True").capitalize() == "False": + n_jobs = 1 + log.debug(f"Downloading using {n_jobs} concurrent thread(s)") fetch_result_generator = Parallel( - n_jobs=1, # TODO - fix segfault when using multiple threads + # TODO - fix segfault when using multiple threads + n_jobs=n_jobs, prefer="threads", return_as="generator_unordered", )(amr.fetch_init_data(it=it)) @@ -117,8 +124,8 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[pathlib.Path]: monitor.join() notify_result = self.nr().notify( message=entities.StoreCreatedNotification( - filename=store.path.name, - size_mb=store.size_mb, + filename=pathlib.Path(store.path).name, + size_mb=store.size_kb // 1024, performance=entities.PerformanceMetadata( duration_seconds=monitor.get_runtime(), memory_mb=max(monitor.memory_buffer) / 1e6, diff --git a/src/test_integration/test_integration.py b/src/test_integration/test_integration.py index fe1854a2..d2d0cae5 100644 --- a/src/test_integration/test_integration.py +++ b/src/test_integration/test_integration.py @@ -11,11 +11,11 @@ class TestIntegration(unittest.TestCase): def test_ceda_metoffice_global_model(self) -> None: c = handlers.CLIHandler( consumer_usecase=services.ConsumerService( - model_repository=repositories.CedaMetOfficeGlobalModelRepository, + model_repository=repositories.CEDAFTPModelRepository, notification_repository=repositories.StdoutNotificationRepository, ), archiver_usecase=services.ArchiverService( - model_repository=repositories.CedaMetOfficeGlobalModelRepository, + model_repository=repositories.CEDAFTPModelRepository, notification_repository=repositories.StdoutNotificationRepository, ), )