Skip to content

Commit

Permalink
fix(coordinate): Log warning on unsafe regional writes (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
devsjc authored Jan 7, 2025
1 parent 45d0c0c commit 6c3e9ed
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 42 deletions.
50 changes: 29 additions & 21 deletions src/nwp_consumer/internal/entities/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import dataclasses
import datetime as dt
import json
import logging
from importlib.metadata import PackageNotFoundError, version

import dask.array
Expand All @@ -54,12 +55,14 @@
except PackageNotFoundError:
__version__ = "v?"

log = logging.getLogger("nwp-consumer")


@dataclasses.dataclass(slots=True)
class NWPDimensionCoordinateMap:
"""Container for dimensions names and their coordinate index values.
Each field in the container is a dimension label, and the corresponding
Each public field in the container is a dimension label, and the corresponding
value is a list of the coordinate values for each index along the dimension.
All NWP data has an associated init time, step, and variable,
Expand Down Expand Up @@ -91,12 +94,6 @@ class NWPDimensionCoordinateMap:
"""
longitude: list[float] | None = None
"""The longitude coordinates of the forecast grid in degrees. """
maximum_number_of_chunks_in_one_dim: int = 8
""" The maximum number of chunks in one dimension.
When saving to S3 we might want this to be small, to reduce the number of files saved.
Will be truncated to 4 decimal places, and ordered as -180 -> 180.
"""

def __post_init__(self) -> None:
"""Rigidly set input value ordering and precision."""
Expand All @@ -119,9 +116,7 @@ def dims(self) -> list[str]:
Ignores any dimensions that do not have a corresponding coordinate
index value list.
"""
return [f.name for f in dataclasses.fields(self) if
getattr(self, f.name) is not None
and f.name != "maximum_number_of_chunks_in_one_dim"]
return [f.name for f in dataclasses.fields(self) if getattr(self, f.name) is not None]

@property
def shapemap(self) -> dict[str, int]:
Expand Down Expand Up @@ -384,6 +379,8 @@ def determine_region(
# TODO: of which might loop around the edges of the grid. In this case, it would
# TODO: be useful to determine if the run is non-contiguous only in that it wraps
# TODO: around that boundary, and in that case, split it and write it in two goes.
# TODO: 2025-01-06: I think this is a resolved problem now that fetch_init_data
# can return a list of DataArrays.
return Failure(
ValueError(
f"Coordinate values for dimension '{inner_dim_label}' do not correspond "
Expand All @@ -398,47 +395,58 @@ def determine_region(

return Success(slices)

def default_chunking(self) -> dict[str, int]:
def chunking(self, chunk_count_overrides: dict[str, int]) -> dict[str, int]:
"""The expected chunk sizes for each dimension.
A dictionary mapping of dimension labels to the size of a chunk along that
dimension. Note that the number is chunk size, not chunk number, so a chunk
that wants to cover the entire dimension should have a size equal to the
dimension length.
It defaults to a single chunk per init time and step, and 8 chunks
for each entire other dimension. These are purposefully small, to ensure
that when perfomring parallel writes, chunk boundaries are not crossed.
It defaults to a single chunk per init time, step, and variable coordinate,
and 2 chunks for each entire other dimension, unless overridden by the
`chunk_count_overrides` argument.
The defaults are purposefully small, to ensure that when performing parallel
writes, chunk boundaries are not crossed.
Args:
chunk_count_overrides: A dictionary mapping dimension labels to the
number of chunks to split the dimension into.
"""
out_dict: dict[str, int] = {
"init_time": 1,
"step": 1,
"variable": 1,
} | {
dim: len(getattr(self, dim)) // self.maximum_number_of_chunks_in_one_dim
if len(getattr(self, dim)) > self.maximum_number_of_chunks_in_one_dim else 1
dim: len(getattr(self, dim)) // chunk_count_overrides.get(dim, 2)
if len(getattr(self, dim)) > 8 else 1
for dim in self.dims
if dim not in ["init_time", "step"]
if dim not in ["init_time", "step", "variable"]
}

return out_dict


def as_zeroed_dataarray(self, name: str) -> xr.DataArray:
def as_zeroed_dataarray(self, name: str, chunks: dict[str, int]) -> xr.DataArray:
"""Express the coordinates as an xarray DataArray.
Data is populated with zeros and a default chunking scheme is applied.
The underlying dask array is a zeroed array with the shape of the dataset,
that is chunked according to the given chunking scheme.
Args:
name: The name of the DataArray.
chunks: A mapping of dimension names to the size of the chunks
along the dimensions.
See Also:
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
"""
# Create a dask array of zeros with the shape of the dataset
# * The values of this are ignored, only the shape and chunks are used
dummy_values = dask.array.zeros( # type: ignore
shape=list(self.shapemap.values()),
chunks=tuple([self.default_chunking()[k] for k in self.shapemap]),
chunks=tuple([chunks[k] for k in self.shapemap]),
)
attrs: dict[str, str] = {
"produced_by": "".join((
Expand Down
24 changes: 17 additions & 7 deletions src/nwp_consumer/internal/entities/modelmetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ class ModelMetadata:
Which prints grid data from the grib file.
"""

chunk_count_overrides: dict[str, int] = dataclasses.field(default_factory=dict)
"""Mapping of dimension names to the desired number of chunks in that dimension.
Overrides the default chunking strategy.
See Also:
- `entities.coordinates.NWPDimensionCoordinateMap.chunking`
"""

def __str__(self) -> str:
"""Return a pretty-printed string representation of the metadata."""
pretty: str = "".join((
Expand Down Expand Up @@ -93,13 +102,14 @@ def with_region(self, region: str) -> "ModelMetadata":
log.warning(f"Unknown region '{region}', not cropping expected coordinates.")
return self

def set_maximum_number_of_chunks_in_one_dim(self, maximum_number_of_chunks_in_one_dim: int) \
-> "ModelMetadata":
"""Set the maximum number of chunks in one dimension."""
self.expected_coordinates.maximum_number_of_chunks_in_one_dim \
= maximum_number_of_chunks_in_one_dim
return self

def with_chunk_count_overrides(self, overrides: dict[str, int]) -> "ModelMetadata":
"""Returns metadata for the given model with the given chunk count overrides."""
if not set(overrides.keys()).issubset(self.expected_coordinates.dims):
log.warning(
"Chunk count overrides contain keys not in the expected coordinates. "
"These will not modify the chunking strategy.",
)
return dataclasses.replace(self, chunk_count_overrides=overrides)

class Models:
"""Namespace containing known models."""
Expand Down
39 changes: 36 additions & 3 deletions src/nwp_consumer/internal/entities/tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import pathlib
import shutil
from collections.abc import MutableMapping
from collections.abc import Mapping, MutableMapping
from typing import Any

import pandas as pd
Expand Down Expand Up @@ -77,6 +77,7 @@ def initialize_empty_store(
model: str,
repository: str,
coords: NWPDimensionCoordinateMap,
chunks: dict[str, int],
) -> ResultE["TensorStore"]:
"""Initialize a store for a given init time.
Expand Down Expand Up @@ -110,6 +111,7 @@ def initialize_empty_store(
This is also used as the name of the tensor.
repository: The name of the repository providing the tensor data.
coords: The coordinates of the store.
chunks: The chunk sizes for the store.
Returns:
An indicator of a successful store write containing the number of bytes written.
Expand Down Expand Up @@ -152,7 +154,8 @@ def initialize_empty_store(
# Write the coordinates to a skeleton Zarr store
# * 'compute=False' enables only saving metadata
# * 'mode="w-"' fails if it finds an existing store
da: xr.DataArray = coords.as_zeroed_dataarray(name=model)

da: xr.DataArray = coords.as_zeroed_dataarray(name=model, chunks=chunks)
encoding = {
model: {"write_empty_chunks": False},
"init_time": {"units": "nanoseconds since 1970-01-01"},
Expand Down Expand Up @@ -257,22 +260,52 @@ def write_to_region(
If the region dict is empty or not provided, the region is determined
via the 'determine_region' method.
This function should be thread safe, so a check is performed on the region
to ensure that it can be safely written to in parallel, i.e. that it covers
an integer number of chunks.
Args:
da: The data to write to the store.
region: The region to write to.
Returns:
An indicator of a successful store write containing the number of bytes written.
See Also:
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
"""
# Attempt to determine the region if missing
if region is None or region == {}:
region_result = NWPDimensionCoordinateMap.from_xarray(da).bind(
self.coordinate_map.determine_region,
)
if isinstance(region_result, Failure):
return Failure(region_result.failure())
return region_result
region = region_result.unwrap()

# For each dimensional slice defining the region, check the slice represents an
# integer number of chunks along that dimension.
# * This is to ensure that the data can safely be written in parallel.
# * The start and and of each slice should be divisible by the chunk size.
chunksizes: Mapping[Any, tuple[int, ...]] = xr.open_dataarray(
self.path, engine="zarr",
).chunksizes
for dim, slc in region.items():
chunk_size = chunksizes.get(dim, (1,))[0]
# TODO: Determine if this should return a full failure object
if slc.start % chunk_size != 0 or slc.stop % chunk_size != 0:
log.warning(
f"Determined region of raw data to be written for dimension '{dim}'"
f"does not align with chunk boundaries of the store. "
f"Dimension '{dim}' has a chunk size of {chunk_size}, "
"but the data to be written for this dimension starts at chunk "
f"{slc.start / chunk_size:.2f} (index {slc.start}) and ends at chunk "
f"{slc.stop / chunk_size:.2f} (index {slc.stop}). "
"As such, this region cannot be safely written in parallel. "
"Ensure the chunking is granular enough to cover the raw data region.",
)


# Perform the regional write
try:
da.to_zarr(store=self.path, region=region, consolidated=True)
Expand Down
1 change: 1 addition & 0 deletions src/nwp_consumer/internal/entities/test_tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def store(self, year: int) -> Generator[TensorStore, None, None]:
model="test_da",
repository="dummy_repository",
coords=test_coords,
chunks=test_coords.chunking(chunk_count_overrides={}),
)
self.assertIsInstance(init_result, Success, msg=init_result)
store = init_result.unwrap()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def repository() -> entities.RawRepositoryMetadata:
optional_env={},
postprocess_options=entities.PostProcessOptions(),
available_models={
"default": entities.Models.MO_UM_GLOBAL_17KM,
"default": entities.Models.MO_UM_GLOBAL_17KM.with_chunk_count_overrides({
"latitude": 8,
"longitude": 8,
}),
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,8 @@ def repository() -> entities.RawRepositoryMetadata:
},
postprocess_options=entities.PostProcessOptions(),
available_models={
"default": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk")
.set_maximum_number_of_chunks_in_one_dim(2),
"hres-ifs-uk": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk")
.set_maximum_number_of_chunks_in_one_dim(2),
"default": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk"),
"hres-ifs-uk": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk"),
"hres-ifs-india": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("india"),
},
)
Expand Down Expand Up @@ -196,17 +194,18 @@ def _download(self, url: str) -> ResultE[pathlib.Path]:
).with_suffix(".grib").expanduser()

# Only download the file if not already present
log.info("Checking for local file: '%s'", local_path)
if not local_path.exists() or local_path.stat().st_size == 0:
if local_path.exists() and local_path.stat().st_size > 0:
log.debug("Skipping download for existing file at '%s'.", local_path.as_posix())
else:
local_path.parent.mkdir(parents=True, exist_ok=True)
log.debug("Requesting file from S3 at: '%s'", url)

try:
if not self._fs.exists(url):
raise FileNotFoundError(f"File not found at '{url}'")

log.debug("Writing file from '%s' to '%s'", url, local_path.as_posix())
with local_path.open("wb") as lf, self._fs.open(url, "rb") as rf:
log.info(f"Writing file from {url} to {local_path}")
for chunk in iter(lambda: rf.read(12 * 1024), b""):
lf.write(chunk)
lf.flush()
Expand Down
8 changes: 5 additions & 3 deletions src/nwp_consumer/internal/services/consumer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,12 @@ def _parallelize_generator[T](
max_connections: The maximum number of connections to use.
"""
# TODO: Change this based on threads instead of CPU count
# TODO: Enable choosing between threads and processes?
n_jobs: int = max(cpu_count() - 1, max_connections)
prefer = "threads"

concurrency = os.getenv("CONCURRENCY", "True").capitalize() == "False"
if concurrency:
if os.getenv("CONCURRENCY", "True").capitalize() == "False":
n_jobs = 1
prefer = "processes"

log.debug(f"Using {n_jobs} concurrent {prefer}")

Expand Down Expand Up @@ -156,6 +155,9 @@ def _create_suitable_store(
model_metadata.expected_coordinates,
init_time=its,
),
chunks=model_metadata.expected_coordinates.chunking(
chunk_count_overrides=model_metadata.chunk_count_overrides,
),
)

@override
Expand Down

0 comments on commit 6c3e9ed

Please sign in to comment.