From 0f6c84b899bc46f717fb61eb7fb9ac49ff50fc7d Mon Sep 17 00:00:00 2001 From: devsjc <47188100+devsjc@users.noreply.github.com> Date: Thu, 31 Oct 2024 15:17:31 +0000 Subject: [PATCH] feat(models): Add GFS Model repository --- pyproject.toml | 2 +- src/nwp_consumer/{cmd => cmd.tmp}/__init__.py | 0 src/nwp_consumer/{cmd => cmd.tmp}/main.py | 7 +- .../internal/entities/coordinates.py | 2 +- .../internal/entities/parameters.py | 32 ++ .../internal/entities/test_parameters.py | 10 + .../internal/repositories/__init__.py | 2 + .../model_repositories/__init__.py | 2 + .../model_repositories/ecmwf_realtime.py | 66 ++-- .../model_repositories/metoffice_global.py | 60 ++-- .../model_repositories/noaa_gfs.py | 286 ++++++++++++++++++ .../model_repositories/test_noaa_gfs.py | 109 +++++++ 12 files changed, 497 insertions(+), 81 deletions(-) rename src/nwp_consumer/{cmd => cmd.tmp}/__init__.py (100%) rename src/nwp_consumer/{cmd => cmd.tmp}/main.py (93%) create mode 100644 src/nwp_consumer/internal/repositories/model_repositories/noaa_gfs.py create mode 100644 src/nwp_consumer/internal/repositories/model_repositories/test_noaa_gfs.py diff --git a/pyproject.toml b/pyproject.toml index d2110286..490d3143 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ authors = [ classifiers = ["Programming Language :: Python :: 3"] dependencies = [ "dask == 2024.8.1", - "eccodes == 2.38.1", + "eccodes == 2.38.3", "ecmwf-api-client == 1.6.3", "cfgrib == 0.9.14.0", "dagster-pipes == 1.8.5", diff --git a/src/nwp_consumer/cmd/__init__.py b/src/nwp_consumer/cmd.tmp/__init__.py similarity index 100% rename from src/nwp_consumer/cmd/__init__.py rename to src/nwp_consumer/cmd.tmp/__init__.py diff --git a/src/nwp_consumer/cmd/main.py b/src/nwp_consumer/cmd.tmp/main.py similarity index 93% rename from src/nwp_consumer/cmd/main.py rename to src/nwp_consumer/cmd.tmp/main.py index 6d001863..2c63f6ee 100644 --- a/src/nwp_consumer/cmd/main.py +++ b/src/nwp_consumer/cmd.tmp/main.py @@ -18,12 +18,11 @@ def parse_env() -> Adaptors: """Parse from the environment.""" model_repository_adaptor: type[ports.ModelRepository] match os.getenv("MODEL_REPOSITORY"): - case None: - log.error("MODEL_REPOSITORY is not set in environment.") - sys.exit(1) + case None | "gfs": + model_repository_adaptor = repositories.NOAAGFSS3ModelRepository case "ceda": model_repository_adaptor = repositories.CedaMetOfficeGlobalModelRepository - case "ecmwf-realtime-s3": + case "ecmwf-realtime": model_repository_adaptor = repositories.ECMWFRealTimeS3ModelRepository case _ as model: log.error(f"Unknown model: {model}") diff --git a/src/nwp_consumer/internal/entities/coordinates.py b/src/nwp_consumer/internal/entities/coordinates.py index 565fe707..45a6d241 100644 --- a/src/nwp_consumer/internal/entities/coordinates.py +++ b/src/nwp_consumer/internal/entities/coordinates.py @@ -207,7 +207,7 @@ def to_pandas(self) -> dict[str, pd.Index]: # type: ignore This is useful for interoperability with xarray, which prefers to define DataArray coordinates using a dict pandas Index objects. - For the most part, the conversion consists of a straighforward cast + For the most part, the conversion consists of a straightforward cast to a pandas Index object. However, there are some caveats involving the time-centric dimensions: diff --git a/src/nwp_consumer/internal/entities/parameters.py b/src/nwp_consumer/internal/entities/parameters.py index a15fc0d2..1735755d 100644 --- a/src/nwp_consumer/internal/entities/parameters.py +++ b/src/nwp_consumer/internal/entities/parameters.py @@ -23,6 +23,8 @@ import dataclasses from enum import StrEnum, auto +from returns.result import Failure, ResultE, Success + @dataclasses.dataclass(slots=True) class ParameterLimits: @@ -77,6 +79,9 @@ class ParameterData: Used in sanity and validity checking the database values. """ + alternate_shortnames: list[str] = dataclasses.field(default_factory=list) + """Alternate names for the parameter found in the wild.""" + def __str__(self) -> str: """String representation of the parameter.""" return self.name @@ -121,6 +126,7 @@ def metadata(self) -> ParameterData: description="Temperature at screen level", units="C", limits=ParameterLimits(upper=60, lower=-90), + alternate_shortnames=["t", "t2m"], ) case self.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.name: return ParameterData( @@ -130,6 +136,7 @@ def metadata(self) -> ParameterData: "incident on the surface expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=1500, lower=0), + alternate_shortnames=["swavr", "ssrd", "dswrf"], ) case self.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL.name: return ParameterData( @@ -139,6 +146,7 @@ def metadata(self) -> ParameterData: "incident on the surface expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=500, lower=0), + alternate_shortnames=["strd", "dlwrf"] ) case self.RELATIVE_HUMIDITY_SL.name: return ParameterData( @@ -148,6 +156,7 @@ def metadata(self) -> ParameterData: "to the equilibrium vapour pressure of water", units="%", limits=ParameterLimits(upper=100, lower=0), + alternate_shortnames=["r"], ) case self.VISIBILITY_SL.name: return ParameterData( @@ -157,6 +166,7 @@ def metadata(self) -> ParameterData: "horizontally in daylight conditions.", units="m", limits=ParameterLimits(upper=4500, lower=0), + alternate_shortnames=["vis"], ) case self.WIND_U_COMPONENT_10m.name: return ParameterData( @@ -166,6 +176,7 @@ def metadata(self) -> ParameterData: "the wind in the eastward direction.", units="m/s", limits=ParameterLimits(upper=100, lower=-100), + alternate_shortnames=["u10"], ) case self.WIND_V_COMPONENT_10m.name: return ParameterData( @@ -176,6 +187,7 @@ def metadata(self) -> ParameterData: units="m/s", # Non-tornadic winds are usually < 100m/s limits=ParameterLimits(upper=100, lower=-100), + alternate_shortnames=["v10"], ) case self.WIND_U_COMPONENT_100m.name: return ParameterData( @@ -185,6 +197,7 @@ def metadata(self) -> ParameterData: "the wind in the eastward direction.", units="m/s", limits=ParameterLimits(upper=100, lower=-100), + alternate_shortnames=["u100"], ) case self.WIND_V_COMPONENT_100m.name: return ParameterData( @@ -194,6 +207,7 @@ def metadata(self) -> ParameterData: "the wind in the northward direction.", units="m/s", limits=ParameterLimits(upper=100, lower=-100), + alternate_shortnames=["v100"], ) case self.WIND_U_COMPONENT_200m.name: return ParameterData( @@ -203,6 +217,7 @@ def metadata(self) -> ParameterData: "the wind in the eastward direction.", units="m/s", limits=ParameterLimits(upper=150, lower=-150), + alternate_shortnames=["u200"], ) case self.WIND_V_COMPONENT_200m.name: return ParameterData( @@ -212,6 +227,7 @@ def metadata(self) -> ParameterData: "the wind in the northward direction.", units="m/s", limits=ParameterLimits(upper=150, lower=-150), + alternate_shortnames=["v200"], ) case self.SNOW_DEPTH_GL.name: return ParameterData( @@ -219,6 +235,7 @@ def metadata(self) -> ParameterData: description="Depth of snow on the ground.", units="m", limits=ParameterLimits(upper=12, lower=0), + alternate_shortnames=["sd", "sdwe"], ) case self.CLOUD_COVER_HIGH.name: return ParameterData( @@ -229,6 +246,7 @@ def metadata(self) -> ParameterData: "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), + alternate_shortnames=["hcc"], ) case self.CLOUD_COVER_MEDIUM.name: return ParameterData( @@ -239,6 +257,7 @@ def metadata(self) -> ParameterData: "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), + alternate_shortnames=["mcc"], ) case self.CLOUD_COVER_LOW.name: return ParameterData( @@ -249,6 +268,7 @@ def metadata(self) -> ParameterData: "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), + alternate_shortnames=["lcc"], ) case self.CLOUD_COVER_TOTAL.name: return ParameterData( @@ -259,6 +279,7 @@ def metadata(self) -> ParameterData: "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), + alternate_shortnames=["tcc", "clt"], ) case self.TOTAL_PRECIPITATION_RATE_GL.name: return ParameterData( @@ -268,6 +289,7 @@ def metadata(self) -> ParameterData: "including rain, snow, and hail.", units="kg/m^2/s", limits=ParameterLimits(upper=0.2, lower=0), + alternate_shortnames=["prate", "tprate"], ) case self.DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL.name: return ParameterData( @@ -278,6 +300,7 @@ def metadata(self) -> ParameterData: "expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=1000, lower=0), + alternate_shortnames=["uvb"], ) case self.DIRECT_SHORTWAVE_RADIATION_FLUX_GL.name: return ParameterData( @@ -289,7 +312,16 @@ def metadata(self) -> ParameterData: "expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=1000, lower=0), + alternate_shortnames=["dsrp"], ) case _: # Shouldn't happen thanks to the test case in test_parameters.py raise ValueError(f"Unknown parameter: {self}") + + def try_from_alternate(name: str) -> ResultE["Parameter"]: + """Map an alternate name to a parameter.""" + for p in Parameter: + if name in p.metadata().alternate_shortnames: + return Success(p) + return Failure(ValueError(f"Unknown shortname: {name}")) + diff --git a/src/nwp_consumer/internal/entities/test_parameters.py b/src/nwp_consumer/internal/entities/test_parameters.py index 1470e28f..c8b864f2 100644 --- a/src/nwp_consumer/internal/entities/test_parameters.py +++ b/src/nwp_consumer/internal/entities/test_parameters.py @@ -2,6 +2,7 @@ from hypothesis import given from hypothesis import strategies as st +from returns.pipeline import is_successful from .parameters import Parameter @@ -15,6 +16,15 @@ def test_metadata(self, p: Parameter) -> None: metadata = p.metadata() self.assertEqual(metadata.name, p.value) + @given(st.sampled_from([s for p in Parameter for s in p.metadata().alternate_shortnames])) + def test_try_from_shortname(self, shortname: str) -> None: + """Test the try_from_shortname method.""" + p = Parameter.try_from_alternate(shortname) + self.assertTrue(is_successful(p)) + + p = Parameter.try_from_alternate("invalid") + self.assertFalse(is_successful(p)) + if __name__ == "__main__": unittest.main() diff --git a/src/nwp_consumer/internal/repositories/__init__.py b/src/nwp_consumer/internal/repositories/__init__.py index d1f60926..ae396385 100644 --- a/src/nwp_consumer/internal/repositories/__init__.py +++ b/src/nwp_consumer/internal/repositories/__init__.py @@ -26,6 +26,7 @@ from .model_repositories import ( CedaMetOfficeGlobalModelRepository, ECMWFRealTimeS3ModelRepository, + NOAAGFSS3ModelRepository, ) from .notification_repositories import ( StdoutNotificationRepository, @@ -35,6 +36,7 @@ __all__ = [ "CedaMetOfficeGlobalModelRepository", "ECMWFRealTimeS3ModelRepository", + "NOAAGFSS3ModelRepository", "StdoutNotificationRepository", "DagsterPipesNotificationRepository", ] diff --git a/src/nwp_consumer/internal/repositories/model_repositories/__init__.py b/src/nwp_consumer/internal/repositories/model_repositories/__init__.py index 3580a1cc..65e6ae2a 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/__init__.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/__init__.py @@ -1,8 +1,10 @@ from .metoffice_global import CedaMetOfficeGlobalModelRepository from .ecmwf_realtime import ECMWFRealTimeS3ModelRepository +from .noaa_gfs import NOAAGFSS3ModelRepository __all__ = [ "CedaMetOfficeGlobalModelRepository", "ECMWFRealTimeS3ModelRepository", + "NOAAGFSS3ModelRepository", ] diff --git a/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py b/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py index f726bcc8..c8c36cd6 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py @@ -65,7 +65,7 @@ def repository() -> entities.ModelRepositoryMetadata: name="ECMWF-Realtime-S3", is_archive=False, is_order_based=True, - running_hours=[0, 12], + running_hours=[0, 6, 12, 18], delay_minutes=(60 * 6), # 6 hours max_connections=100, required_env=[ @@ -196,7 +196,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: # Only download the file if not already present if not local_path.exists(): local_path.parent.mkdir(parents=True, exist_ok=True) - log.info("Requesting file from S3 at: '%s'", url) + log.debug("Requesting file from S3 at: '%s'", url) try: if not self._fs.exists(url): @@ -234,13 +234,19 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: f"Error opening '{path}' as list of xarray Datasets: {e}", )) if len(dss) == 0: - return Failure(ValueError(f"No datasets found in '{path}'")) + return Failure(ValueError( + f"No datasets found in '{path}'. File may be corrupted. " + "A redownload of the file may be required.", + )) processed_das: list[xr.DataArray] = [] for i, ds in enumerate(dss): try: da: xr.DataArray = ( - ds.pipe(ECMWFRealTimeS3ModelRepository._rename_vars) + ECMWFRealTimeS3ModelRepository._rename_or_drop_vars( + ds=ds, + allowed_parameters=ECMWFRealTimeS3ModelRepository.model().expected_coordinates.variable, + ) .rename(name_dict={"time": "init_time"}) .expand_dims(dim="init_time") .expand_dims(dim="step") @@ -274,36 +280,6 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: return Success(processed_das) - @staticmethod - def _rename_vars(ds: xr.Dataset) -> xr.Dataset: - """Rename variables to match the expected names.""" - rename_map: dict[str, str] = { - "dsrp": entities.Parameter.DIRECT_SHORTWAVE_RADIATION_FLUX_GL.value, - "uvb": entities.Parameter.DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL.value, - "sd": entities.Parameter.SNOW_DEPTH_GL.value, - "tcc": entities.Parameter.CLOUD_COVER_TOTAL.value, - "clt": entities.Parameter.CLOUD_COVER_TOTAL.value, - "u10": entities.Parameter.WIND_U_COMPONENT_10m.value, - "v10": entities.Parameter.WIND_V_COMPONENT_10m.value, - "t2m": entities.Parameter.TEMPERATURE_SL.value, - "ssrd": entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.value, - "strd": entities.Parameter.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL.value, - "lcc": entities.Parameter.CLOUD_COVER_LOW.value, - "mcc": entities.Parameter.CLOUD_COVER_MEDIUM.value, - "hcc": entities.Parameter.CLOUD_COVER_HIGH.value, - "vis": entities.Parameter.VISIBILITY_SL.value, - "u200": entities.Parameter.WIND_U_COMPONENT_200m.value, - "v200": entities.Parameter.WIND_V_COMPONENT_200m.value, - "u100": entities.Parameter.WIND_U_COMPONENT_100m.value, - "v100": entities.Parameter.WIND_V_COMPONENT_100m.value, - "tprate": entities.Parameter.TOTAL_PRECIPITATION_RATE_GL.value, - } - - for old, new in rename_map.items(): - if old in ds.data_vars: - ds = ds.rename({old: new}) - return ds - @staticmethod def _wanted_file(filename: str, it: dt.datetime, max_step: int) -> bool: """Determine if the file is wanted based on the init time. @@ -329,3 +305,25 @@ def _wanted_file(filename: str, it: dt.datetime, max_step: int) -> bool: "%Y%m%d%H%M%z", ) return tt < it + dt.timedelta(hours=max_step) + + + @staticmethod + def _rename_or_drop_vars(ds: xr.Dataset, allowed_parameters: list[entities.Parameter]) \ + -> xr.Dataset: + """Rename variables to match the expected names, dropping invalid ones. + + Args: + ds: The xarray dataset to rename. + allowed_parameters: The list of parameters allowed in the resultant dataset. + """ + for var in ds.data_vars: + param_result = entities.Parameter.try_from_alternate(str(var)) + match param_result: + case Success(p): + if p in allowed_parameters: + ds = ds.rename_vars({var: p.value}) + continue + log.warning("Dropping invalid parameter '%s' from dataset", var) + ds = ds.drop_vars(str(var)) + return ds + diff --git a/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py b/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py index 320d6d58..8a72183f 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py @@ -285,7 +285,11 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: ) try: da: xr.DataArray = ( - ds.sel(step=[np.timedelta64(i, "h") for i in range(0, 48, 1)]) + CedaMetOfficeGlobalModelRepository._rename_or_drop_vars( + ds=ds, + allowed_parameters=CedaMetOfficeGlobalModelRepository.model().expected_coordinates.variable, + ) + .sel(step=[np.timedelta64(i, "h") for i in range(0, 48, 1)]) .expand_dims(dim={"init_time": [ds["time"].values]}) .drop_vars( names=[ @@ -294,7 +298,6 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: if v not in ["init_time", "step", "latitude", "longitude"] ], ) - .pipe(CedaMetOfficeGlobalModelRepository._rename_vars) .to_dataarray(name=CedaMetOfficeGlobalModelRepository.model().name) .transpose("init_time", "step", "variable", "latitude", "longitude") # Remove the last value of the longitude dimension as it overlaps with the next file @@ -311,47 +314,22 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: @staticmethod - def _rename_vars(ds: xr.Dataset) -> xr.Dataset: - """Rename variables to match the expected names. - - To find the names as they exist in the raw files, the following - function was used: - - >>> import xarray as xr - >>> import urllib.request - >>> import datetime as dt - >>> - >>> def download_single_file(parameter: str) -> xr.Dataset: - >>> it = dt.datetime(2021, 1, 1, 0, tzinfo=dt.UTC) - >>> base_url = "ftp://:@ftp.ceda.ac.uk/badc/ukmo-nwp/data/global-grib" - >>> url = f"{base_url}/{it:%Y/%m/%d}/" + \ - >>> f"{it:%Y%m%d%H}_WSGlobal17km_{parameter}_AreaA_000144.grib" - >>> response = urllib.request.urlopen(url) - >>> with open("/tmp/mo-global/test.grib", "wb") as f: - >>> for chunk in iter(lambda: response.read(16 * 1024), b""): - >>> f.write(chunk) - >>> f.flush() - >>> - >>> ds = xr.open_dataset("/tmp/mo-global/test.grib", engine="cfgrib") - >>> return ds + def _rename_or_drop_vars(ds: xr.Dataset, allowed_parameters: list[entities.Parameter]) \ + -> xr.Dataset: + """Rename variables to match the expected names, dropping invalid ones. Args: ds: The xarray dataset to rename. + allowed_parameters: The list of parameters allowed in the resultant dataset. """ - rename_map: dict[str, str] = { - "t": entities.Parameter.TEMPERATURE_SL.value, - "r": entities.Parameter.RELATIVE_HUMIDITY_SL.value, - "sf": entities.Parameter.SNOW_DEPTH_GL.value, - "prate": entities.Parameter.TOTAL_PRECIPITATION_RATE_GL.value, - "swavr": entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.value, - "u": entities.Parameter.WIND_U_COMPONENT_10m.value, - "v": entities.Parameter.WIND_V_COMPONENT_10m.value, - "vis": entities.Parameter.VISIBILITY_SL.value, - "hcc": entities.Parameter.CLOUD_COVER_HIGH.value, - "lcc": entities.Parameter.CLOUD_COVER_LOW.value, - "mcc": entities.Parameter.CLOUD_COVER_MEDIUM.value, - } - for old, new in rename_map.items(): - if old in ds.data_vars: - ds = ds.rename_vars({old: new}) + for var in ds.data_vars: + param_result = entities.Parameter.try_from_alternate(str(var)) + match param_result: + case Success(p): + if p in allowed_parameters: + ds = ds.rename_vars({var: p.value}) + continue + log.warning("Dropping invalid parameter '%s' from dataset", var) + ds = ds.drop_vars(str(var)) return ds + diff --git a/src/nwp_consumer/internal/repositories/model_repositories/noaa_gfs.py b/src/nwp_consumer/internal/repositories/model_repositories/noaa_gfs.py new file mode 100644 index 00000000..08fc70ab --- /dev/null +++ b/src/nwp_consumer/internal/repositories/model_repositories/noaa_gfs.py @@ -0,0 +1,286 @@ +import datetime as dt +import logging +import os +import pathlib +import re +from collections.abc import Callable, Iterator +from readline import backend +from typing import override + +import cfgrib +import s3fs +import xarray as xr +from joblib import delayed +from returns.result import Failure, ResultE, Success + +from nwp_consumer.internal import entities, ports + +log = logging.getLogger("nwp-consumer") + + +class NOAAGFSS3ModelRepository(ports.ModelRepository): + """Model repository implementation for GFS data stored in S3.""" + + @staticmethod + @override + def repository() -> entities.ModelRepositoryMetadata: + return entities.ModelRepositoryMetadata( + name="NOAA-GFS-S3", + is_archive=False, + is_order_based=False, + running_hours=[0, 6, 12, 18], + delay_minutes=(60 * 24 * 7), # 1 week + max_connections=100, + required_env=[], + optional_env={}, + postprocess_options=entities.PostProcessOptions(), + ) + + @staticmethod + @override + def model() -> entities.ModelMetadata: + return entities.ModelMetadata( + name="NCEP-GFS", + resolution="1 degree", + expected_coordinates=entities.NWPDimensionCoordinateMap( + init_time=[], + step=list(range(0, 49, 3)), + variable=sorted( + [ + entities.Parameter.TEMPERATURE_SL, + entities.Parameter.CLOUD_COVER_TOTAL, + entities.Parameter.CLOUD_COVER_HIGH, + entities.Parameter.CLOUD_COVER_MEDIUM, + entities.Parameter.CLOUD_COVER_LOW, + entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL, + entities.Parameter.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL, + entities.Parameter.TOTAL_PRECIPITATION_RATE_GL, + entities.Parameter.SNOW_DEPTH_GL, + entities.Parameter.RELATIVE_HUMIDITY_SL, + entities.Parameter.VISIBILITY_SL, + entities.Parameter.WIND_U_COMPONENT_10m, + entities.Parameter.WIND_V_COMPONENT_10m, + entities.Parameter.WIND_U_COMPONENT_100m, + entities.Parameter.WIND_V_COMPONENT_100m, + ], + ), + latitude=[float(lat) for lat in range(90, -90 - 1, -1)], + longitude=[float(lon) for lon in range(-180, 180 + 1, 1)], + ), + ) + + @override + def fetch_init_data( + self, it: dt.datetime, + ) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: + # List relevant files in the s3 bucket + bucket_path: str = f"noaa-gfs-bdp-pds/gfs.{it:%Y%m%d}/{it:%H}/atmos" + try: + fs = s3fs.S3FileSystem(anon=True) + urls: list[str] = [ + f"s3://{f}" + for f in fs.ls(bucket_path) + if self._wanted_file( + filename=f.split("/")[-1], + it=it, + max_step=max(self.model().expected_coordinates.step), + ) + ] + except Exception as e: + yield delayed(Failure)( + ValueError( + f"Failed to list file in bucket path '{bucket_path}'. " + "Ensure the path exists and the bucket does not require auth. " + f"Encountered error: '{e}'", + ), + ) + return + + if len(urls) == 0: + yield delayed(Failure)( + ValueError( + f"No files found for init time '{it:%Y-%m-%d %H:%M}'. " + "in bucket path '{bucket_path}'. Ensure files exists at the given path " + "with the expected filename pattern. ", + ), + ) + + for url in urls: + yield delayed(self._download_and_convert)(url=url) + + @classmethod + @override + def authenticate(cls) -> ResultE["NOAAGFSS3ModelRepository"]: + return Success(cls()) + + def _download_and_convert(self, url: str) -> ResultE[list[xr.DataArray]]: + """Download and convert a file from S3. + + Args: + url: The URL to the S3 object. + """ + return self._download(url).bind(self._convert) + + def _download(self, url: str) -> ResultE[pathlib.Path]: + """Download an ECMWF realtime file from S3. + + Args: + url: The URL to the S3 object. + """ + local_path: pathlib.Path = ( + pathlib.Path( + os.getenv( + "RAWDIR", + f"~/.local/cache/nwp/{self.repository().name}/{self.model().name}/raw", + ), + ) / url.split("/")[-1] + ).with_suffix(".grib").expanduser() + + # Only download the file if not already present + if not local_path.exists(): + local_path.parent.mkdir(parents=True, exist_ok=True) + log.debug("Requesting file from S3 at: '%s'", url) + + fs = s3fs.S3FileSystem(anon=True) + try: + if not fs.exists(url): + raise FileNotFoundError(f"File not found at '{url}'") + + with local_path.open("wb") as lf, fs.open(url, "rb") as rf: + for chunk in iter(lambda: rf.read(12 * 1024), b""): + lf.write(chunk) + lf.flush() + + except Exception as e: + return Failure(OSError( + f"Failed to download file from S3 at '{url}'. Encountered error: {e}", + )) + + if local_path.stat().st_size != fs.info(url)["size"]: + return Failure(ValueError( + f"Failed to download file from S3 at '{url}'. " + "File size mismatch. File may be corrupted.", + )) + + # Also download the associated index file + # * This isn't critical, but speeds up reading the file in when converting + # TODO: Re-incorporate this when https://github.com/ecmwf/cfgrib/issues/350 + # TODO: is resolved. Currently downloaded index files are ignored due to + # TODO: path differences once downloaded. + index_url: str = url + ".idx" + index_path: pathlib.Path = local_path.with_suffix(".grib.idx") + try: + with index_path.open("wb") as lf, fs.open(index_url, "rb") as rf: + for chunk in iter(lambda: rf.read(12 * 1024), b""): + lf.write(chunk) + lf.flush() + except Exception as e: + log.warning( + f"Failed to download index file from S3 at '{url}'. " + "This will require a manual indexing when converting the file. " + f"Encountered error: {e}", + ) + + return Success(local_path) + + def _convert(self, path: pathlib.Path) -> ResultE[list[xr.DataArray]]: + """Convert a GFS file to an xarray DataArray collection. + + Args: + path: The path to the local grib file. + """ + try: + # Squeeze reduces length-1- dimensions to scalar coordinates, + # Thus single-level variables should not have any extra dimensions + dss: list[xr.Dataset] = cfgrib.open_datasets( + path.as_posix(), + backend_kwargs={"squeeze": True}, + ) + except Exception as e: + return Failure(ValueError( + f"Error opening '{path}' as list of xarray Datasets: {e}", + )) + + if len(dss) == 0: + return Failure(ValueError( + f"No datasets found in '{path}'. File may be corrupted. " + "A redownload of the file may be required.", + )) + + processed_das: list[xr.DataArray] = [] + for i, ds in enumerate(dss): + try: + ds = NOAAGFSS3ModelRepository._rename_or_drop_vars( + ds=ds, + allowed_parameters=self.model().expected_coordinates.variable, + ) + # Ignore datasets with no variables of interest + if len(ds.data_vars) == 0: + continue + # Ignore datasets with multi-level variables + # * This would not work without the "squeeze" option in the open_datasets call, + # which reduces single-length dimensions to scalar coordinates + if any(x not in ["latitude", "longitude" ,"time"] for x in ds.dims): + continue + da: xr.DataArray = ( + ds + .rename(name_dict={"time": "init_time"}) + .expand_dims(dim="init_time") + .expand_dims(dim="step") + .to_dataarray(name=NOAAGFSS3ModelRepository.model().name) + ) + da = ( + da.drop_vars( + names=[ + c for c in da.coords + if c not in ["init_time", "step", "variable", "latitude", "longitude"] + ], + errors="raise", + ) + .transpose("init_time", "step", "variable", "latitude", "longitude") + .assign_coords(coords={"longitude": (da.coords["longitude"] + 180) % 360 - 180}) + .sortby(variables=["step", "variable", "longitude"]) + .sortby(variables="latitude", ascending=False) + ) + except Exception as e: + return Failure(ValueError( + f"Error processing dataset {i} from '{path}' to DataArray: {e}", + )) + processed_das.append(da) + + return Success(processed_das) + + @staticmethod + def _wanted_file(filename: str, it: dt.datetime, max_step: int) -> bool: + """Determine if a file is wanted based on the init time and max step. + + See module docstring for file naming convention. + """ + pattern: str = r"^gfs\.t(\d{2})z\.pgrb2\.1p00\.f(\d{3})$" + match: re.Match[str] | None = re.search(pattern=pattern, string=filename) + if match is None: + return False + if int(match.group(1)) != it.hour: + return False + return not int(match.group(2)) > max_step + + @staticmethod + def _rename_or_drop_vars(ds: xr.Dataset, allowed_parameters: list[entities.Parameter]) \ + -> xr.Dataset: + """Rename variables to match the expected names, dropping invalid ones. + + Args: + ds: The xarray dataset to rename. + allowed_parameters: The list of parameters allowed in the resultant dataset. + """ + for var in ds.data_vars: + param_result = entities.Parameter.try_from_alternate(str(var)) + match param_result: + case Success(p): + if p in allowed_parameters: + ds = ds.rename_vars({var: p.value}) + continue + log.debug("Dropping invalid parameter '%s' from dataset", var) + ds = ds.drop_vars(str(var)) + return ds + diff --git a/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_gfs.py b/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_gfs.py new file mode 100644 index 00000000..9994e589 --- /dev/null +++ b/src/nwp_consumer/internal/repositories/model_repositories/test_noaa_gfs.py @@ -0,0 +1,109 @@ +import dataclasses +import datetime as dt +import os +import unittest +from typing import TYPE_CHECKING + +import s3fs +from returns.pipeline import is_successful + +from ...entities import NWPDimensionCoordinateMap +from .noaa_gfs import NOAAGFSS3ModelRepository + +if TYPE_CHECKING: + import xarray as xr + + from nwp_consumer.internal import entities + + +class TestECMWFRealTimeS3ModelRepository(unittest.TestCase): + """Test the business methods of the ECMWFRealTimeS3ModelRepository class.""" + + @unittest.skipIf( + condition="CI" in os.environ, + reason="Skipping integration test that requires S3 access.", + ) # TODO: Move into integration tests, or remove + def test__download_and_convert(self) -> None: + """Test the _download_and_convert method.""" + + c: NOAAGFSS3ModelRepository = NOAAGFSS3ModelRepository.authenticate().unwrap() + + test_it: dt.datetime = dt.datetime(2024, 10, 25, 0, tzinfo=dt.UTC) + test_coordinates: entities.NWPDimensionCoordinateMap = dataclasses.replace( + c.model().expected_coordinates, + init_time=[test_it], + ) + + fs = s3fs.S3FileSystem(anon=True) + bucket_path: str = f"noaa-gfs-bdp-pds/gfs.{test_it:%Y%m%d}/{test_it:%H}/atmos" + urls: list[str] = [ + f"s3://{f}" + for f in fs.ls(bucket_path) + if c._wanted_file( + filename=f.split("/")[-1], + it=test_it, + max_step=max(c.model().expected_coordinates.step), + ) + ] + + for url in urls: + with (self.subTest(url=url)): + result = c._download_and_convert(url) + + self.assertTrue(is_successful(result), msg=f"Error: {result}") + + da: xr.DataArray = result.unwrap()[0] + determine_region_result = NWPDimensionCoordinateMap.from_xarray(da).bind( + test_coordinates.determine_region, + ) + self.assertTrue( + is_successful(determine_region_result), + msg=f"Error: {determine_region_result}", + ) + + def test__wanted_file(self) -> None: + """Test the _wanted_file method.""" + + @dataclasses.dataclass + class TestCase: + name: str + filename: str + expected: bool + + test_it: dt.datetime = dt.datetime(2024, 10, 25, 0, tzinfo=dt.UTC) + + tests: list[TestCase] = [ + TestCase( + name="valid_filename", + filename=f"gfs.t{test_it:%H}z.pgrb2.1p00.f000", + expected=True, + ), + TestCase( + name="invalid_init_time", + filename="gfs.t02z.pgrb2.1p00.f000", + expected=False, + ), + TestCase( + name="invalid_prefix", + filename=f"gfs.t{test_it:%H}z.pgrb2.0p20.f006", + expected=False, + ), + TestCase( + name="unexpected_extension", + filename=f"gfs.t{test_it:%H}z.pgrb2.1p00.f030.nc", + expected=False, + ), + TestCase( + name="step_too_large", + filename=f"gfs.t{test_it:%H}z.pgrb2.1p00.f049", + expected=False, + ), + ] + + for t in tests: + with self.subTest(name=t.name): + result = NOAAGFSS3ModelRepository._wanted_file( + filename=t.filename, + it=test_it, + max_step=max(NOAAGFSS3ModelRepository.model().expected_coordinates.step)) + self.assertEqual(result, t.expected)