Skip to content

Commit

Permalink
fix coord issues and add datastore example plotting cli
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Sep 12, 2024
1 parent d04d15e commit 743d7a1
Show file tree
Hide file tree
Showing 8 changed files with 515 additions and 92 deletions.
9 changes: 1 addition & 8 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from torch_geometric.utils.convert import from_networkx

# Local
from .datastore import DATASTORES
from .datastore.base import BaseCartesianDatastore
from .datastore.mllam import MLLAMDatastore
from .datastore.npyfiles import NpyFilesDatastore


def plot_graph(graph, title=None):
Expand Down Expand Up @@ -532,12 +531,6 @@ def create_graph(
save_edges(pyg_m2g, "m2g", graph_dir_path)


DATASTORES = dict(
mllam=MLLAMDatastore,
npyfiles=NpyFilesDatastore,
)


def create_graph_from_datastore(
datastore: BaseCartesianDatastore,
output_root_path: str,
Expand Down
42 changes: 23 additions & 19 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
from pandas.core.indexes.multi import MultiIndex


class BaseDatastore(abc.ABC):
Expand Down Expand Up @@ -228,21 +229,13 @@ class CartesianGridShape:


class BaseCartesianDatastore(BaseDatastore):
"""Base class for weather
data stored on a Cartesian
grid. In addition to the
methods and attributes
required for weather data
in general (see
`BaseDatastore`) for
Cartesian gridded source
data each `grid_index`
coordinate value is assume
to have an associated `x`
and `y`-value so that the
processed data-arrays can
be reshaped back into into
2D xy-gridded arrays.
"""
Base class for weather data stored on a Cartesian grid. In addition to the
methods and attributes required for weather data in general (see
`BaseDatastore`) for Cartesian gridded source data each `grid_index`
coordinate value is assume to have an associated `x` and `y`-value so that
the processed data-arrays can be reshaped back into into 2D xy-gridded
arrays.
In addition the following attributes and methods are required:
- `coords_projection` (property): Projection object for the coordinates.
Expand All @@ -253,7 +246,7 @@ class BaseCartesianDatastore(BaseDatastore):
"""

CARTESIAN_COORDS = ["y", "x"]
CARTESIAN_COORDS = ["x", "y"]

@property
@abc.abstractmethod
Expand Down Expand Up @@ -347,9 +340,20 @@ def unstack_grid_coords(
The dataarray or dataset with the grid coordinates unstacked.
"""
return da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS).unstack(
"grid_index"
)
# check whether `grid_index` is a multi-index
if not isinstance(da_or_ds.indexes.get("grid_index"), MultiIndex):
da_or_ds = da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS)

da_or_ds_unstacked = da_or_ds.unstack("grid_index")

# ensure that the x, y dimensions are in the correct order
dims = da_or_ds_unstacked.dims
xy_dim_order = [d for d in dims if d in self.CARTESIAN_COORDS]

if xy_dim_order != self.CARTESIAN_COORDS:
da_or_ds_unstacked = da_or_ds_unstacked.transpose("y", "x")

return da_or_ds_unstacked

def stack_grid_coords(
self, da_or_ds: Union[xr.DataArray, xr.Dataset]
Expand Down
50 changes: 44 additions & 6 deletions neural_lam/datastore/mllam.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
if len(self.get_vars_names(category)) > 0:
print(f"{category}: {' '.join(self.get_vars_names(category))}")

# find out the dimension order for the stacking to grid-index
dim_order = None
for input_dataset in self._config.inputs.values():
dim_order_ = input_dataset.dim_mapping["grid_index"].dims
if dim_order is None:
dim_order = dim_order_
else:
assert (
dim_order == dim_order_
), "all inputs must have the same dimension order"

self.CARTESIAN_COORDS = dim_order

@property
def root_path(self) -> Path:
"""The root path of the dataset.
Expand Down Expand Up @@ -202,6 +215,14 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:

da_category = self._ds[category]

# set units on x y coordinates if missing
for coord in ["x", "y"]:
if "units" not in da_category[coord].attrs:
da_category[coord].attrs["units"] = "m"

# set multi-index for grid-index
da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS)

if "time" not in da_category.dims:
return da_category
else:
Expand Down Expand Up @@ -294,10 +315,26 @@ def coords_projection(self) -> ccrs.Projection:
The projection of the coordinates.
"""
# TODO: danra doesn't contain projection information yet, but the next
# version will for now we hardcode the projection
# XXX: this is wrong
return ccrs.PlateCarree()
# XXX: this should move to config
kwargs = {
"LoVInDegrees": 25.0,
"LaDInDegrees": 56.7,
"Latin1InDegrees": 56.7,
"Latin2InDegrees": 56.7,
}

lon_0 = kwargs["LoVInDegrees"] # Latitude of first standard parallel
lat_0 = kwargs["LaDInDegrees"] # Latitude of second standard parallel
lat_1 = kwargs["Latin1InDegrees"] # Origin latitude
lat_2 = kwargs["Latin2InDegrees"] # Origin longitude

crs = ccrs.LambertConformal(
central_longitude=lon_0,
central_latitude=lat_0,
standard_parallels=(lat_1, lat_2),
)

return crs

@property
def grid_shape_state(self):
Expand Down Expand Up @@ -346,10 +383,11 @@ def get_xy(self, category: str, stacked: bool) -> ndarray:
da_xy = xr.concat([da_x, da_y], dim="grid_coord")

if stacked:
da_xy = da_xy.stack(grid_index=("y", "x")).transpose(
da_xy = da_xy.stack(grid_index=self.CARTESIAN_COORDS).transpose(
"grid_coord", "grid_index"
)
else:
da_xy = da_xy.transpose("grid_coord", "y", "x")
dims = ["grid_coord", "y", "x"]
da_xy = da_xy.transpose(*dims)

return da_xy.values
48 changes: 27 additions & 21 deletions neural_lam/datastore/npyfiles/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,7 @@ def _get_single_timeseries_dataarray(
"Member can only be specified for the 'state' category"
)

# XXX: we here assume that the grid shape is the same for all categories
grid_shape = self.grid_shape_state
concat_axis = 0

file_params = {}
add_feature_dim = False
Expand Down Expand Up @@ -387,7 +386,8 @@ def _get_single_timeseries_dataarray(
fp_samples = self.root_path / "static"
elif features == ["x", "y"]:
filename_format = "nwp_xy.npy"
file_dims = ["y", "x", "feature"]
# NB: for x, y the feature dimension is the first one
file_dims = ["feature", "y", "x"]
features_vary_with_analysis_time = False
# XXX: x, y are the same for all splits, and so saved in static/
fp_samples = self.root_path / "static"
Expand All @@ -403,6 +403,12 @@ def _get_single_timeseries_dataarray(

coords = {}
arr_shape = []

xs, ys = self.get_xy(category="state", stacked=False)
assert np.all(xs[0, :] == xs[-1, :])
assert np.all(ys[:, 0] == ys[:, -1])
x = xs[0, :]
y = ys[:, 0]
for d in dims:
if d == "elapsed_forecast_duration":
coord_values = (
Expand All @@ -413,9 +419,9 @@ def _get_single_timeseries_dataarray(
elif d == "analysis_time":
coord_values = self._get_analysis_times(split=split)
elif d == "y":
coord_values = np.arange(grid_shape.y)
coord_values = y
elif d == "x":
coord_values = np.arange(grid_shape.x)
coord_values = x
elif d == "feature":
coord_values = features
else:
Expand Down Expand Up @@ -450,7 +456,7 @@ def _get_single_timeseries_dataarray(
]

if features_vary_with_analysis_time:
arr_all = dask.array.stack(arrays, axis=0)
arr_all = dask.array.stack(arrays, axis=concat_axis)
else:
arr_all = arrays[0]

Expand Down Expand Up @@ -568,17 +574,17 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
Returns
-------
np.ndarray
The x, y coordinates of the dataset, returned differently based on
the value of `stacked`:
The x, y coordinates of the dataset (with x first then y second),
returned differently based on the value of `stacked`:
- `stacked==True`: shape `(2, n_grid_points)` where
n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
"""

# the array on disk has shape [2, N_x, N_y], but we want to return it
# as [2, N_y, N_x] so we swap the axes
arr = np.load(self.root_path / "static" / "nwp_xy.npy").swapaxes(1, 2)
# the array on disk has shape [2, N_y, N_x], with the first dimension
# being [x, y]
arr = np.load(self.root_path / "static" / "nwp_xy.npy")

assert arr.shape[0] == 2, "Expected 2D array"
grid_shape = self.grid_shape_state
Expand Down Expand Up @@ -611,7 +617,7 @@ def grid_shape_state(self) -> CartesianGridShape:
The shape of the cartesian grid for the state variables.
"""
nx, ny = self.config.grid_shape_state
ny, nx = self.config.grid_shape_state
return CartesianGridShape(x=nx, y=ny)

@property
Expand All @@ -626,10 +632,10 @@ def boundary_mask(self) -> xr.DataArray:
"""
xs, ys = self.get_xy(category="state", stacked=False)
assert np.all(xs[:, 0] == xs[:, -1])
assert np.all(ys[0, :] == ys[-1, :])
x = xs[:, 0]
y = ys[0, :]
assert np.all(xs[0, :] == xs[-1, :])
assert np.all(ys[:, 0] == ys[:, -1])
x = xs[0, :]
y = ys[:, 0]
values = np.load(self.root_path / "static" / "border_mask.npy")
da_mask = xr.DataArray(
values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
Expand Down Expand Up @@ -677,11 +683,11 @@ def load_pickled_tensor(fn):
std_values = np.array([flux_std, 1.0, 1.0, 1.0, 1.0, 1.0])

elif category == "static":
ds_static = self.get_dataarray(category="static", split="train")
ds_static_mean = ds_static.mean(dim=["grid_index"])
ds_static_std = ds_static.std(dim=["grid_index"])
mean_values = ds_static_mean["static_feature"].values
std_values = ds_static_std["static_feature"].values
da_static = self.get_dataarray(category="static", split="train")
da_static_mean = da_static.mean(dim=["grid_index"]).compute()
da_static_std = da_static.std(dim=["grid_index"]).compute()
mean_values = da_static_mean.values
std_values = da_static_std.values
else:
raise NotImplementedError(f"Category {category} not supported")

Expand Down
Loading

0 comments on commit 743d7a1

Please sign in to comment.